Skip to content

Commit

Permalink
fix wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
kjappelbaum committed Aug 15, 2024
1 parent f72d74f commit 9d2972e
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 7 deletions.
4 changes: 2 additions & 2 deletions data/tabular/melting_points/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ targets:
benchmarks: []
identifiers:
- id: SMILES
type: text
type: SMILES
description: SMILES
- id: NAME
type: text
type: Other
description: name
license: CC BY 4.0
links:
Expand Down
11 changes: 8 additions & 3 deletions experiments/ablations/continued_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def load_model(
dtype=dtype,
load_in_4bit=load_in_4bit,
)

add_new_tokens(model, tokenizer, new_tokens=add_special_tokens)
if add_special_tokens is not None:
add_new_tokens(model, tokenizer, new_tokens=add_special_tokens)

target_modules = [
"q_proj",
Expand Down Expand Up @@ -116,7 +116,12 @@ def formatting_prompts_func(examples):
return dataset


def run(data_files: List[str], train_embeddings: bool, run_name: str, batch_size: int, add_special_tokens: Optional[List[str]]=None)
def run(data_files: List[str], run_name: str, batch_size: int=64, add_special_tokens: Optional[List[str]]=None, train_embeddings: bool=True):
print(f"Data files {data_files}")
print(f"Run name {run_name}")
print(f"Batch size {batch_size}")
print(f"Add special tokens {add_special_tokens}")
print(f"Train embeddings {train_embeddings}")
model, tokenizer = load_model(train_embeddings=train_embeddings, add_special_tokens=add_special_tokens )

dataset = create_dataset(
Expand Down
3 changes: 3 additions & 0 deletions src/chemnlp/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def _wrap_identifier(self, identifier: str, value: str) -> str:
"""Wrap the identifier value with tags if wrap_identifiers is enabled."""

if not self.wrap_identifiers:
logger.debug("Not wrapping identifiers.")
return value

identifier_type = next(
Expand All @@ -188,9 +189,11 @@ def _wrap_identifier(self, identifier: str, value: str) -> str:
except ValueError:
identifier_type = None

logger.debug(f'Identifier type: {identifier_type}, value: {value}')
if identifier_type and identifier_type not in self.config.get(
"excluded_from_wrapping", []
):
logger.debug(f"Wrapping {identifier_type} with tags.")
return f"[BEGIN_{identifier_type}]{value}[END_{identifier_type}]"
return value

Expand Down
7 changes: 5 additions & 2 deletions src/chemnlp/data/sampler_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def process_dataset(
"excluded_from_wrapping": ["Other"],
}


templates = meta["templates"]
if benchmarking:
templates = [t for t in templates if "<EOI>" in t]
Expand All @@ -116,7 +117,9 @@ def process_dataset(
logger.debug(f"Processing chunk {chunk_idx} to {chunk_output_dir}")
os.makedirs(chunk_output_dir, exist_ok=True)

sampler = TemplateSampler(df_chunk, meta, config, data_dir)
sampler = TemplateSampler(df_chunk, meta=meta, config=config, path_data_dir=data_dir)
if wrap_identifiers:
assert sampler.wrap_identifiers, "Wrap identifiers must be enabled in the sampler"

for template_idx, template in enumerate(templates):
print(
Expand Down Expand Up @@ -177,7 +180,7 @@ def main(
benchmarking,
additional_templates,
use_standard_templates,
wrap_identifiers,
wrap_identifiers=wrap_identifiers,
)


Expand Down

0 comments on commit 9d2972e

Please sign in to comment.