Skip to content

Commit

Permalink
add training script using unsloth
Browse files Browse the repository at this point in the history
  • Loading branch information
kjappelbaum committed Aug 15, 2024
1 parent c0ddf9f commit 6e37708
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 3 deletions.
113 changes: 113 additions & 0 deletions experiments/ablations/continued_pretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from unsloth import FastLanguageModel
import torch
from unsloth import add_new_tokens
from typing import Optional, List
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from unsloth import UnslothTrainer, UnslothTrainingArguments
import fire
import wandb
from datasets import load_dataset


def load_model(rank: int = 128, train_embeddings: bool = True, add_special_tokens: Optional[List[str]]=None):
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/llama-3-8b-bnb-4bit",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)

add_new_tokens(model, tokenizer, new_tokens = add_special_tokens)

target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"]

if train_embeddings:
target_modules += ["embed_tokens", "lm_head"]
model = FastLanguageModel.get_peft_model(
model,
r = rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = target_modules,
lora_alpha = rank/4,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = True, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)

return model, tokenizer


def train(model, tokenizer, dataset, run_name: str, batch_size:int =64, max_seq_length = 2048):
wandb.init(
project="chemnlp-ablations",
name=run_name
)
trainer = UnslothTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
dataset_num_proc = 2,

args = UnslothTrainingArguments(
per_device_train_batch_size = batch_size,
gradient_accumulation_steps = 1,
warmup_ratio = 0.1,
num_train_epochs = 1,
learning_rate = 5e-5,
embedding_learning_rate = 1e-5,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = f"outputs_{run_name}",
),
)

#@title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

trainer_stats = trainer.train()

model.save_pretrained(f"lora_model_{run_name}") # Local saving
tokenizer.save_pretrained(f"lora_model_{run_name}")


def create_dataset(tokenizer, datasets):
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
outputs = []
for t in examples['text']:
outputs.append(t + EOS_TOKEN)
return { "text" : outputs, }

dataset = load_dataset("json", data_files=datasets)
dataset = dataset["train"]

dataset = dataset.map(formatting_prompts_func, batched = True)

return dataset

if __name__ == "__main__":
model, tokenizer = load_model(train_embeddings=True, add_special_tokens=None)

dataset = create_dataset(tokenizer, ["data/chemnlp_train.json", "data/chemnlp_val.json"])

train(model, tokenizer, dataset, "lora_128", batch_size=64)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ chemnlp-generate-meta = "chemnlp.data.meta_yaml_generator:cli"
chemnlp-augment-meta = "chemnlp.data.meta_yaml_augmenter:cli"
chemnlp-sample = "chemnlp.data.sampler_cli:cli"
chemnlp-add-random-split-column = "chemnlp.data.utils:add_random_split_column_cli"
chemnlp-concatenate-jsonl = "chemnlp.data.utils:concatenate_jsonl_files_cli"

[tool.setuptools_scm]
version_scheme = "post-release"
5 changes: 2 additions & 3 deletions src/chemnlp/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,9 +839,8 @@ def export(self, output_dir: str, template: str) -> pd.DataFrame:
df_split = self.df[self.df["split"] == split]
samples = []
for _, row in tqdm(df_split.iterrows(), total=len(df_split)):
sample_dict = row.to_dict()
sample = self._fill_template(template, sample_dict)
samples.append(sample)
sampled = self.sample(row, template)
samples.append(sampled)
df_out = pd.DataFrame(samples)

# if self.benchmarking_templates:
Expand Down
39 changes: 39 additions & 0 deletions src/chemnlp/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,45 @@
import pandas as pd


from pathlib import Path
import fire

def get_all_datasets(root_dir):
return [d.name for d in Path(root_dir).iterdir() if d.is_dir()]

def concatenate_jsonl_files(root_dir, output_file, datasets=None, file_type='train'):
root_dir = Path(root_dir)

if datasets is None:
datasets = get_all_datasets(root_dir)
elif isinstance(datasets, str):
datasets = [datasets]

print(f"Processing datasets: {', '.join(datasets)}")
print(f"File type: {file_type}.jsonl")

with open(output_file, 'w') as outfile:
for dataset in datasets:
dataset_path = root_dir / dataset
if not dataset_path.is_dir():
print(f"Warning: Dataset '{dataset}' not found. Skipping.")
continue

for chunk_dir in dataset_path.glob('chunk_*'):
for template_dir in chunk_dir.glob('template_*'):
jsonl_file = template_dir / f'{file_type}.jsonl'
if jsonl_file.is_file():
with open(jsonl_file, 'r') as infile:
for line in infile:
outfile.write(line)

print(f"Concatenated {file_type}.jsonl files have been saved to {output_file}")

def concatenate_jsonl_files_cli():
fire.Fire(concatenate_jsonl_files)



def add_random_split_column(df):
# Calculate the number of rows for each split
n_rows = len(df)
Expand Down

0 comments on commit 6e37708

Please sign in to comment.