diff --git a/experiments/ablations/continued_pretrain.py b/experiments/ablations/continued_pretrain.py new file mode 100644 index 000000000..e0308911b --- /dev/null +++ b/experiments/ablations/continued_pretrain.py @@ -0,0 +1,113 @@ +from unsloth import FastLanguageModel +import torch +from unsloth import add_new_tokens +from typing import Optional, List +from transformers import TrainingArguments +from unsloth import is_bfloat16_supported +from unsloth import UnslothTrainer, UnslothTrainingArguments +import fire +import wandb +from datasets import load_dataset + + +def load_model(rank: int = 128, train_embeddings: bool = True, add_special_tokens: Optional[List[str]]=None): + max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! + dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ + load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. + + model, tokenizer = FastLanguageModel.from_pretrained( + model_name = "unsloth/llama-3-8b-bnb-4bit", + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, + ) + + add_new_tokens(model, tokenizer, new_tokens = add_special_tokens) + + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj"] + + if train_embeddings: + target_modules += ["embed_tokens", "lm_head"] + model = FastLanguageModel.get_peft_model( + model, + r = rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 + target_modules = target_modules, + lora_alpha = rank/4, + lora_dropout = 0, # Supports any, but = 0 is optimized + bias = "none", # Supports any, but = "none" is optimized + # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! + use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context + random_state = 3407, + use_rslora = True, # We support rank stabilized LoRA + loftq_config = None, # And LoftQ + ) + + return model, tokenizer + + +def train(model, tokenizer, dataset, run_name: str, batch_size:int =64, max_seq_length = 2048): + wandb.init( + project="chemnlp-ablations", + name=run_name + ) + trainer = UnslothTrainer( + model = model, + tokenizer = tokenizer, + train_dataset = dataset, + dataset_text_field = "text", + max_seq_length = max_seq_length, + dataset_num_proc = 2, + + args = UnslothTrainingArguments( + per_device_train_batch_size = batch_size, + gradient_accumulation_steps = 1, + warmup_ratio = 0.1, + num_train_epochs = 1, + learning_rate = 5e-5, + embedding_learning_rate = 1e-5, + fp16 = not is_bfloat16_supported(), + bf16 = is_bfloat16_supported(), + logging_steps = 1, + optim = "adamw_8bit", + weight_decay = 0.01, + lr_scheduler_type = "linear", + seed = 3407, + output_dir = f"outputs_{run_name}", + ), + ) + + #@title Show current memory stats + gpu_stats = torch.cuda.get_device_properties(0) + start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) + max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) + print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") + print(f"{start_gpu_memory} GB of memory reserved.") + + trainer_stats = trainer.train() + + model.save_pretrained(f"lora_model_{run_name}") # Local saving + tokenizer.save_pretrained(f"lora_model_{run_name}") + + +def create_dataset(tokenizer, datasets): + EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN + def formatting_prompts_func(examples): + outputs = [] + for t in examples['text']: + outputs.append(t + EOS_TOKEN) + return { "text" : outputs, } + + dataset = load_dataset("json", data_files=datasets) + dataset = dataset["train"] + + dataset = dataset.map(formatting_prompts_func, batched = True) + + return dataset + +if __name__ == "__main__": + model, tokenizer = load_model(train_embeddings=True, add_special_tokens=None) + + dataset = create_dataset(tokenizer, ["data/chemnlp_train.json", "data/chemnlp_val.json"]) + + train(model, tokenizer, dataset, "lora_128", batch_size=64) diff --git a/pyproject.toml b/pyproject.toml index fca4e156e..c1e6c5fd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ chemnlp-generate-meta = "chemnlp.data.meta_yaml_generator:cli" chemnlp-augment-meta = "chemnlp.data.meta_yaml_augmenter:cli" chemnlp-sample = "chemnlp.data.sampler_cli:cli" chemnlp-add-random-split-column = "chemnlp.data.utils:add_random_split_column_cli" +chemnlp-concatenate-jsonl = "chemnlp.data.utils:concatenate_jsonl_files_cli" [tool.setuptools_scm] version_scheme = "post-release" diff --git a/src/chemnlp/data/sampler.py b/src/chemnlp/data/sampler.py index bbea75dfb..539733587 100644 --- a/src/chemnlp/data/sampler.py +++ b/src/chemnlp/data/sampler.py @@ -839,9 +839,8 @@ def export(self, output_dir: str, template: str) -> pd.DataFrame: df_split = self.df[self.df["split"] == split] samples = [] for _, row in tqdm(df_split.iterrows(), total=len(df_split)): - sample_dict = row.to_dict() - sample = self._fill_template(template, sample_dict) - samples.append(sample) + sampled = self.sample(row, template) + samples.append(sampled) df_out = pd.DataFrame(samples) # if self.benchmarking_templates: diff --git a/src/chemnlp/data/utils.py b/src/chemnlp/data/utils.py index 3bbeb7b31..cf8788714 100644 --- a/src/chemnlp/data/utils.py +++ b/src/chemnlp/data/utils.py @@ -8,6 +8,45 @@ import pandas as pd +from pathlib import Path +import fire + +def get_all_datasets(root_dir): + return [d.name for d in Path(root_dir).iterdir() if d.is_dir()] + +def concatenate_jsonl_files(root_dir, output_file, datasets=None, file_type='train'): + root_dir = Path(root_dir) + + if datasets is None: + datasets = get_all_datasets(root_dir) + elif isinstance(datasets, str): + datasets = [datasets] + + print(f"Processing datasets: {', '.join(datasets)}") + print(f"File type: {file_type}.jsonl") + + with open(output_file, 'w') as outfile: + for dataset in datasets: + dataset_path = root_dir / dataset + if not dataset_path.is_dir(): + print(f"Warning: Dataset '{dataset}' not found. Skipping.") + continue + + for chunk_dir in dataset_path.glob('chunk_*'): + for template_dir in chunk_dir.glob('template_*'): + jsonl_file = template_dir / f'{file_type}.jsonl' + if jsonl_file.is_file(): + with open(jsonl_file, 'r') as infile: + for line in infile: + outfile.write(line) + + print(f"Concatenated {file_type}.jsonl files have been saved to {output_file}") + +def concatenate_jsonl_files_cli(): + fire.Fire(concatenate_jsonl_files) + + + def add_random_split_column(df): # Calculate the number of rows for each split n_rows = len(df)