diff --git a/docs/api/sampler.md b/docs/api/sampler.md index 2ca3a9e10..e289d50a6 100644 --- a/docs/api/sampler.md +++ b/docs/api/sampler.md @@ -2,7 +2,7 @@ ## Overview -The `sampler` module provides functionality for generating text samples based on templates and data. It is primarily used for creating datasets for natural language processing tasks in chemistry and related fields. The main class in this module is `TemplateSampler`, which allows for flexible text generation with support for multiple choice questions and class balancing. +The `sampler` module provides functionality for generating text samples based on templates and data. It is primarily used for creating datasets for natural language processing tasks in chemistry and related fields. The main class in this module is `TemplateSampler`, which allows for flexible text generation with support for multiple choice questions, class balancing, and identifier wrapping. ## TemplateSampler @@ -21,41 +21,49 @@ sampler = TemplateSampler(df: pd.DataFrame, meta: Dict, config: Dict, column_dat - `config`: A dictionary containing configuration parameters for the sampler. - `column_datafield_sampler`: An optional callable for custom sampling from multiple options. +#### Configuration Options + +- `wrap_identifiers`: Boolean flag to enable wrapping of identifiers with tags (default: False). + #### Main Methods -##### sample +(... other methods remain the same ...) -```python -def sample(self, sample: Optional[pd.Series], template: str) -> str -``` +#### Identifier Wrapping -Generates a text sample based on a template and a data sample. +When `wrap_identifiers` is set to `True` in the configuration, the sampler will wrap identifier values with tags. For example: -- `sample`: A row from the dataset. If None, a random sample is chosen. -- `template`: The template string to be filled. -- Returns: The completed text sample with all variables replaced by their values. +- `[BEGIN_SMILES]CC(C)NCC(O)c1ccc(O)c(O)c1[END_SMILES]` +- `[BEGIN_InChI]InChI=1S/C8H9NO2/c1-6(10)9-7-2-4-8(11)5-3-7/h2-5,11H,1H3,(H,9,10)[END_InChI]` + +This feature can be useful for downstream tasks that need to identify specific types of chemical identifiers in the generated text. + +#### Usage Examples -##### enable_class_balancing +Basic usage: ```python -def enable_class_balancing(self, column: str) -``` +import pandas as pd +from chemnlp.data.sampler import TemplateSampler -Enables class-balanced sampling for a specified column. +# Prepare your data, metadata, and config +df = pd.DataFrame(...) +meta = {...} +config = {...} -- `column`: The column to use for balancing. +# Initialize the sampler +sampler = TemplateSampler(df, meta, config) -##### disable_class_balancing +# Define a template +template = "The molecule with SMILES {SMILES#} has a {property#} of {value#}." -```python -def disable_class_balancing(self) +# Generate a sample +result = sampler.sample(df.iloc[0], template) +print(result) ``` -Disables class-balanced sampling and reverts to the original dataset. -#### Usage Examples - -Basic usage: +Basic usage with identifier wrapping: ```python import pandas as pd @@ -64,7 +72,10 @@ from chemnlp.data.sampler import TemplateSampler # Prepare your data, metadata, and config df = pd.DataFrame(...) meta = {...} -config = {...} +config = { + 'wrap_identifiers': True, + # ... other config options +} # Initialize the sampler sampler = TemplateSampler(df, meta, config) @@ -75,8 +86,10 @@ template = "The molecule with SMILES {SMILES#} has a {property#} of {value#}." # Generate a sample result = sampler.sample(df.iloc[0], template) print(result) +# Output: The molecule with SMILES [BEGIN_SMILES]CC(C)NCC(O)c1ccc(O)c(O)c1[END_SMILES] has a LogP of 1.23. ``` + Using class balancing: ```python @@ -106,9 +119,8 @@ print(mc_result) ## Notes -- The `TemplateSampler` class supports various types of templates, including those with multiple choice questions. -- Class balancing can be useful for creating balanced datasets for machine learning tasks. -- The sampler can handle both categorical and continuous data types, with proper formatting for continuous values. -- Custom sampling functions can be provided for more control over how values are selected from multiple options. +- The `TemplateSampler` class now supports wrapping of identifiers with tags when the `wrap_identifiers` option is enabled in the configuration. +- Wrapped identifiers use the format `[BEGIN_IDENTIFIER_TYPE]value[END_IDENTIFIER_TYPE]`. +- Identifier types are based on the `IdentifierEnum` class, which includes common chemical identifiers like SMILES, InChI, and others. For more detailed information on the implementation and advanced usage, please refer to the source code and unit tests. diff --git a/src/chemnlp/data/sampler.py b/src/chemnlp/data/sampler.py index a8b2b522d..3e1cdecc9 100644 --- a/src/chemnlp/data/sampler.py +++ b/src/chemnlp/data/sampler.py @@ -8,7 +8,7 @@ from chemnlp.data.random_variable import RandomVariable from functools import partial from functools import lru_cache - +from chemnlp.data_val.model import IdentifierEnum # ToDo: handle somewhere that the meta contains multiple templates class TemplateSampler: @@ -53,6 +53,25 @@ def __init__( self.column_datafield_sampler = column_datafield_sampler or (lambda x: random.sample(x, k=1)) self.class_balanced = False self.balance_column = None + self.wrap_identifiers = config.get('wrap_identifiers', False) + + def _wrap_identifier(self, identifier: str, value: str) -> str: + """Wrap the identifier value with tags if wrap_identifiers is enabled.""" + print('wrap_identifier', identifier, value, self.wrap_identifiers) + + if not self.wrap_identifiers: + return value + + identifier_type = next((item['type'] for item in self.meta['identifiers'] if item['id'] == identifier), None) + + try: + identifier_type = IdentifierEnum(identifier_type) + except ValueError: + identifier_type = None + + if identifier_type: + return f"[BEGIN_{identifier_type}]{value}[END_{identifier_type}]" + return value def _balance_classes(self, column: str) -> pd.DataFrame: """ @@ -416,7 +435,9 @@ def sample(self, sample: pd.Series, template: str) -> str: def _fill_template(self, template: str, sample_dict: Dict[str, Union[str, List[str]]]) -> str: for key, value in sample_dict.items(): if isinstance(value, list): - # Handle list values (e.g., for multiple-choice options) value = '\n'.join(value) + if '#' in key: # This indicates it's an identifier + identifier = key.replace('#', '') + value = self._wrap_identifier(identifier, str(value)) template = template.replace('{' + key + '}', str(value)) return template diff --git a/src/chemnlp/data_val/model.py b/src/chemnlp/data_val/model.py index 1c381fd4f..0207f1a9b 100644 --- a/src/chemnlp/data_val/model.py +++ b/src/chemnlp/data_val/model.py @@ -1,6 +1,5 @@ from typing import List, Optional -import pubchempy as pcp import requests from pydantic import Extra, root_validator, validator from pydantic_yaml import YamlModel, YamlStrEnum diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py index b30328216..526234bb0 100644 --- a/tests/data/test_sampler.py +++ b/tests/data/test_sampler.py @@ -48,6 +48,16 @@ def sample_config(): } +@pytest.fixture +def sample_config_with_wrapping(): + return { + 'DEFAULT_SIGNIFICANT_DIGITS': 2, + 'multiple_choice_rnd_symbols': ["", ".)", ")"], + 'multiple_choice_benchmarking_templates': False, + 'multiple_choice_benchmarking_format': None, + 'wrap_identifiers': True + } + # Add these to your existing fixtures or create new ones as needed @pytest.fixture def large_sample_df(): @@ -73,6 +83,14 @@ def large_sample_meta(sample_meta): return sample_meta +def test_basic_identifier_wrapping(sample_df, sample_meta, sample_config_with_wrapping): + sampler = TemplateSampler(sample_df, sample_meta, sample_config_with_wrapping) + template = "SMILES: {SMILES#}, Name: {compound_name#}" + result = sampler.sample(sample_df.iloc[0], template) + print(result) + assert "[BEGIN_SMILES]" in result and "[END_SMILES]" in result + assert "[BEGIN_Other]" in result and "[END_Other]" in result + def test_get_target_from_row(sample_df, sample_meta, sample_config): sampler = TemplateSampler(sample_df, sample_meta, sample_config) assert sampler._get_target_from_row(sample_df.iloc[0], "SMILES#") == "CC(C)NCC(O)c1ccc(O)c(O)c1" @@ -169,3 +187,31 @@ def test_random_sampling(large_sample_df, large_sample_meta, sample_config): # Check if we have at least two different results (high probability) assert len(set(results)) > 1 + + +def test_multiple_identifier_types(sample_df, sample_meta, sample_config_with_wrapping): + sampler = TemplateSampler(sample_df, sample_meta, sample_config_with_wrapping) + template = "SMILES: {SMILES#}, Name: {compound_name#}" + result = sampler.sample(sample_df.iloc[0], template) + assert all(tag in result for tag in ["[BEGIN_SMILES]", "[END_SMILES]", "[BEGIN_Other]", "[END_Other]"]) + + +def test_wrapping_with_multiple_choice(sample_df, sample_meta, sample_config_with_wrapping): + sampler = TemplateSampler(sample_df, sample_meta, sample_config_with_wrapping) + template = """ + Which compound has this SMILES: {SMILES#}? + {%multiple_choice_enum%2%aA1} + {compound_name%} + Answer: {%multiple_choice_result} + """ + result = sampler.sample(sample_df.iloc[0], template) + assert "[BEGIN_SMILES]" in result and "[END_SMILES]" in result + assert "A or B" in result or "a or b" in result or "1 or 2" in result + + +def test_wrapping_with_continuous_value(large_sample_df, large_sample_meta, sample_config_with_wrapping): + sampler = TemplateSampler(large_sample_df, large_sample_meta, sample_config_with_wrapping) + template = "SMILES: {SMILES#}, LogP: {LogP#}" + result = sampler.sample(large_sample_df.iloc[0], template) + assert "[BEGIN_SMILES]" in result and "[END_SMILES]" in result + assert re.search(r"LogP: \d+\.\d{2}", result) # Checks for 2 decimal places