Skip to content

Commit 5d56c47

Browse files
Merge pull request #508 from carlocamilloni/main
black
2 parents a43c5ee + 77145ec commit 5d56c47

File tree

2 files changed

+6
-13
lines changed

2 files changed

+6
-13
lines changed

multiego.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def main():
242242
print("- Done in:", elapsed_time, "seconds")
243243

244244
print("- Writing Multi-eGO model")
245-
meGO_LJ = ensemble.sort_LJ(meGO_ensembles, meGO_LJ)
245+
meGO_LJ = ensemble.sort_LJ(meGO_ensembles, meGO_LJ)
246246
io.write_model(meGO_ensembles, meGO_LJ, meGO_LJ_14, args)
247247
et = time.time()
248248
elapsed_time = et - st

src/multiego/ensemble.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -853,9 +853,7 @@ def generate_OO_LJ(meGO_ensemble):
853853
rc_LJ["distance"] = rc_LJ["cutoff"]
854854
rc_LJ["learned"] = 0
855855
rc_LJ["1-4"] = "1>4"
856-
molecule_names_dictionary = {
857-
name.split("_", 1)[1]: name for name in meGO_ensemble["molecules_idx_sbtype_dictionary"]
858-
}
856+
molecule_names_dictionary = {name.split("_", 1)[1]: name for name in meGO_ensemble["molecules_idx_sbtype_dictionary"]}
859857
rc_LJ["molecule_name_ai"] = rc_LJ["ai"].apply(lambda x: "_".join(x.split("_")[1:-1])).map(molecule_names_dictionary)
860858
rc_LJ["molecule_name_aj"] = rc_LJ["aj"].apply(lambda x: "_".join(x.split("_")[1:-1])).map(molecule_names_dictionary)
861859
rc_LJ["ai"] = rc_LJ["ai"].astype("category")
@@ -1336,12 +1334,10 @@ def generate_LJ(meGO_ensemble, train_dataset, parameters):
13361334
meGO_LJ_14["epsilon"] < 0.0, -meGO_LJ_14["epsilon"], 4 * meGO_LJ_14["epsilon"] * (meGO_LJ_14["sigma"] ** 12)
13371335
)
13381336

1339-
13401337
# meGO consistency checks
13411338
consistency_checks(meGO_LJ)
13421339
consistency_checks(meGO_LJ_14)
13431340

1344-
13451341
et = time.time()
13461342
elapsed_time = et - st
13471343
print("\t- Done in:", elapsed_time, "seconds")
@@ -1356,11 +1352,9 @@ def sort_LJ(meGO_ensemble, meGO_LJ):
13561352
meGO_LJ["number_aj"] = meGO_LJ["aj"].map(meGO_ensemble["sbtype_number_dict"]).astype(int)
13571353

13581354
# Filter and explicitly create a copy to avoid the warning
1359-
meGO_LJ = meGO_LJ[
1360-
(meGO_LJ["ai"].cat.codes <= meGO_LJ["aj"].cat.codes)
1361-
].copy()
1355+
meGO_LJ = meGO_LJ[(meGO_LJ["ai"].cat.codes <= meGO_LJ["aj"].cat.codes)].copy()
13621356

1363-
# across molecules use molecule_ai<=molecule_aj
1357+
# across molecules use molecule_ai<=molecule_aj
13641358
(
13651359
meGO_LJ["ai"],
13661360
meGO_LJ["aj"],
@@ -1369,7 +1363,7 @@ def sort_LJ(meGO_ensemble, meGO_LJ):
13691363
meGO_LJ["number_ai"],
13701364
meGO_LJ["number_aj"],
13711365
) = np.where(
1372-
(meGO_LJ["molecule_name_ai"].astype(str)<=meGO_LJ["molecule_name_aj"].astype(str)),
1366+
(meGO_LJ["molecule_name_ai"].astype(str) <= meGO_LJ["molecule_name_aj"].astype(str)),
13731367
[
13741368
meGO_LJ["ai"],
13751369
meGO_LJ["aj"],
@@ -1388,8 +1382,7 @@ def sort_LJ(meGO_ensemble, meGO_LJ):
13881382
],
13891383
)
13901384

1391-
1392-
# in the same molecule use ai<=aj
1385+
# in the same molecule use ai<=aj
13931386
# Apply np.where to swap values only when molecule_name_ai == molecule_name_aj
13941387
(
13951388
meGO_LJ["ai"],

0 commit comments

Comments
 (0)