Skip to content

Commit

Permalink
implement eval
Browse files Browse the repository at this point in the history
  • Loading branch information
kjappelbaum committed Aug 15, 2024
1 parent 9d2972e commit 61ca373
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 13 deletions.
26 changes: 19 additions & 7 deletions experiments/ablations/continued_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datasets import load_dataset
import fire


def load_model(
rank: int = 128,
train_embeddings: bool = True,
Expand Down Expand Up @@ -56,7 +57,7 @@ def load_model(


def train(
model, tokenizer, dataset, run_name: str, batch_size: int = 64, max_seq_length=2048
model, tokenizer, dataset, run_name: str, batch_size: int = 64, max_seq_length=2048, eval_dataset=None
):
wandb.init(project="chemnlp-ablations", name=run_name)
trainer = UnslothTrainer(
Expand All @@ -66,6 +67,7 @@ def train(
dataset_text_field="text",
max_seq_length=max_seq_length,
dataset_num_proc=2,
eval_dataset=eval_dataset,
args=UnslothTrainingArguments(
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=1,
Expand All @@ -81,6 +83,8 @@ def train(
lr_scheduler_type="linear",
seed=3407,
output_dir=f"outputs_{run_name}",
eval_strategy = 'steps' if eval_dataset is not None else 'no',
eval_steps = 10_000 if eval_dataset is not None else None
),
)

Expand Down Expand Up @@ -116,19 +120,27 @@ def formatting_prompts_func(examples):
return dataset


def run(data_files: List[str], run_name: str, batch_size: int=64, add_special_tokens: Optional[List[str]]=None, train_embeddings: bool=True):
def run(
data_files: List[str],
run_name: str,
batch_size: int = 64,
add_special_tokens: Optional[List[str]] = None,
train_embeddings: bool = True,
eval_data_files: Optional[List[str]] = None,
):
print(f"Data files {data_files}")
print(f"Run name {run_name}")
print(f"Batch size {batch_size}")
print(f"Add special tokens {add_special_tokens}")
print(f"Train embeddings {train_embeddings}")
model, tokenizer = load_model(train_embeddings=train_embeddings, add_special_tokens=add_special_tokens )

dataset = create_dataset(
tokenizer, data_files
model, tokenizer = load_model(
train_embeddings=train_embeddings, add_special_tokens=add_special_tokens
)

train(model, tokenizer, dataset, run_name, batch_size=batch_size)
dataset = create_dataset(tokenizer, data_files)
eval_dataset = create_dataset(tokenizer, eval_data_files) if eval_data_files else None

train(model, tokenizer, dataset, run_name, batch_size=batch_size, eval_dataset=eval_dataset)


if __name__ == "__main__":
Expand Down
3 changes: 0 additions & 3 deletions src/chemnlp/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def _wrap_identifier(self, identifier: str, value: str) -> str:
"""Wrap the identifier value with tags if wrap_identifiers is enabled."""

if not self.wrap_identifiers:
logger.debug("Not wrapping identifiers.")
return value

identifier_type = next(
Expand All @@ -189,11 +188,9 @@ def _wrap_identifier(self, identifier: str, value: str) -> str:
except ValueError:
identifier_type = None

logger.debug(f'Identifier type: {identifier_type}, value: {value}')
if identifier_type and identifier_type not in self.config.get(
"excluded_from_wrapping", []
):
logger.debug(f"Wrapping {identifier_type} with tags.")
return f"[BEGIN_{identifier_type}]{value}[END_{identifier_type}]"
return value

Expand Down
9 changes: 6 additions & 3 deletions src/chemnlp/data/sampler_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def process_dataset(
"excluded_from_wrapping": ["Other"],
}


templates = meta["templates"]
if benchmarking:
templates = [t for t in templates if "<EOI>" in t]
Expand All @@ -117,9 +116,13 @@ def process_dataset(
logger.debug(f"Processing chunk {chunk_idx} to {chunk_output_dir}")
os.makedirs(chunk_output_dir, exist_ok=True)

sampler = TemplateSampler(df_chunk, meta=meta, config=config, path_data_dir=data_dir)
sampler = TemplateSampler(
df_chunk, meta=meta, config=config, path_data_dir=data_dir
)
if wrap_identifiers:
assert sampler.wrap_identifiers, "Wrap identifiers must be enabled in the sampler"
assert (
sampler.wrap_identifiers
), "Wrap identifiers must be enabled in the sampler"

for template_idx, template in enumerate(templates):
print(
Expand Down

0 comments on commit 61ca373

Please sign in to comment.