Skip to content

Commit

Permalink
wrapping implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
kjappelbaum committed Aug 12, 2024
1 parent 07ee7a1 commit ff08213
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 29 deletions.
64 changes: 38 additions & 26 deletions docs/api/sampler.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Overview

The `sampler` module provides functionality for generating text samples based on templates and data. It is primarily used for creating datasets for natural language processing tasks in chemistry and related fields. The main class in this module is `TemplateSampler`, which allows for flexible text generation with support for multiple choice questions and class balancing.
The `sampler` module provides functionality for generating text samples based on templates and data. It is primarily used for creating datasets for natural language processing tasks in chemistry and related fields. The main class in this module is `TemplateSampler`, which allows for flexible text generation with support for multiple choice questions, class balancing, and identifier wrapping.

## TemplateSampler

Expand All @@ -21,41 +21,49 @@ sampler = TemplateSampler(df: pd.DataFrame, meta: Dict, config: Dict, column_dat
- `config`: A dictionary containing configuration parameters for the sampler.
- `column_datafield_sampler`: An optional callable for custom sampling from multiple options.

#### Configuration Options

- `wrap_identifiers`: Boolean flag to enable wrapping of identifiers with tags (default: False).

#### Main Methods

##### sample
(... other methods remain the same ...)

```python
def sample(self, sample: Optional[pd.Series], template: str) -> str
```
#### Identifier Wrapping

Generates a text sample based on a template and a data sample.
When `wrap_identifiers` is set to `True` in the configuration, the sampler will wrap identifier values with tags. For example:

- `sample`: A row from the dataset. If None, a random sample is chosen.
- `template`: The template string to be filled.
- Returns: The completed text sample with all variables replaced by their values.
- `[BEGIN_SMILES]CC(C)NCC(O)c1ccc(O)c(O)c1[END_SMILES]`
- `[BEGIN_InChI]InChI=1S/C8H9NO2/c1-6(10)9-7-2-4-8(11)5-3-7/h2-5,11H,1H3,(H,9,10)[END_InChI]`

This feature can be useful for downstream tasks that need to identify specific types of chemical identifiers in the generated text.

#### Usage Examples

##### enable_class_balancing
Basic usage:

```python
def enable_class_balancing(self, column: str)
```
import pandas as pd
from chemnlp.data.sampler import TemplateSampler

Enables class-balanced sampling for a specified column.
# Prepare your data, metadata, and config
df = pd.DataFrame(...)
meta = {...}
config = {...}

- `column`: The column to use for balancing.
# Initialize the sampler
sampler = TemplateSampler(df, meta, config)

##### disable_class_balancing
# Define a template
template = "The molecule with SMILES {SMILES#} has a {property#} of {value#}."

```python
def disable_class_balancing(self)
# Generate a sample
result = sampler.sample(df.iloc[0], template)
print(result)
```

Disables class-balanced sampling and reverts to the original dataset.

#### Usage Examples

Basic usage:
Basic usage with identifier wrapping:

```python
import pandas as pd
Expand All @@ -64,7 +72,10 @@ from chemnlp.data.sampler import TemplateSampler
# Prepare your data, metadata, and config
df = pd.DataFrame(...)
meta = {...}
config = {...}
config = {
'wrap_identifiers': True,
# ... other config options
}

# Initialize the sampler
sampler = TemplateSampler(df, meta, config)
Expand All @@ -75,8 +86,10 @@ template = "The molecule with SMILES {SMILES#} has a {property#} of {value#}."
# Generate a sample
result = sampler.sample(df.iloc[0], template)
print(result)
# Output: The molecule with SMILES [BEGIN_SMILES]CC(C)NCC(O)c1ccc(O)c(O)c1[END_SMILES] has a LogP of 1.23.
```


Using class balancing:

```python
Expand Down Expand Up @@ -106,9 +119,8 @@ print(mc_result)

## Notes

- The `TemplateSampler` class supports various types of templates, including those with multiple choice questions.
- Class balancing can be useful for creating balanced datasets for machine learning tasks.
- The sampler can handle both categorical and continuous data types, with proper formatting for continuous values.
- Custom sampling functions can be provided for more control over how values are selected from multiple options.
- The `TemplateSampler` class now supports wrapping of identifiers with tags when the `wrap_identifiers` option is enabled in the configuration.
- Wrapped identifiers use the format `[BEGIN_IDENTIFIER_TYPE]value[END_IDENTIFIER_TYPE]`.
- Identifier types are based on the `IdentifierEnum` class, which includes common chemical identifiers like SMILES, InChI, and others.

For more detailed information on the implementation and advanced usage, please refer to the source code and unit tests.
25 changes: 23 additions & 2 deletions src/chemnlp/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from chemnlp.data.random_variable import RandomVariable
from functools import partial
from functools import lru_cache

from chemnlp.data_val.model import IdentifierEnum

# ToDo: handle somewhere that the meta contains multiple templates
class TemplateSampler:
Expand Down Expand Up @@ -53,6 +53,25 @@ def __init__(
self.column_datafield_sampler = column_datafield_sampler or (lambda x: random.sample(x, k=1))
self.class_balanced = False
self.balance_column = None
self.wrap_identifiers = config.get('wrap_identifiers', False)

def _wrap_identifier(self, identifier: str, value: str) -> str:
"""Wrap the identifier value with tags if wrap_identifiers is enabled."""
print('wrap_identifier', identifier, value, self.wrap_identifiers)

if not self.wrap_identifiers:
return value

identifier_type = next((item['type'] for item in self.meta['identifiers'] if item['id'] == identifier), None)

try:
identifier_type = IdentifierEnum(identifier_type)
except ValueError:
identifier_type = None

if identifier_type:
return f"[BEGIN_{identifier_type}]{value}[END_{identifier_type}]"
return value

def _balance_classes(self, column: str) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -416,7 +435,9 @@ def sample(self, sample: pd.Series, template: str) -> str:
def _fill_template(self, template: str, sample_dict: Dict[str, Union[str, List[str]]]) -> str:
for key, value in sample_dict.items():
if isinstance(value, list):
# Handle list values (e.g., for multiple-choice options)
value = '\n'.join(value)
if '#' in key: # This indicates it's an identifier
identifier = key.replace('#', '')
value = self._wrap_identifier(identifier, str(value))
template = template.replace('{' + key + '}', str(value))
return template
1 change: 0 additions & 1 deletion src/chemnlp/data_val/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List, Optional

import pubchempy as pcp
import requests
from pydantic import Extra, root_validator, validator
from pydantic_yaml import YamlModel, YamlStrEnum
Expand Down
46 changes: 46 additions & 0 deletions tests/data/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ def sample_config():
}


@pytest.fixture
def sample_config_with_wrapping():
return {
'DEFAULT_SIGNIFICANT_DIGITS': 2,
'multiple_choice_rnd_symbols': ["", ".)", ")"],
'multiple_choice_benchmarking_templates': False,
'multiple_choice_benchmarking_format': None,
'wrap_identifiers': True
}

# Add these to your existing fixtures or create new ones as needed
@pytest.fixture
def large_sample_df():
Expand All @@ -73,6 +83,14 @@ def large_sample_meta(sample_meta):
return sample_meta


def test_basic_identifier_wrapping(sample_df, sample_meta, sample_config_with_wrapping):
sampler = TemplateSampler(sample_df, sample_meta, sample_config_with_wrapping)
template = "SMILES: {SMILES#}, Name: {compound_name#}"
result = sampler.sample(sample_df.iloc[0], template)
print(result)
assert "[BEGIN_SMILES]" in result and "[END_SMILES]" in result
assert "[BEGIN_Other]" in result and "[END_Other]" in result

def test_get_target_from_row(sample_df, sample_meta, sample_config):
sampler = TemplateSampler(sample_df, sample_meta, sample_config)
assert sampler._get_target_from_row(sample_df.iloc[0], "SMILES#") == "CC(C)NCC(O)c1ccc(O)c(O)c1"
Expand Down Expand Up @@ -169,3 +187,31 @@ def test_random_sampling(large_sample_df, large_sample_meta, sample_config):

# Check if we have at least two different results (high probability)
assert len(set(results)) > 1


def test_multiple_identifier_types(sample_df, sample_meta, sample_config_with_wrapping):
sampler = TemplateSampler(sample_df, sample_meta, sample_config_with_wrapping)
template = "SMILES: {SMILES#}, Name: {compound_name#}"
result = sampler.sample(sample_df.iloc[0], template)
assert all(tag in result for tag in ["[BEGIN_SMILES]", "[END_SMILES]", "[BEGIN_Other]", "[END_Other]"])


def test_wrapping_with_multiple_choice(sample_df, sample_meta, sample_config_with_wrapping):
sampler = TemplateSampler(sample_df, sample_meta, sample_config_with_wrapping)
template = """
Which compound has this SMILES: {SMILES#}?
{%multiple_choice_enum%2%aA1}
{compound_name%}
Answer: {%multiple_choice_result}
"""
result = sampler.sample(sample_df.iloc[0], template)
assert "[BEGIN_SMILES]" in result and "[END_SMILES]" in result
assert "A or B" in result or "a or b" in result or "1 or 2" in result


def test_wrapping_with_continuous_value(large_sample_df, large_sample_meta, sample_config_with_wrapping):
sampler = TemplateSampler(large_sample_df, large_sample_meta, sample_config_with_wrapping)
template = "SMILES: {SMILES#}, LogP: {LogP#}"
result = sampler.sample(large_sample_df.iloc[0], template)
assert "[BEGIN_SMILES]" in result and "[END_SMILES]" in result
assert re.search(r"LogP: \d+\.\d{2}", result) # Checks for 2 decimal places

0 comments on commit ff08213

Please sign in to comment.