Nothing Special   »   [go: up one dir, main page]

Skip to content

Commit

Permalink
Update entrypoint scripts (#18)
Browse files Browse the repository at this point in the history
* Cleanup and rename

* Add test for config

* Update script parameters

* Finished first adaptation
  • Loading branch information
giovannidoni authored Jan 20, 2021
1 parent 665020b commit 9c80acb
Show file tree
Hide file tree
Showing 31 changed files with 300 additions and 69 deletions.
2 changes: 1 addition & 1 deletion gmplabtools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from gmplabtools.pamm import Pamm
from .pamm import Pamm
from ._version import __version__
6 changes: 3 additions & 3 deletions gmplabtools/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .tools import *
from .transition_rates import *
from .clustering import *
from .tools import oracle_shrinkage, CovDim, DataSampler
from .transition_rates import ClusterRates
from .clustering import calculate_adjacency, merge, adjancency_dendrogram
2 changes: 1 addition & 1 deletion gmplabtools/pamm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .pamm import *
from .pamm import Gauss, PammGMM, Pamm
Binary file modified gmplabtools/pamm/bin/pamm
Binary file not shown.
1 change: 1 addition & 0 deletions gmplabtools/shared/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .config import get_config
66 changes: 59 additions & 7 deletions gmplabtools/shared/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import json
import json, importlib, pkgutil, shutil, os
from types import SimpleNamespace
import toml

Expand All @@ -7,7 +7,63 @@ class WrongConfigFormat(Exception):
pass


def get_config(filename, section=None):
def _module_exists(module):
try:
return pkgutil.get_loader(module) is not None
except ImportError:
return False


def _importer(module_name, obj):
if _module_exists(module_name):
module = importlib.import_module(module_name)
if hasattr(module, obj):
return getattr(module, obj)
elif os.path.isfile(module_name):
base_module_name = os.path.basename(module_name).replace(".py", "")
shutil.copy(module_name, ".")
module = importlib.import_module(base_module_name)
if hasattr(module, obj):
return getattr(module, obj)


def import_external(config, class_name):
try:
module = importlib.import_module(".processing", package="gmplabtools.shared")
return getattr(module, class_name)
except AttributeError:
if config.pymodule:
class_type = _importer(config.pymodule, class_name)
if config.mymodule and class_type is None:
class_type = _importer(config.mymodule, class_name)
if class_type is not None:
return class_type
else:
msg = (
f"Cannot find class {class_name} in system packages or in those provided with "
f"in mymodule and pymodule config values."
)
raise ModuleNotFoundError(msg)


class RecursiveNamespace(SimpleNamespace):

@staticmethod
def map_entry(entry):
if isinstance(entry, dict):
return RecursiveNamespace(**entry)
return entry

def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, val in kwargs.items():
if type(val) == dict:
setattr(self, key, RecursiveNamespace(**val))
elif type(val) == list:
setattr(self, key, list(map(self.map_entry, val)))


def get_config(filename):
with open(filename, 'r') as f:
if filename.endswith('.json'):
_config = json.load(f)
Expand All @@ -18,8 +74,4 @@ def get_config(filename, section=None):
f"Format of the '{filename}' is not supported. "
"Available formats: .json, .toml"
)
if section is not None:
parse = _config[section]
else:
parse = _config
return SimpleNamespace(**parse)
return SimpleNamespace(**_config)
13 changes: 13 additions & 0 deletions gmplabtools/shared/processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from sklearn.base import TransformerMixin
import sklearn.decomposition as dim_red

__all__ = ["NullTransformer"] + [cls for cls in dir(dim_red) if hasattr(cls, "transform")]


class NullTransformer(TransformerMixin):

def fit(self, x):
return x

def predict(self, x):
return x
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
from jinja2 import Environment, PackageLoader

from gmplabtools.martini.data.non_bonded import ORIGINAL
from gmplabtools.simulations.data.non_bonded import ORIGINAL


def cache():
Expand Down Expand Up @@ -60,13 +60,13 @@ def __init__(self, fields, parameter_file=None):

@classmethod
def get_data(cls, filename):
param = pkgutil.get_data('gmplabtools.martini.data', filename)
param = pkgutil.get_data('gmplabtools.simulations.data', filename)
return pd.read_csv(io.BytesIO(param))

@property
@cache()
def df(self):
df = Param.get_data('martini-non-bonded.csv')
df = Param.get_data('simulations-non-bonded.csv')
df['type'] = df['type'].map(Param.non_bonded_strength)
return df

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse

from gmplabtools.shared.config import get_config
from gmplabtools.martini.simulation import SetupSim
from gmplabtools.simulations.simulation import SetupSim


def main(config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from skopt.space import Real

from gmplabtools.shared.config import get_config
from gmplabtools.martini.parameters import Param
from gmplabtools.simulations.parameters import Param


def main(config, n, res):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _init_path(self):

def _init_template(self, config):
template = Martini(**config).get_template(self.config.params['template'])
template_file = os.path.join(self.full_path, 'martini.itp')
template_file = os.path.join(self.full_path, 'simulations.itp')
return template_file, template

def _generate_simulation(self, params):
Expand Down
19 changes: 9 additions & 10 deletions scripts/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,24 @@
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram

from gmplabtools.shared.config import get_config
from gmplabtools.shared import get_config
from gmplabtools.analysis import DataSampler, calculate_adjacency, adjancency_dendrogram
from gmplabtools.pamm.pamm import Pamm
from gmplabtools.pamm import Pamm


def main(config):

x = np.loadtxt(config.trj_filename)
x = np.loadtxt("all_system_transformed")

if config.generate_grid:
if config.cluster["generate_grid"]:
d = DataSampler(config.distance, norm=config.p)
grid, indices = d.minmax_sample(x, config.size)
np.savetxt("{}.grid".format(config.savegrid), indices + 1, fmt="%d")

p = Pamm(config.pamm_input)
print(p.command_parser)
p = Pamm(config.cluster["pamm_input"])
p.run()

if 'bootstrap' in config.pamm_input:
if "bootstrap" in config.cluster["pamm_input"]:
adjacency, mapping = calculate_adjacency(
prob=p.p,
clusters=p.cluster,
Expand All @@ -31,8 +30,8 @@ def main(config):

z = adjancency_dendrogram(adjacency)
fig, ax = plt.subplots()
_ = dendrogram(z, ax=ax, **config.dendrogram)['leaves']
fig.savefig('clusters_dendrogram.png')
_ = dendrogram(z, ax=ax, **config.cluster["dendrogram"])["leaves"]
fig.savefig("clusters_dendrogram.png")


if __name__ == "__main__":
Expand All @@ -42,4 +41,4 @@ def main(config):
help="config file")
args = parser.parse_args()

main(get_config(args.config, "cluster"))
main(get_config(args.config))
10 changes: 9 additions & 1 deletion scripts/config.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
mymodule=[]
pymodules=[]

[transform]
components = 7
plot = true

[transform.trajectories]
Expand All @@ -14,6 +16,12 @@ rcut = 8
nmax = 8
lmax = 8

[transform.transformer]
name = "PCA"

[transform.transformer.param]
n_components=7

[cluster]
trj_filename = "allsoap.pca"
distance = "minkowski"
Expand Down
6 changes: 3 additions & 3 deletions scripts/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@


def main(config):
gmm = PammGMM.read_clusters(config.pamm_output + ".pamm")
gmm = PammGMM.read_clusters(config.cluster["o"] + ".pamm")

print("There are {} clusters".format(np.unique(gmm.pk).shape[0]))

for k,f in config.extrapolate_on_files.items():
for k, f in config.extrapolate_on_files.items():
x = np.loadtxt(f)
x_ = gmm.predict_proba(x)
clusters = np.argmax(x_, axis=1).reshape((-1, 1))
Expand All @@ -30,4 +30,4 @@ def main(config):
help="config file")
args = parser.parse_args()

main(get_config(args.config, "predict"))
main(get_config(args.config))
63 changes: 31 additions & 32 deletions scripts/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import numpy as np
from ase.io import read
from dscribe.descriptors import SOAP
from sklearn.pipeline import make_pipeline
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

from gmplabtools.shared.config import get_config
from gmplabtools.shared.config import get_config, import_external


def read_traj(filename, index=":", start=None, end=None, stride=None):
Expand All @@ -16,25 +14,22 @@ def read_traj(filename, index=":", start=None, end=None, stride=None):
return read(filename, index=index, format="xyz")


def plot(pca, fname):
def plot(x, fname):
plt.figure(figsize=(11, 8), dpi=80)
plt.ylabel('2nd Coord')
plt.xlabel('1st Coord')
plt.title('2D PS - C1 fiber')
plt.style.context('seaborn-whitegrid')
plt.scatter(pca[:, 0], pca[:, 1], c=pca[:, 2], cmap='inferno')
plt.ylabel('2nd coord')
plt.xlabel('1st coord')
plt.scatter(x[:, 0], x[:, 1], c=x[:, 2], cmap='inferno')
plt.colorbar()
plt.savefig(fname)
plt.close()


def main(config):
traj = {name: read_traj(traj) for name, traj in config.trajectories.items()}

if hasattr(config, "full_traj"):
all_traj = read_traj(config.full_traj)
else:
all_traj = sum(traj.values(), [])
n_traj = len(config.trajectories)
traj = {name: read_traj(traj) for name, traj in config.transform["trajectories"].items()}

all_traj = sum(traj.values(), [])

# info on how it works (and installation):
# https://singroup.github.io/dscribe/tutorials/soap.html
Expand All @@ -47,36 +42,40 @@ def main(config):
msg = ", ".join(
["{}: {}".format(name, k.shape[0]) for name, k in soap.items()]
)
print(all_soap.shape)
print(msg)

tranformer = make_pipeline(PCA(n_components=config.components))
transformer = import_external(config.transform["transformer"]["name"])
params = config.transform["transformer"]["params"]
transformer = transformer(**params)

tranformer = tranformer.fit(all_soap)
transformer = transformer.fit(all_soap)

# calculate variance ratios on the merged data
variance = tranformer.named_steps['pca'].explained_variance_ratio_
var=np.cumsum(np.round(variance, decimals=3)*100)
print("PCA (dim={}) variance: {}".format(
str(tranformer.named_steps['pca'].n_components_),
var)
)
if "PCA" in str(config.transform["transformer"]["name"]):
variance = transformer.explained_variance_ratio_
var = np.cumsum(np.round(variance, decimals=3) * 100)
print("PCA (dim={}) variance: {}".format(
str(transformer.named_steps['pca'].n_components_),
var)
)

transformed = {name: tranformer.transform(k) for name, k in soap.items()}
all_pca = tranformer.transform(all_soap)
transformed = {name: transformer.transform(k) for name, k in soap.items()}
red_dim = transformer.transform(all_soap)

np.savetxt("allsoap.pca", all_pca)
np.savetxt("all_system_transformed.txt", red_dim)

for k, x in transformed.items():
np.savetxt("{}soap.pca".format(k), x)

np.savetxt("{}_soap.txt".format(k), x)

sample_length = min(2000, x.shape[0])

if config.plot:
for k, x in transformed.items():
np.random.shuffle(x)
plot(x[:5000, :], "PhaseSpace2D_{}.png".format(k))
plot(x[:sample_length, :], "scatter_plot_{}.png".format(k))

np.random.shuffle(all_pca)
plot(all_pca[:15000, :], "PhaseSpace2D_all.png")
np.random.shuffle(red_dim)
plot(red_dim[:sample_length * n_traj, :], "scatter_plot_all.png")


if __name__ == "__main__":
Expand All @@ -86,4 +85,4 @@ def main(config):
help="config file")
args = parser.parse_args()

main(get_config(args.config, "transform"))
main(get_config(args.config))
Loading

0 comments on commit 9c80acb

Please sign in to comment.