Skip to content

Commit

Permalink
test standard templates
Browse files Browse the repository at this point in the history
  • Loading branch information
kjappelbaum committed Aug 14, 2024
1 parent a491182 commit a248c4b
Showing 1 changed file with 85 additions and 15 deletions.
100 changes: 85 additions & 15 deletions tests/data/test_sampler_cli.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import pytest
import os
import pandas as pd
import yaml
import json
from chemnlp.data.sampler_cli import process_dataset
from chemnlp.data.constants import STANDARD_TABULAR_TEXT_TEMPLATES


@pytest.fixture
def temp_data_dir(tmp_path):
data_dir = tmp_path / "data"
data_dir.mkdir()
def temp_tabular_data_dir(tmp_path):
data_dir = tmp_path / "tabular" / "test_dataset"
data_dir.mkdir(parents=True)

# Create meta.yaml
meta = {
'identifiers': [{'id': 'SMILES', 'type': 'SMILES'}],
'targets': [{'id': 'property', 'type': 'continuous'}],
'identifiers': [{'id': 'SMILES', 'type': 'SMILES', 'description': 'SMILES'}],
'targets': [{'id': 'logP', 'type': 'continuous', 'names': [{'noun': 'logP'}], 'units': '' }],
'templates': [
'The molecule with SMILES {SMILES#} has property {property#}.',
'What is the property of the molecule with SMILES {SMILES#}?<EOI>{property#}'
'Custom template: The molecule with SMILES {SMILES#} has logP {logP#}.',
]
}
with open(data_dir / "meta.yaml", "w") as f:
Expand All @@ -25,27 +25,59 @@ def temp_data_dir(tmp_path):
# Create data_clean.csv
df = pd.DataFrame({
'SMILES': ['CC', 'CCC', 'CCCC'],
'property': [1.0, 2.0, 3.0],
'logP': [1.0, 2.0, 3.0],
'split': ['train', 'test', 'valid']
})
df.to_csv(data_dir / "data_clean.csv", index=False)

return data_dir


@pytest.fixture
def temp_data_dir(tmp_path):
data_dir = tmp_path / "data"
data_dir.mkdir()

# Create meta.yaml
meta = {
"identifiers": [{"id": "SMILES", "type": "SMILES"}],
"targets": [{"id": "property", "type": "continuous"}],
"templates": [
"The molecule with SMILES {SMILES#} has property {property#}.",
"What is the property of the molecule with SMILES {SMILES#}?<EOI>{property#}",
],
}
with open(data_dir / "meta.yaml", "w") as f:
yaml.dump(meta, f)

# Create data_clean.csv
df = pd.DataFrame(
{
"SMILES": ["CC", "CCC", "CCCC"],
"property": [1.0, 2.0, 3.0],
"split": ["train", "test", "valid"],
}
)
df.to_csv(data_dir / "data_clean.csv", index=False)

return data_dir


@pytest.fixture
def temp_output_dir(tmp_path):
output_dir = tmp_path / "output"
output_dir.mkdir()
return output_dir


def test_process_dataset(temp_data_dir, temp_output_dir):
process_dataset(
data_dir=str(temp_data_dir),
output_dir=str(temp_output_dir),
chunksize=1000,
class_balanced=False,
benchmarking=False,
multiple_choice=False
multiple_choice=False,
)

# Check that output files were created
Expand All @@ -54,7 +86,7 @@ def test_process_dataset(temp_data_dir, temp_output_dir):
assert template_dir.exists()

# Check the content of the output files
for split in ['train', 'test', 'valid']:
for split in ["train", "test", "valid"]:
with open(template_dir / f"{split}.jsonl", "r") as f:
lines = f.readlines()
assert len(lines) == 1 # One sample per split
Expand All @@ -63,14 +95,15 @@ def test_process_dataset(temp_data_dir, temp_output_dir):
assert "SMILES" in sample["text"]
assert "property" in sample["text"]


def test_process_dataset_benchmarking(temp_data_dir, temp_output_dir):
process_dataset(
data_dir=str(temp_data_dir),
output_dir=str(temp_output_dir),
chunksize=1000,
class_balanced=False,
benchmarking=True,
multiple_choice=False
multiple_choice=False,
)

# Check that output files were created
Expand All @@ -79,7 +112,7 @@ def test_process_dataset_benchmarking(temp_data_dir, temp_output_dir):
assert template_dir.exists()

# Check the content of the output files
for split in ['train', 'test', 'valid']:
for split in ["train", "test", "valid"]:
with open(template_dir / f"{split}.jsonl", "r") as f:
lines = f.readlines()
assert len(lines) == 1 # One sample per split
Expand All @@ -93,14 +126,15 @@ def test_process_dataset_benchmarking(temp_data_dir, temp_output_dir):
except ValueError:
assert False


def test_process_dataset_class_balanced(temp_data_dir, temp_output_dir):
process_dataset(
data_dir=str(temp_data_dir),
output_dir=str(temp_output_dir),
chunksize=1000,
class_balanced=True,
benchmarking=False,
multiple_choice=False
multiple_choice=False,
)

# Check that output files were created
Expand All @@ -109,11 +143,47 @@ def test_process_dataset_class_balanced(temp_data_dir, temp_output_dir):
assert template_dir.exists()

# Check the content of the output files
for split in ['train', 'test', 'valid']:
for split in ["train", "test", "valid"]:
with open(template_dir / f"{split}.jsonl", "r") as f:
lines = f.readlines()
assert len(lines) == 1 # One sample per split
sample = json.loads(lines[0])
assert "text" in sample
assert "SMILES" in sample["text"]
assert "property" in sample["text"]



def test_process_dataset_with_standard_templates(temp_tabular_data_dir, temp_output_dir):
process_dataset(
data_dir=str(temp_tabular_data_dir),
output_dir=str(temp_output_dir),
chunksize=1000,
class_balanced=False,
benchmarking=False,
multiple_choice=False,
use_standard_templates=True
)

# Check that output files were created
chunk_dir = temp_output_dir / "chunk_0"

# Count the number of template directories
template_dirs = list(chunk_dir.glob("template_*"))

print(len(template_dirs))

# Expected number of templates: 1 custom + len(STANDARD_TABULAR_TEXT_TEMPLATES)
expected_template_count = 1 + len([t for t in STANDARD_TABULAR_TEXT_TEMPLATES if not "<EOI>" in t])
assert len(template_dirs) == expected_template_count, f"Expected {expected_template_count} templates, but found {len(template_dirs)}"

# Check the content of the output files for each template
for template_dir in template_dirs:
for split in ['train', 'test', 'valid']:
with open(template_dir / f"{split}.jsonl", "r") as f:
lines = f.readlines()
assert len(lines) == 1 # One sample per split
sample = json.loads(lines[0])
assert "text" in sample
assert "SMILES" in sample["text"]
assert "logP" in sample["text"]

0 comments on commit a248c4b

Please sign in to comment.