Skip to content

Commit

Permalink
add more tests and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kjappelbaum committed Aug 12, 2024
1 parent 01e9de1 commit 3a713f1
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 17 deletions.
Empty file.
21 changes: 13 additions & 8 deletions src/chemnlp/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from functools import partial
from functools import lru_cache


# ToDo: handle somewhere that the meta contains multiple templates
class TemplateSampler:
"""
A class for sampling and generating text based on templates and data.
Expand Down Expand Up @@ -165,7 +167,7 @@ def _get_target_from_row(self, sample: pd.Series, var: str) -> str:

return out

def get_sample_dict(self, sample: pd.Series, template: str) -> Dict[str, str]:
def get_sample_dict(self, sample: pd.Series, template: str) -> Dict[str, str]:
"""
Extract and process all target values from a sample row based on a template.
"""
Expand Down Expand Up @@ -266,10 +268,10 @@ def _handle_multiple_choice(self, sample: pd.Series, input_variables: List[str])

def _get_choices_without_indicator(self, multiple_choice_var: str, symbols: List[str], correct_choice: str) -> Tuple[List[str], int]:
cutoff_full_unique = 100
if len(self.df[multiple_choice_var].unique()) < cutoff_full_unique:
all_choices = sorted([str(x) for x in self.df[multiple_choice_var].unique()])
else:
all_choices = sorted([str(x) for x in self.df[multiple_choice_var].sample(cutoff_full_unique).unique()])
all_choices = self.df[multiple_choice_var].unique()
if len(all_choices) > cutoff_full_unique:
all_choices = self.df[multiple_choice_var].sample(cutoff_full_unique).unique()
all_choices = sorted([str(x) for x in all_choices])

if all_choices == ["0", "1"]:
all_choices = ["False", "True"]
Expand All @@ -295,7 +297,7 @@ def _get_choices_with_indicator(self, sample: pd.Series, multiple_choice_var: st
multiple_choices, multiple_choices_indicators = zip(*multiple_choices_combined)

correct_choice_idx = [i for i, (choice, indicator) in enumerate(zip(multiple_choices, multiple_choices_indicators))
if indicator == correct_choice_indicator]
if indicator == correct_choice_indicator]

return list(multiple_choices), correct_choice_idx

Expand Down Expand Up @@ -411,7 +413,10 @@ def sample(self, sample: pd.Series, template: str) -> str:
sample_dict = self.get_sample_dict(sample, template)
return self._fill_template(template, sample_dict)

def _fill_template(self, template: str, sample_dict: Dict[str, str]) -> str:
def _fill_template(self, template: str, sample_dict: Dict[str, Union[str, List[str]]]) -> str:
for key, value in sample_dict.items():
template = template.replace('{' + key + '}', value)
if isinstance(value, list):
# Handle list values (e.g., for multiple-choice options)
value = '\n'.join(value)
template = template.replace('{' + key + '}', str(value))
return template
111 changes: 102 additions & 9 deletions tests/data/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import pandas as pd
from chemnlp.data.sampler import TemplateSampler
import numpy as np
import re

@pytest.fixture
def sample_df():
Expand Down Expand Up @@ -45,6 +47,32 @@ def sample_config():
'multiple_choice_benchmarking_format': None
}


# Add these to your existing fixtures or create new ones as needed
@pytest.fixture
def large_sample_df():
np.random.seed(42)
return pd.DataFrame({
'SMILES': [f'C{i}' for i in range(1000)],
'CYP2D6_Substrate': np.random.choice([0, 1], size=1000, p=[0.7, 0.3]),
'LogP': np.random.normal(2, 1, 1000),
'compound_name': [f'Compound_{i}' for i in range(1000)],
'split': np.random.choice(['train', 'test', 'valid'], size=1000, p=[0.8, 0.1, 0.1])
})

@pytest.fixture
def large_sample_meta(sample_meta):
sample_meta['targets'].append({
"id": "LogP",
"type": "continuous",
"description": "Logarithm of the partition coefficient",
"names": [{"noun": "LogP value"}, {"noun": "partition coefficient"}],
"units": "log units",
"significant_digits": 2
})
return sample_meta


def test_get_target_from_row(sample_df, sample_meta, sample_config):
sampler = TemplateSampler(sample_df, sample_meta, sample_config)
assert sampler._get_target_from_row(sample_df.iloc[0], "SMILES#") == "CC(C)NCC(O)c1ccc(O)c(O)c1"
Expand All @@ -70,21 +98,86 @@ def test_sample_with_template(sample_df, sample_meta, sample_config):

def test_multiple_choice_template(sample_df, sample_meta, sample_config):
sampler = TemplateSampler(sample_df, sample_meta, sample_config)
template = """
Task: Please answer the multiple choice question.
Question: Is the molecule with the {SMILES__description} {SMILES#} {CYP2D6_Substrate__names__verb}?
Constraint: Even if you are uncertain, you must pick either {%multiple_choice_enum%2%aA1} without using any other words.
Options:
{CYP2D6_Substrate%}
Answer: {%multiple_choice_result}
"""
template = """Task: Please answer the multiple choice question.
Question: Is the molecule with the {SMILES__description} {SMILES#} {CYP2D6_Substrate__names__verb}?
Constraint: Even if you are uncertain, you must pick either {%multiple_choice_enum%2%aA1} without using any other words.
Options:
{CYP2D6_Substrate%}
Answer: {%multiple_choice_result}"""
result = sampler.sample(sample_df.iloc[0], template)
assert "CC(C)NCC(O)c1ccc(O)c(O)c1" in result
assert "A or B" in result or "a or b" in result or "1 or 2" in result
assert result.strip().endswith("A") or result.strip().endswith("a") or result.strip().endswith("1")
# Check that the answer is one of the options
answer_letter = result.split("Answer: ")[1].strip()
# assert that "True" in in the line starting with the answer letter
assert "True" in [line for line in result.split("\n") if line.startswith(answer_letter) and ('True' in line or 'False' in line)][0]

def test_class_balancing(sample_df, sample_meta, sample_config):
sampler = TemplateSampler(sample_df, sample_meta, sample_config)
sampler.enable_class_balancing("CYP2D6_Substrate")
balanced_df = sampler.df
assert len(balanced_df[balanced_df['CYP2D6_Substrate'] == 0]) == len(balanced_df[balanced_df['CYP2D6_Substrate'] == 1])



def test_class_balancing_large_dataset(large_sample_df, large_sample_meta, sample_config):
sampler = TemplateSampler(large_sample_df, large_sample_meta, sample_config)
sampler.enable_class_balancing("CYP2D6_Substrate")
assert len(sampler.df) < len(large_sample_df)
assert len(sampler.df[sampler.df['CYP2D6_Substrate'] == 0]) == len(sampler.df[sampler.df['CYP2D6_Substrate'] == 1])




def test_class_balancing_disable(large_sample_df, large_sample_meta, sample_config):
sampler = TemplateSampler(large_sample_df, large_sample_meta, sample_config)
sampler.enable_class_balancing("CYP2D6_Substrate")
assert len(sampler.df) < len(large_sample_df)

sampler.disable_class_balancing()
assert len(sampler.df) == len(large_sample_df)
assert (sampler.df['CYP2D6_Substrate'].value_counts() != sampler.df['CYP2D6_Substrate'].value_counts().iloc[0]).any()

def test_continuous_value_formatting(large_sample_df, large_sample_meta, sample_config):
sampler = TemplateSampler(large_sample_df, large_sample_meta, sample_config)
template = "The {LogP__names__noun} of {compound_name#} is {LogP#} {LogP__units}."
result = sampler.sample(large_sample_df.iloc[0], template)

assert "LogP value" in result or "partition coefficient" in result
assert "log units" in result
assert re.search(r'\d+\.\d{2} log units', result) # Check if the value is rounded to 2 decimal places

def test_error_handling_invalid_variable(sample_df, sample_meta, sample_config):
sampler = TemplateSampler(sample_df, sample_meta, sample_config)
template = "This is an {invalid_variable#}."

with pytest.raises(KeyError):
sampler.sample(sample_df.iloc[0], template)

def test_multiple_targets_in_template(large_sample_df, large_sample_meta, sample_config):
sampler = TemplateSampler(large_sample_df, large_sample_meta, sample_config)
template = "The molecule {compound_name#} with {SMILES__description} {SMILES#} has a {LogP__names__noun} of {LogP#} {LogP__units} and is {CYP2D6_Substrate#not &NULL}a {CYP2D6_Substrate__names__noun}."
result = sampler.sample(large_sample_df.iloc[0], template)
print(result)
assert all(x in result for x in ['Compound_', 'C', 'LogP value', 'log units', 'CYP'])
assert ('is a' in result and 'not a' not in result) or ('is not a' in result and 'is a' not in result)

def test_consistent_sampling(large_sample_df, large_sample_meta, sample_config):
sampler = TemplateSampler(large_sample_df, large_sample_meta, sample_config)
template = "The {LogP__names__noun} is {LogP#}."

# Sample multiple times with the same row
results = [sampler.sample(large_sample_df.iloc[0], template) for _ in range(10)]

# Check if all results are identical
assert all(result == results[0] for result in results)

def test_random_sampling(large_sample_df, large_sample_meta, sample_config):
sampler = TemplateSampler(large_sample_df, large_sample_meta, sample_config)
template = "The {compound_name#} has a {LogP__names__noun} of {LogP#}."

# Sample multiple times without specifying a row
results = [sampler.sample(None, template) for _ in range(10)]

# Check if we have at least two different results (high probability)
assert len(set(results)) > 1

0 comments on commit 3a713f1

Please sign in to comment.