Skip to content

Commit

Permalink
revise training script
Browse files Browse the repository at this point in the history
  • Loading branch information
kjappelbaum committed Aug 15, 2024
1 parent 6e37708 commit f72d74f
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 84 deletions.
139 changes: 78 additions & 61 deletions experiments/ablations/continued_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,82 +2,89 @@
import torch
from unsloth import add_new_tokens
from typing import Optional, List
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from unsloth import UnslothTrainer, UnslothTrainingArguments
import fire
import wandb
from datasets import load_dataset
import fire


def load_model(rank: int = 128, train_embeddings: bool = True, add_special_tokens: Optional[List[str]]=None):
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
def load_model(
rank: int = 128,
train_embeddings: bool = True,
add_special_tokens: Optional[List[str]] = None,
):
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/llama-3-8b-bnb-4bit",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
model_name="unsloth/llama-3-8b-bnb-4bit",
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)

add_new_tokens(model, tokenizer, new_tokens = add_special_tokens)
add_new_tokens(model, tokenizer, new_tokens=add_special_tokens)

target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"]
target_modules = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]

if train_embeddings:
target_modules += ["embed_tokens", "lm_head"]
target_modules += ["embed_tokens", "lm_head"]
model = FastLanguageModel.get_peft_model(
model,
r = rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = target_modules,
lora_alpha = rank/4,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
r=rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=target_modules,
lora_alpha=rank / 4,
lora_dropout=0, # Supports any, but = 0 is optimized
bias="none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = True, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
random_state=3407,
use_rslora=True, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
)

return model, tokenizer


def train(model, tokenizer, dataset, run_name: str, batch_size:int =64, max_seq_length = 2048):
wandb.init(
project="chemnlp-ablations",
name=run_name
)
def train(
model, tokenizer, dataset, run_name: str, batch_size: int = 64, max_seq_length=2048
):
wandb.init(project="chemnlp-ablations", name=run_name)
trainer = UnslothTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
dataset_num_proc = 2,

args = UnslothTrainingArguments(
per_device_train_batch_size = batch_size,
gradient_accumulation_steps = 1,
warmup_ratio = 0.1,
num_train_epochs = 1,
learning_rate = 5e-5,
embedding_learning_rate = 1e-5,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = f"outputs_{run_name}",
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=max_seq_length,
dataset_num_proc=2,
args=UnslothTrainingArguments(
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=1,
warmup_ratio=0.1,
num_train_epochs=1,
learning_rate=5e-5,
embedding_learning_rate=1e-5,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir=f"outputs_{run_name}",
),
)

#@title Show current memory stats
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
Expand All @@ -86,28 +93,38 @@ def train(model, tokenizer, dataset, run_name: str, batch_size:int =64, max_seq_

trainer_stats = trainer.train()

model.save_pretrained(f"lora_model_{run_name}") # Local saving
model.save_pretrained(f"lora_model_{run_name}") # Local saving
tokenizer.save_pretrained(f"lora_model_{run_name}")


def create_dataset(tokenizer, datasets):
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN

def formatting_prompts_func(examples):
outputs = []
for t in examples['text']:
for t in examples["text"]:
outputs.append(t + EOS_TOKEN)
return { "text" : outputs, }
return {
"text": outputs,
}

dataset = load_dataset("json", data_files=datasets)
dataset = dataset["train"]

dataset = dataset.map(formatting_prompts_func, batched = True)
dataset = dataset.map(formatting_prompts_func, batched=True)

return dataset

if __name__ == "__main__":
model, tokenizer = load_model(train_embeddings=True, add_special_tokens=None)

dataset = create_dataset(tokenizer, ["data/chemnlp_train.json", "data/chemnlp_val.json"])
def run(data_files: List[str], train_embeddings: bool, run_name: str, batch_size: int, add_special_tokens: Optional[List[str]]=None)
model, tokenizer = load_model(train_embeddings=train_embeddings, add_special_tokens=add_special_tokens )

train(model, tokenizer, dataset, "lora_128", batch_size=64)
dataset = create_dataset(
tokenizer, data_files
)

train(model, tokenizer, dataset, run_name, batch_size=batch_size)


if __name__ == "__main__":
fire.Fire(run)
30 changes: 15 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ dynamic = ["version"]
[project.optional-dependencies]
dev = ["pre-commit", "pytest"]
dataset_creation = [
"PyTDC",
"rdkit",
"ruamel.yaml",
"selfies",
"deepsmiles",
"pubchempy",
"bioc",
"pylatexenc",
"canonicalize_psmiles@git+https://github.com/Ramprasad-Group/canonicalize_psmiles.git",
"rxn-chem-utils",
"backoff",
"givemeconformer",
"chembl_webresource_client",
"dask",
"pandarallel",
"PyTDC",
"rdkit",
"ruamel.yaml",
"selfies",
"deepsmiles",
"pubchempy",
"bioc",
"pylatexenc",
"canonicalize_psmiles@git+https://github.com/Ramprasad-Group/canonicalize_psmiles.git",
"rxn-chem-utils",
"backoff",
"givemeconformer",
"chembl_webresource_client",
"dask",
"pandarallel"
]

[project.scripts]
Expand Down
17 changes: 9 additions & 8 deletions src/chemnlp/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@


from pathlib import Path
import fire


def get_all_datasets(root_dir):
return [d.name for d in Path(root_dir).iterdir() if d.is_dir()]

def concatenate_jsonl_files(root_dir, output_file, datasets=None, file_type='train'):

def concatenate_jsonl_files(root_dir, output_file, datasets=None, file_type="train"):
root_dir = Path(root_dir)

if datasets is None:
Expand All @@ -25,28 +26,28 @@ def concatenate_jsonl_files(root_dir, output_file, datasets=None, file_type='tra
print(f"Processing datasets: {', '.join(datasets)}")
print(f"File type: {file_type}.jsonl")

with open(output_file, 'w') as outfile:
with open(output_file, "w") as outfile:
for dataset in datasets:
dataset_path = root_dir / dataset
if not dataset_path.is_dir():
print(f"Warning: Dataset '{dataset}' not found. Skipping.")
continue

for chunk_dir in dataset_path.glob('chunk_*'):
for template_dir in chunk_dir.glob('template_*'):
jsonl_file = template_dir / f'{file_type}.jsonl'
for chunk_dir in dataset_path.glob("chunk_*"):
for template_dir in chunk_dir.glob("template_*"):
jsonl_file = template_dir / f"{file_type}.jsonl"
if jsonl_file.is_file():
with open(jsonl_file, 'r') as infile:
with open(jsonl_file, "r") as infile:
for line in infile:
outfile.write(line)

print(f"Concatenated {file_type}.jsonl files have been saved to {output_file}")


def concatenate_jsonl_files_cli():
fire.Fire(concatenate_jsonl_files)



def add_random_split_column(df):
# Calculate the number of rows for each split
n_rows = len(df)
Expand Down

0 comments on commit f72d74f

Please sign in to comment.