Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes to mg prior and to account for neighbor excluded AA #504

Merged
merged 3 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,3 @@ outputs/*
analysis/*
.DS_Store
*.ff
!multi-rc-prior.ff/
!multi-mg-prior.ff/
!multi-ego-ready.ff/
80 changes: 63 additions & 17 deletions src/multiego/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,8 @@ def generate_LJ(meGO_ensemble, train_dataset, parameters):
# now we can remove contacts with default c6/c12 becasue these
# are uninformative and predefined. This also allow to replace them with contact learned
# by either intra/inter training. We cannot remove 1-4 interactions.
# we should not remove default interactions in the window of 2 neighor AA to
# avoid replacing them with unwanted interactions
meGO_LJ = meGO_LJ.loc[
~(
(meGO_LJ["epsilon"] > 0)
Expand All @@ -1215,6 +1217,10 @@ def generate_LJ(meGO_ensemble, train_dataset, parameters):
& (meGO_LJ["mg_epsilon"] < 0)
& ((abs(meGO_LJ["epsilon"] - meGO_LJ["mg_epsilon"]) / abs(meGO_LJ["mg_epsilon"])) < parameters.relative_c12d)
& (meGO_LJ["1-4"] == "1>4")
& ~(
(abs(meGO_LJ["ai"].apply(get_residue_number) - meGO_LJ["aj"].apply(get_residue_number)) < 3)
& (meGO_LJ["same_chain"] == True)
)
)
]

Expand Down Expand Up @@ -1339,15 +1345,13 @@ def generate_LJ(meGO_ensemble, train_dataset, parameters):
# Calculate c6 and c12 for meGO_LJ
meGO_LJ["c6"] = np.where(meGO_LJ["epsilon"] < 0.0, 0.0, 4 * meGO_LJ["epsilon"] * (meGO_LJ["sigma"] ** 6))

meGO_LJ["c12"] = np.where(
meGO_LJ["epsilon"] < 0.0, -meGO_LJ["epsilon"], abs(4 * meGO_LJ["epsilon"] * (meGO_LJ["sigma"] ** 12))
)
meGO_LJ["c12"] = np.where(meGO_LJ["epsilon"] < 0.0, -meGO_LJ["epsilon"], 4 * meGO_LJ["epsilon"] * (meGO_LJ["sigma"] ** 12))

# Calculate c6 and c12 for meGO_LJ_14
meGO_LJ_14["c6"] = np.where(meGO_LJ_14["epsilon"] < 0.0, 0.0, 4 * meGO_LJ_14["epsilon"] * (meGO_LJ_14["sigma"] ** 6))

meGO_LJ_14["c12"] = np.where(
meGO_LJ_14["epsilon"] < 0.0, -meGO_LJ_14["epsilon"], abs(4 * meGO_LJ_14["epsilon"] * (meGO_LJ_14["sigma"] ** 12))
meGO_LJ_14["epsilon"] < 0.0, -meGO_LJ_14["epsilon"], 4 * meGO_LJ_14["epsilon"] * (meGO_LJ_14["sigma"] ** 12)
)

meGO_LJ["type"] = 1
Expand Down Expand Up @@ -1418,6 +1422,10 @@ def generate_LJ(meGO_ensemble, train_dataset, parameters):
return meGO_LJ, meGO_LJ_14


def get_residue_number(s):
return int(s.split("_")[-1])


def make_pairs_exclusion_topology(meGO_ensemble, meGO_LJ_14, args):
"""
This function prepares the [ exclusion ] and [ pairs ] section to output to topology.top
Expand Down Expand Up @@ -1463,14 +1471,11 @@ def make_pairs_exclusion_topology(meGO_ensemble, meGO_LJ_14, args):
# p14 are specifically the interactions at exactly 3 bonds
exclusion_bonds, p14 = topology.get_14_interaction_list(reduced_topology, bond_pair)

pairs = pd.DataFrame()
# in the case of the MG prior we need to remove interactions in a window of 2 residues
if args.egos == "mg":
# Create a list of tuples (sbtype, residue_number)
sbtype_with_residue = [
(sbtype, resnum_type_dict[sbtype])
for sbtype in reduced_topology["sb_type"]
# if meGO_ensemble["sbtype_type_dict"][sbtype] != "CH1a"
]
sbtype_with_residue = [(sbtype, resnum_type_dict[sbtype]) for sbtype in reduced_topology["sb_type"]]
# Sort the list by residue numbers
sbtype_with_residue.sort(key=lambda x: x[1])
# Initialize a list to hold the filtered combinations
Expand Down Expand Up @@ -1516,10 +1521,8 @@ def make_pairs_exclusion_topology(meGO_ensemble, meGO_LJ_14, args):
mask = df.remove == "Yes"
df = df[~mask]
df.drop(columns=["check", "remove"], inplace=True)
meGO_LJ_14 = pd.concat([meGO_LJ_14, df], axis=0, sort=False, ignore_index=True)

pairs = pd.DataFrame()
if not meGO_LJ_14.empty:
pairs = pd.concat([meGO_LJ_14, df], axis=0, sort=False, ignore_index=True)
elif args.egos == "production" and not meGO_LJ_14.empty:
mol_ai = f"{idx}_{molecule}"
# pairs do not have duplicates because these have been cleaned before
pairs = meGO_LJ_14[meGO_LJ_14["molecule_name_ai"] == mol_ai][
Expand All @@ -1531,11 +1534,57 @@ def make_pairs_exclusion_topology(meGO_ensemble, meGO_LJ_14, args):
"same_chain",
"probability",
"rc_probability",
"mg_sigma",
"mg_epsilon",
"source",
"rep",
]
].copy()
# Intermolecular interactions are excluded
# this need to be the default repulsion if within two residue
if not pairs.empty:
pairs.loc[
(~pairs["same_chain"])
& (abs(pairs["ai"].apply(get_residue_number) - pairs["aj"].apply(get_residue_number)) < 3),
"c6",
] = 0.0
pairs.loc[
(~pairs["same_chain"])
& (abs(pairs["ai"].apply(get_residue_number) - pairs["aj"].apply(get_residue_number)) < 3),
"c12",
] = pairs["rep"]
# else it should be default mg
pairs.loc[
(~pairs["same_chain"])
& (abs(pairs["ai"].apply(get_residue_number) - pairs["aj"].apply(get_residue_number)) > 2)
& (pairs["mg_epsilon"] < 0.0),
"c6",
] = 0.0
pairs.loc[
(~pairs["same_chain"])
& (abs(pairs["ai"].apply(get_residue_number) - pairs["aj"].apply(get_residue_number)) > 2)
& (pairs["mg_epsilon"] < 0.0),
"c12",
] = -pairs["mg_epsilon"]
pairs.loc[
(~pairs["same_chain"])
& (abs(pairs["ai"].apply(get_residue_number) - pairs["aj"].apply(get_residue_number)) > 2)
& (pairs["mg_epsilon"] > 0.0),
"c6",
] = (
4 * pairs["mg_epsilon"] * (pairs["mg_sigma"] ** 6)
)
pairs.loc[
(~pairs["same_chain"])
& (abs(pairs["ai"].apply(get_residue_number) - pairs["aj"].apply(get_residue_number)) > 2)
& (pairs["mg_epsilon"] > 0.0),
"c12",
] = (
4 * pairs["mg_epsilon"] * (pairs["mg_sigma"] ** 12)
)

# now we are ready to finalize
if not pairs.empty:
# The exclusion list was made based on the atom number
pairs["ai"] = pairs["ai"].map(atnum_type_dict)
pairs["aj"] = pairs["aj"].map(atnum_type_dict)
Expand All @@ -1546,10 +1595,8 @@ def make_pairs_exclusion_topology(meGO_ensemble, meGO_LJ_14, args):
pairs.loc[(pairs["check"].isin(p14) & (pairs["same_chain"])), "remove"] = "No"
mask = pairs.remove == "Yes"
pairs = pairs[~mask]
# finalize
pairs["func"] = 1
# Intermolecular interactions are excluded
pairs.loc[(~pairs["same_chain"]), "c6"] = 0.0
pairs.loc[(~pairs["same_chain"]), "c12"] = pairs["rep"]
# this is a safety check
pairs = pairs[pairs["c12"] > 0.0]
pairs = pairs[
Expand All @@ -1564,7 +1611,6 @@ def make_pairs_exclusion_topology(meGO_ensemble, meGO_LJ_14, args):
"source",
]
]

pairs.dropna(inplace=True)
pairs["ai"] = pairs["ai"].astype(int)
pairs["aj"] = pairs["aj"].astype(int)
Expand Down
2 changes: 1 addition & 1 deletion test/test_inputs/lyso-bnz_ref/topol.top
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@
[ molecules ]
; Compound #mols
Lyso 1
BNZ 5
BNZ 1
2 changes: 1 addition & 1 deletion test/test_outputs/abetaref/case_1/ffnonbonded.itp
Original file line number Diff line number Diff line change
Expand Up @@ -21207,7 +21207,7 @@ CE2_ABeta_10 O2_ABeta_42 1 2.155633e-04 1.318059e-07 ; 0.291335 8.813629e-
O_ABeta_10 OE1_ABeta_11 1 0.000000e+00 4.877029e-06 ; 0.360866 -4.877029e-06 0.091976 7.003680e-02 0.000725 0.000037 2.428469e-06 0.493718 True native_MD 84 90
O_ABeta_10 OE2_ABeta_11 1 0.000000e+00 4.877029e-06 ; 0.360866 -4.877029e-06 0.091976 7.003680e-02 0.000725 0.000037 2.428469e-06 0.493718 True native_MD 84 91
O_ABeta_10 C_ABeta_11 1 1.921719e-04 8.340416e-08 ; 0.275159 1.106960e-01 0.999918 9.820871e-01 0.000725 0.002037 8.269429e-07 0.451327 True native_MD 84 92
O_ABeta_10 O_ABeta_11 1 0.000000e+00 3.000001e-06 ; 0.346545 -3.000001e-06 1.000000 1.000000e+00 1.000000 1.000000 3.000001e-06 0.502491 False basic 84 93
O_ABeta_10 O_ABeta_11 1 0.000000e+00 3.001784e-06 ; 0.346563 -3.001784e-06 0.981208 8.733828e-01 0.000725 0.000037 3.000001e-06 0.502491 True native_MD 84 93
O_ABeta_10 N_ABeta_12 1 1.807813e-04 6.666288e-08 ; 0.267786 1.225641e-01 0.829101 5.994208e-01 0.000725 0.002037 4.799381e-07 0.431321 True native_MD 84 94
O_ABeta_10 CA_ABeta_12 1 1.390373e-03 4.121466e-06 ; 0.379012 1.172603e-01 0.726703 4.512806e-01 0.000725 0.001600 4.153495e-06 0.516300 True native_MD 84 95
O_ABeta_10 CB_ABeta_12 1 1.605453e-03 5.683238e-06 ; 0.390391 1.133807e-01 0.364860 2.514262e-01 0.000725 0.001600 4.153495e-06 0.516300 True native_MD 84 96
Expand Down
Loading
Loading