From a21a01d660e428a313df7df68908c0d7fa237ae7 Mon Sep 17 00:00:00 2001 From: Kevin Maik Jablonka Date: Tue, 13 Aug 2024 13:33:28 -0700 Subject: [PATCH] implement identifier replacement --- docs/api/meta_yaml_augmentor.md | 2 +- docs/api/sampler.md | 1 + src/chemnlp/data/sampler.py | 25 +++++++++++++++++ tests/data/test_sampler.py | 48 +++++++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 1 deletion(-) diff --git a/docs/api/meta_yaml_augmentor.md b/docs/api/meta_yaml_augmentor.md index 5b470503c..0e8328295 100644 --- a/docs/api/meta_yaml_augmentor.md +++ b/docs/api/meta_yaml_augmentor.md @@ -48,7 +48,7 @@ The augmentation process involves: 1. **LLM Integration**: This tool requires integration with an LLM service. Ensure you have the necessary credentials and access set up. By default it uses, `gpt-4o`. For this, you need to expose the `OPENAI_API_KEY` environment variable. 2. **Output Quality**: The quality of the augmented `meta.yaml` depends on the capabilities of the LLM being used. Manual review and adjustment may be necessary. -It also depends on the quality of the existing `meta.yaml` file. If the existing file doesn't follow the standards (and, for example, hard codes target names) the augmentation may not be successful. + It also depends on the quality of the existing `meta.yaml` file. If the existing file doesn't follow the standards (and, for example, hard codes target names) the augmentation may not be successful. ## Example Usage in Python diff --git a/docs/api/sampler.md b/docs/api/sampler.md index 460b6ef0c..8bc963288 100644 --- a/docs/api/sampler.md +++ b/docs/api/sampler.md @@ -153,5 +153,6 @@ print(mc_result) - The `TemplateSampler` class supports wrapping of identifiers with tags when the `wrap_identifiers` option is enabled in the configuration. - Wrapped identifiers use the format `[BEGIN_IDENTIFIER_TYPE]value[END_IDENTIFIER_TYPE]`. - Identifier types are based on the `IdentifierEnum` class, which includes common chemical identifiers like SMILES, InChI, and others. +- If you have `SMILES` as columns as well as one of `["selfies", "deepsmiles", "canonical", "inchi", "iupac_name']` in the dataframe (they do not even need to be in the `meta.yaml`) the engine will automatically replace the `SMILES` with the other identifiers (randomly). For more detailed information on the implementation and advanced usage, please refer to the source code and unit tests. diff --git a/src/chemnlp/data/sampler.py b/src/chemnlp/data/sampler.py index 9a7cf5045..5c563c5dc 100644 --- a/src/chemnlp/data/sampler.py +++ b/src/chemnlp/data/sampler.py @@ -58,6 +58,26 @@ def __init__( self.class_balanced = False self.balance_column = None self.wrap_identifiers = config.get("wrap_identifiers", False) + self.additional_targets = self._get_additional_targets(df) + self._add_additional_targets_to_meta() + + def _get_additional_targets(self, df: pd.DataFrame) -> List[str]: + additional_targets = [] + for col in ["selfies", "deepsmiles", "canonical", "inchi", "iupac_name"]: + if col in df.columns: + additional_targets.append(col) + return additional_targets + + def _add_additional_targets_to_meta(self): + additional_targets_meta = { + "selfies": {"id": "selfies", "type": "selfies", "description": "SELFIES"}, + "deepsmiles": {"id": "deepsmiles", "type": "deepsmiles", "description": "DeepSMILES"}, + "canonical": {"id": "canonical", "type": "canonical", "description": "canonical SMILES"}, + "inchi": {"id": "inchi", "type": "inchi", "description": "InChI"}, + "iupac_name": {"id": "iupac_name", "type": "iupac_name", "description": "IUPAC name"}, + } + for target in self.additional_targets: + self.meta["targets"].append(additional_targets_meta[target]) def _wrap_identifier(self, identifier: str, value: str) -> str: """Wrap the identifier value with tags if wrap_identifiers is enabled.""" @@ -517,6 +537,11 @@ def sample(self, sample: pd.Series, template: str) -> str: """ if sample is None: sample = self.df.sample(1).iloc[0] + if self.additional_targets and "SMILES" in sample.index: + non_nan_targets = [target for target in ["SMILES"] + self.additional_targets if pd.notna(sample[target])] + new_target = random.choice(non_nan_targets) + if new_target != "SMILES": + template = template.replace("{SMILES", "{" + new_target) sample_dict = self.get_sample_dict(sample, template) return self._fill_template(template, sample_dict) diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py index 3c806a4ac..34452b678 100644 --- a/tests/data/test_sampler.py +++ b/tests/data/test_sampler.py @@ -19,6 +19,17 @@ def sample_df(): } ) +@pytest.fixture +def sample_multiple_identifier_df(): + return pd.DataFrame({ + 'SMILES': ['CC(C)NCC(O)c1ccc(O)c(O)c1', 'CC1=C(C(=O)NC2=C1C=CC=C2)C3=CC=CC=C3'], + 'selfies': ['[C][C][Branch1][C][C][N][C][C][Branch1][O][C][=C][C][=C][C][Branch1][O][C][=C][Branch1][O][C][=C]', '[C][C]1[=C][Branch1][C][Branch2][=O][N][C]2[=C]1[C][=C][C][=C][C]2[C]3[=C][C][=C][C][=C][C]3'], + 'inchi': ['InChI=1S/C11H17NO3/c1-8(2)12-6-9(13)10-4-3-5-11(14)7-10/h3-5,7-9,12-14H,6H2,1-2H3', 'InChI=1S/C15H11NO/c17-14-11-9-5-1-3-7-13(9)16-15(14)10-6-2-4-8-12(10)11/h1-8,16H'], + 'compound_name': ['Isoproterenol', 'Phenytoin'], + 'LogP': [0.08, 2.47], + 'is_active': [True, False], + 'split': ['train', 'test'] + }) @pytest.fixture def sample_meta(): @@ -54,6 +65,23 @@ def sample_meta(): } +@pytest.fixture +def sample_multiple_identifier_meta(): + return { + "targets": [ + {"id": "LogP", "type": "continuous", "description": "Logarithm of partition coefficient"}, + {"id": "is_active", "type": "categorical", "description": "Activity status"} + ], + "identifiers": [ + {"id": "SMILES", "type": "SMILES", "description": "SMILES notation"}, + {"id": "compound_name", "type": "Other", "description": "Compound name"} + ], + "templates": [ + "The molecule with SMILES {SMILES#} has a LogP of {LogP#}.", + "The compound {compound_name#} is {is_active#active&inactive}." + ] + } + @pytest.fixture def sample_config(): return { @@ -475,3 +503,23 @@ def test_polymer_multiple_properties( assert "*CC(*)C" in result assert "275.0" in result assert "0.90" in result + +def test_additional_targets(sample_multiple_identifier_df, sample_multiple_identifier_meta, sample_config): + sampler = TemplateSampler(sample_multiple_identifier_df, sample_multiple_identifier_meta, sample_config) + assert set(sampler.additional_targets) == {"selfies", "inchi"} + print(sampler.meta['targets']) + assert len(sampler.meta["targets"]) == 4 + + + +def test_sample_with_random_replacement(sample_multiple_identifier_df, sample_multiple_identifier_meta, sample_config): + sampler = TemplateSampler(sample_multiple_identifier_df, sample_multiple_identifier_meta, sample_config) + template = "The compound with {SMILES__description} {SMILES#} has a {LogP__description} of {LogP#}" + results = [sampler.sample(sample_multiple_identifier_df.iloc[0], template) for _ in range(20)] + smiles_count = sum("CC(C)NCC(O)c1ccc(O)c(O)c1" in r for r in results) + selfies_count = sum("[C][C][Branch1][C][C][N][C][C][Branch1][O][C][=C][C][=C][C][Branch1][O][C][=C][Branch1][O][C][=C]" in r for r in results) + inchi_count = sum("InChI=1S/C11H17NO3/c1-8(2)12-6-9(13)10-4-3-5-11(14)7-10/h3-5,7-9,12-14H,6H2,1-2H3" in r for r in results) + assert smiles_count > 0 + assert selfies_count > 0 + assert inchi_count > 0 + assert smiles_count + selfies_count + inchi_count == 20