Skip to content

Commit

Permalink
implement identifier replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
kjappelbaum committed Aug 13, 2024
1 parent f0451bd commit a21a01d
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/api/meta_yaml_augmentor.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ The augmentation process involves:
1. **LLM Integration**: This tool requires integration with an LLM service. Ensure you have the necessary credentials and access set up. By default it uses, `gpt-4o`. For this, you need to expose the `OPENAI_API_KEY` environment variable.

2. **Output Quality**: The quality of the augmented `meta.yaml` depends on the capabilities of the LLM being used. Manual review and adjustment may be necessary.
It also depends on the quality of the existing `meta.yaml` file. If the existing file doesn't follow the standards (and, for example, hard codes target names) the augmentation may not be successful.
It also depends on the quality of the existing `meta.yaml` file. If the existing file doesn't follow the standards (and, for example, hard codes target names) the augmentation may not be successful.

## Example Usage in Python

Expand Down
1 change: 1 addition & 0 deletions docs/api/sampler.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,6 @@ print(mc_result)
- The `TemplateSampler` class supports wrapping of identifiers with tags when the `wrap_identifiers` option is enabled in the configuration.
- Wrapped identifiers use the format `[BEGIN_IDENTIFIER_TYPE]value[END_IDENTIFIER_TYPE]`.
- Identifier types are based on the `IdentifierEnum` class, which includes common chemical identifiers like SMILES, InChI, and others.
- If you have `SMILES` as columns as well as one of `["selfies", "deepsmiles", "canonical", "inchi", "iupac_name']` in the dataframe (they do not even need to be in the `meta.yaml`) the engine will automatically replace the `SMILES` with the other identifiers (randomly).

For more detailed information on the implementation and advanced usage, please refer to the source code and unit tests.
25 changes: 25 additions & 0 deletions src/chemnlp/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,26 @@ def __init__(
self.class_balanced = False
self.balance_column = None
self.wrap_identifiers = config.get("wrap_identifiers", False)
self.additional_targets = self._get_additional_targets(df)
self._add_additional_targets_to_meta()

def _get_additional_targets(self, df: pd.DataFrame) -> List[str]:
additional_targets = []
for col in ["selfies", "deepsmiles", "canonical", "inchi", "iupac_name"]:
if col in df.columns:
additional_targets.append(col)
return additional_targets

def _add_additional_targets_to_meta(self):
additional_targets_meta = {
"selfies": {"id": "selfies", "type": "selfies", "description": "SELFIES"},
"deepsmiles": {"id": "deepsmiles", "type": "deepsmiles", "description": "DeepSMILES"},
"canonical": {"id": "canonical", "type": "canonical", "description": "canonical SMILES"},
"inchi": {"id": "inchi", "type": "inchi", "description": "InChI"},
"iupac_name": {"id": "iupac_name", "type": "iupac_name", "description": "IUPAC name"},
}
for target in self.additional_targets:
self.meta["targets"].append(additional_targets_meta[target])

def _wrap_identifier(self, identifier: str, value: str) -> str:
"""Wrap the identifier value with tags if wrap_identifiers is enabled."""
Expand Down Expand Up @@ -517,6 +537,11 @@ def sample(self, sample: pd.Series, template: str) -> str:
"""
if sample is None:
sample = self.df.sample(1).iloc[0]
if self.additional_targets and "SMILES" in sample.index:
non_nan_targets = [target for target in ["SMILES"] + self.additional_targets if pd.notna(sample[target])]
new_target = random.choice(non_nan_targets)
if new_target != "SMILES":
template = template.replace("{SMILES", "{" + new_target)
sample_dict = self.get_sample_dict(sample, template)
return self._fill_template(template, sample_dict)

Expand Down
48 changes: 48 additions & 0 deletions tests/data/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ def sample_df():
}
)

@pytest.fixture
def sample_multiple_identifier_df():
return pd.DataFrame({
'SMILES': ['CC(C)NCC(O)c1ccc(O)c(O)c1', 'CC1=C(C(=O)NC2=C1C=CC=C2)C3=CC=CC=C3'],
'selfies': ['[C][C][Branch1][C][C][N][C][C][Branch1][O][C][=C][C][=C][C][Branch1][O][C][=C][Branch1][O][C][=C]', '[C][C]1[=C][Branch1][C][Branch2][=O][N][C]2[=C]1[C][=C][C][=C][C]2[C]3[=C][C][=C][C][=C][C]3'],
'inchi': ['InChI=1S/C11H17NO3/c1-8(2)12-6-9(13)10-4-3-5-11(14)7-10/h3-5,7-9,12-14H,6H2,1-2H3', 'InChI=1S/C15H11NO/c17-14-11-9-5-1-3-7-13(9)16-15(14)10-6-2-4-8-12(10)11/h1-8,16H'],
'compound_name': ['Isoproterenol', 'Phenytoin'],
'LogP': [0.08, 2.47],
'is_active': [True, False],
'split': ['train', 'test']
})

@pytest.fixture
def sample_meta():
Expand Down Expand Up @@ -54,6 +65,23 @@ def sample_meta():
}


@pytest.fixture
def sample_multiple_identifier_meta():
return {
"targets": [
{"id": "LogP", "type": "continuous", "description": "Logarithm of partition coefficient"},
{"id": "is_active", "type": "categorical", "description": "Activity status"}
],
"identifiers": [
{"id": "SMILES", "type": "SMILES", "description": "SMILES notation"},
{"id": "compound_name", "type": "Other", "description": "Compound name"}
],
"templates": [
"The molecule with SMILES {SMILES#} has a LogP of {LogP#}.",
"The compound {compound_name#} is {is_active#active&inactive}."
]
}

@pytest.fixture
def sample_config():
return {
Expand Down Expand Up @@ -475,3 +503,23 @@ def test_polymer_multiple_properties(
assert "*CC(*)C" in result
assert "275.0" in result
assert "0.90" in result

def test_additional_targets(sample_multiple_identifier_df, sample_multiple_identifier_meta, sample_config):
sampler = TemplateSampler(sample_multiple_identifier_df, sample_multiple_identifier_meta, sample_config)
assert set(sampler.additional_targets) == {"selfies", "inchi"}
print(sampler.meta['targets'])
assert len(sampler.meta["targets"]) == 4



def test_sample_with_random_replacement(sample_multiple_identifier_df, sample_multiple_identifier_meta, sample_config):
sampler = TemplateSampler(sample_multiple_identifier_df, sample_multiple_identifier_meta, sample_config)
template = "The compound with {SMILES__description} {SMILES#} has a {LogP__description} of {LogP#}"
results = [sampler.sample(sample_multiple_identifier_df.iloc[0], template) for _ in range(20)]
smiles_count = sum("CC(C)NCC(O)c1ccc(O)c(O)c1" in r for r in results)
selfies_count = sum("[C][C][Branch1][C][C][N][C][C][Branch1][O][C][=C][C][=C][C][Branch1][O][C][=C][Branch1][O][C][=C]" in r for r in results)
inchi_count = sum("InChI=1S/C11H17NO3/c1-8(2)12-6-9(13)10-4-3-5-11(14)7-10/h3-5,7-9,12-14H,6H2,1-2H3" in r for r in results)
assert smiles_count > 0
assert selfies_count > 0
assert inchi_count > 0
assert smiles_count + selfies_count + inchi_count == 20

0 comments on commit a21a01d

Please sign in to comment.