diff --git a/examples/trl/stack_llama_2/README.md b/examples/trl/stack_llama_2/README.md new file mode 100644 index 0000000000..12b7e4da80 --- /dev/null +++ b/examples/trl/stack_llama_2/README.md @@ -0,0 +1,72 @@ +# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model + +## Prerequisites + +Install all the dependencies in the `requirements.txt`: + +``` +$ pip install -U -r requirements.txt +``` + + +## Training + +There were two main steps to the DPO training process: +1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se: + + ``` + python ../../gaudi_spawn.py --world_size 8 --use_mpi sft_llama2.py \ + --output_dir="./sft" \ + --max_steps=500 \ + --logging_steps=10 \ + --save_steps=10 \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=1 \ + --gradient_accumulation_steps=2 \ + --learning_rate=1e-4 \ + --lr_scheduler_type="cosine" \ + --warmup_steps=100 \ + --weight_decay=0.05 \ + --optim="paged_adamw_32bit" \ + --bf16 \ + --remove_unused_columns=False \ + --run_name="sft_llama2" \ + --report_to=none \ + --use_habana \ + --use_lazy_mode + ``` +2. Run the DPO trainer using the model saved by the previous step: + ``` + python ../../gaudi_spawn.py --world_size 8 --use_mpi dpo_llama2.py \ + --model_name_or_path="sft/final_merged_checkpoint" \ + --output_dir="dpo" \ + --report_to=none + ``` + + +## Merging the adaptors + +To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL: + +``` +python merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo" --output_name="stack-llama-2" +``` + +which will also push the model to your HuggingFace hub account. + +## Running the model + +We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via: + +```py +from peft import AutoPeftModelForCausalLM + + +model = AutoPeftModelForCausalLM.from_pretrained( + "dpo/final_checkpoint", + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, +) + +model.generate(...) +``` diff --git a/examples/trl/stack_llama_2/dpo_llama2.py b/examples/trl/stack_llama_2/dpo_llama2.py new file mode 100644 index 0000000000..2b102e1825 --- /dev/null +++ b/examples/trl/stack_llama_2/dpo_llama2.py @@ -0,0 +1,231 @@ +# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py, enable it for Gaudi2 +from dataclasses import dataclass, field +from typing import Dict, Optional + +import torch +from datasets import Dataset, load_dataset +from peft import LoraConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser + +from optimum.habana import GaudiConfig, GaudiTrainingArguments +from optimum.habana.trl import GaudiDPOTrainer + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + The arguments for the DPO training script. + """ + + # data parameters + beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) + + # training parameters + model_name_or_path: Optional[str] = field( + default="../sft/results/final_checkpoint", + metadata={"help": "the location of the SFT model name or path"}, + ) + tokenizer_name_or_path: Optional[str] = field( + default="meta-llama/Llama-2-7b-hf", + metadata={"help": "the location of the SFT model name or path"}, + ) + learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"}) + lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"}) + warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"}) + weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"}) + optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"}) + + per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "train batch size per device"}) + per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"}) + gradient_accumulation_steps: Optional[int] = field( + default=4, metadata={"help": "the number of gradient accumulation steps"} + ) + gradient_checkpointing: Optional[bool] = field( + default=False, metadata={"help": "whether to use gradient checkpointing"} + ) + + lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) + lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) + + max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"}) + max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) + max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) + logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"}) + save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"}) + eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"}) + + output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"}) + log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"}) + + # instrumentation + sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"}) + report_to: Optional[str] = field( + default="wandb", + metadata={ + "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' + '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' + 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' + }, + ) + # debug argument for distributed training + ignore_bias_buffers: Optional[bool] = field( + default=False, + metadata={ + "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" + "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" + }, + ) + + +def get_stack_exchange_paired( + data_dir: str = "data/rl", + sanity_check: bool = False, + cache_dir: str = None, + num_proc=24, +) -> Dataset: + """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format. + + The dataset is converted to a dictionary with the following structure: + { + 'prompt': List[str], + 'chosen': List[str], + 'rejected': List[str], + } + + Prompts are structured as follows: + "Question: " + + "\n\nAnswer: " + """ + dataset = load_dataset( + "lvwerra/stack-exchange-paired", + split="train", + cache_dir=cache_dir, + data_dir=data_dir, + ) + original_columns = dataset.column_names + + if sanity_check: + dataset = dataset.select(range(min(len(dataset), 1000))) + + def return_prompt_and_responses(samples) -> Dict[str, str]: + return { + "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]], + "chosen": samples["response_j"], + "rejected": samples["response_k"], + } + + return dataset.map( + return_prompt_and_responses, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, + ) + + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + # 1. initialize training arguments: + training_args = GaudiTrainingArguments( + per_device_train_batch_size=script_args.per_device_train_batch_size, + per_device_eval_batch_size=script_args.per_device_eval_batch_size, + max_steps=script_args.max_steps, + logging_steps=script_args.logging_steps, + save_steps=script_args.save_steps, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + gradient_checkpointing=script_args.gradient_checkpointing, + learning_rate=script_args.learning_rate, + evaluation_strategy="steps", + eval_steps=script_args.eval_steps, + output_dir=script_args.output_dir, + report_to=script_args.report_to, + lr_scheduler_type=script_args.lr_scheduler_type, + warmup_steps=script_args.warmup_steps, + optim=script_args.optimizer_type, + bf16=True, + remove_unused_columns=False, + run_name="dpo_llama2", + use_habana=True, + use_lazy_mode=True, + use_hpu_graphs_for_training=True, + use_hpu_graphs_for_inference=True, + ) + # 2. load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + script_args.model_name_or_path, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, + ) + model.config.use_cache = False + + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + model_ref = AutoModelForCausalLM.from_pretrained( + script_args.model_name_or_path, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, + ) + model_ref.config.use_cache = False + tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name_or_path) + tokenizer.pad_token = tokenizer.eos_token + + # 3. Load the Stack-exchange paired dataset + train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check) + train_dataset = train_dataset.filter( + lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length + and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length + ) + + # 4. Load evaluation dataset + eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True) + eval_dataset = eval_dataset.filter( + lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length + and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length + ) + + peft_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + target_modules=[ + "q_proj", + "v_proj", + "k_proj", + "out_proj", + "fc_in", + "fc_out", + "wte", + ], + bias="none", + task_type="CAUSAL_LM", + ) + + gaudi_config = GaudiConfig() + gaudi_config.use_fused_adam = True + gaudi_config.use_fused_clip_norm = True + + # 5. initialize the DPO trainer + dpo_trainer = GaudiDPOTrainer( + model, + model_ref, + gaudi_config=gaudi_config, + args=training_args, + beta=script_args.beta, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + peft_config=peft_config, + max_prompt_length=script_args.max_prompt_length, + max_length=script_args.max_length, + ) + + # 6. train + dpo_trainer.train() + + # 7. save + dpo_trainer.save_model(script_args.output_dir) diff --git a/examples/trl/stack_llama_2/merge_peft_adapter.py b/examples/trl/stack_llama_2/merge_peft_adapter.py new file mode 100644 index 0000000000..8913fc62a4 --- /dev/null +++ b/examples/trl/stack_llama_2/merge_peft_adapter.py @@ -0,0 +1,50 @@ +# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py. +# only difference is removal of model.push_to_hub +from dataclasses import dataclass, field +from typing import Optional + +import torch +from peft import PeftConfig, PeftModel +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser + + +@dataclass +class ScriptArguments: + """ + The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the + merged model. + """ + + adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"}) + base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"}) + output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"}) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] +assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge" +assert script_args.base_model_name is not None, "please provide the name of the Base model" +assert script_args.output_name is not None, "please provide the output name of the merged model" + +peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name) +if peft_config.task_type == "SEQ_CLS": + # The sequence classification task is used for the reward model in PPO + model = AutoModelForSequenceClassification.from_pretrained( + script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16 + ) +else: + model = AutoModelForCausalLM.from_pretrained( + script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16 + ) + +tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name) + +# Load the PEFT model +model = PeftModel.from_pretrained(model, script_args.adapter_model_name) +model.eval() + +model = model.merge_and_unload() + +model.save_pretrained(f"{script_args.output_name}") +tokenizer.save_pretrained(f"{script_args.output_name}") +# model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False) diff --git a/examples/trl/stack_llama_2/requirements.txt b/examples/trl/stack_llama_2/requirements.txt new file mode 100644 index 0000000000..c980a4b30c --- /dev/null +++ b/examples/trl/stack_llama_2/requirements.txt @@ -0,0 +1,5 @@ +trl == 0.7.6 +peft == 0.6.2 +datasets +wandb +tyro diff --git a/examples/trl/stack_llama_2/sft_llama2.py b/examples/trl/stack_llama_2/sft_llama2.py new file mode 100644 index 0000000000..1ebff0df14 --- /dev/null +++ b/examples/trl/stack_llama_2/sft_llama2.py @@ -0,0 +1,168 @@ +# Fine-Tune Llama2-7b on SE paired dataset +# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama_2/scripts/sft_llama2.py, enable it for Gaudi2 +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import torch +import transformers +from datasets import load_dataset +from peft import AutoPeftModelForCausalLM, LoraConfig +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser +from transformers.trainer_utils import is_main_process +from trl.trainer import ConstantLengthDataset + +from optimum.habana import GaudiConfig, GaudiTrainingArguments +from optimum.habana.trl import GaudiSFTTrainer + + +logger = logging.getLogger(__name__) + + +@dataclass +class ScriptArguments: + model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"}) + dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"}) + subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"}) + split: Optional[str] = field(default="train", metadata={"help": "the split to use"}) + size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"}) + streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"}) + shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"}) + seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"}) + num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"}) + packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"}) + + # LoraConfig + lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) + lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) + + +parser = HfArgumentParser((ScriptArguments, GaudiTrainingArguments)) +script_args, training_args = parser.parse_args_into_dataclasses() +peft_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + target_modules=["q_proj", "v_proj"], + bias="none", + task_type="CAUSAL_LM", +) + +if training_args.group_by_length and script_args.packing: + raise ValueError("Cannot use both packing and group by length") + + +def chars_token_ratio(dataset, tokenizer, nb_examples=400): + """ + Estimate the average number of characters per token in the dataset. + """ + total_characters, total_tokens = 0, 0 + for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): + text = prepare_sample_text(example) + total_characters += len(text) + if tokenizer.is_fast: + total_tokens += len(tokenizer(text).tokens()) + else: + total_tokens += len(tokenizer.tokenize(text)) + + return total_characters / total_tokens + + +def prepare_sample_text(example): + """Prepare the text from a sample of the dataset.""" + text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}" + return text + + +def create_datasets(tokenizer, args): + dataset = load_dataset( + args.dataset_name, + data_dir=args.subset, + split=args.split, + use_auth_token=True, + num_proc=args.num_workers if not args.streaming else None, + streaming=args.streaming, + ) + if args.streaming: + print("Loading the dataset in streaming mode") + valid_data = dataset.take(args.size_valid_set) + train_data = dataset.skip(args.size_valid_set) + train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None) + else: + dataset = dataset.train_test_split(test_size=0.005, seed=None) + train_data = dataset["train"] + valid_data = dataset["test"] + print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") + + chars_per_token = chars_token_ratio(train_data, tokenizer) + print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + formatting_func=prepare_sample_text, + infinite=True, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + valid_dataset = ConstantLengthDataset( + tokenizer, + valid_data, + formatting_func=prepare_sample_text, + infinite=False, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + return train_dataset, valid_dataset + + +base_model = AutoModelForCausalLM.from_pretrained( + script_args.model_name, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, + use_auth_token=True, +) +base_model.config.use_cache = False + +tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training + +log_level = training_args.get_process_log_level() +logger.setLevel(log_level) +transformers.utils.logging.set_verbosity(log_level) +transformers.utils.logging.enable_default_handler() +transformers.utils.logging.enable_explicit_format() + +train_dataset, eval_dataset = create_datasets(tokenizer, script_args) + +gaudi_config = GaudiConfig() +gaudi_config.use_fused_adam = True +gaudi_config.use_fused_clip_norm = True + +trainer = GaudiSFTTrainer( + model=base_model, + gaudi_config=gaudi_config, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + packing=script_args.packing, + max_seq_length=None, + tokenizer=tokenizer, + args=training_args, +) +trainer.train() +trainer.save_model(training_args.output_dir) + +# Free memory for merging weights +del base_model +with training_args.main_process_first(desc="merge peft model"): + if is_main_process(training_args.local_rank): + model = AutoPeftModelForCausalLM.from_pretrained(training_args.output_dir, torch_dtype=torch.bfloat16) + model = model.merge_and_unload() + + output_merged_dir = os.path.join(training_args.output_dir, "final_merged_checkpoint") + model.save_pretrained(output_merged_dir, safe_serialization=True) diff --git a/optimum/habana/trl/__init__.py b/optimum/habana/trl/__init__.py new file mode 100644 index 0000000000..e80fac8b8a --- /dev/null +++ b/optimum/habana/trl/__init__.py @@ -0,0 +1,2 @@ +from .trainer.dpo_trainer import GaudiDPOTrainer +from .trainer.sft_trainer import GaudiSFTTrainer diff --git a/optimum/habana/trl/trainer/__init__.py b/optimum/habana/trl/trainer/__init__.py new file mode 100644 index 0000000000..13bf554fd7 --- /dev/null +++ b/optimum/habana/trl/trainer/__init__.py @@ -0,0 +1,21 @@ +# flake8: noqa + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# There is a circular import in the PPOTrainer if we let isort sort these +# isort: on + +from .sft_trainer import GaudiSFTTrainer +from .dpo_trainer import GaudiDPOTrainer diff --git a/optimum/habana/trl/trainer/dpo_trainer.py b/optimum/habana/trl/trainer/dpo_trainer.py new file mode 100644 index 0000000000..e5cfea0cd3 --- /dev/null +++ b/optimum/habana/trl/trainer/dpo_trainer.py @@ -0,0 +1,426 @@ +# DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import warnings +from collections import defaultdict +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +from accelerate.utils import is_deepspeed_available +from datasets import Dataset +from transformers import ( + AutoModelForCausalLM, + DataCollator, + PreTrainedModel, + PreTrainedTokenizerBase, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput +from trl import DPOTrainer, create_reference_model +from trl.import_utils import is_peft_available, is_wandb_available +from trl.trainer.utils import ( + DPODataCollatorWithPadding, + disable_dropout_in_model, + pad_to_length, + peft_module_casting_to_bf16, +) + +from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +if is_wandb_available(): + pass + +if is_deepspeed_available(): + pass + + +class GaudiDPOTrainer(DPOTrainer, GaudiTrainer): + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + beta: float = 0.1, + label_smoothing: float = 0, + loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] = "sigmoid", + args: GaudiTrainingArguments = None, + gaudi_config: GaudiConfig = None, + data_collator: Optional[DataCollator] = None, + label_pad_token_id: int = -100, + padding_value: int = None, + truncation_mode: str = "keep_end", + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + max_length: Optional[int] = None, + max_prompt_length: Optional[int] = None, + max_target_length: Optional[int] = None, + peft_config: Optional[Dict] = None, + is_encoder_decoder: Optional[bool] = None, + disable_dropout: bool = True, + generate_during_eval: bool = False, + compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + precompute_ref_log_probs: bool = False, + model_init_kwargs: Optional[Dict] = None, + ref_model_init_kwargs: Optional[Dict] = None, + ): + """ + Copied from DPOTrainer.__init__: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/dpo_trainer.py#L127 + The only differences are: + - add new args gaudi_config + - use graph for ref_model + - use GaudiTrainer instead of Trainer + - cast peft model to bf16. + """ + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the DPOTrainer. But your model is already instantiated.") + + if ref_model_init_kwargs is None: + ref_model_init_kwargs = {} + elif not isinstance(ref_model, str): + raise ValueError( + "You passed ref_model_kwargs to the DPOTrainer. But your ref_model is already instantiated." + ) + + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the DPOTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if isinstance(ref_model, str): + warnings.warn( + "You passed a ref model_id to the DPOTrainer. This will automatically create an " + "`AutoModelForCausalLM`" + ) + ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_config, we merge and unload it first + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16: + peft_module_casting_to_bf16(model) + + # For models that use gradient_checkpoiting, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if generate_during_eval and not is_wandb_available(): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases to be installed." + " Please install `wandb` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = is_encoder_decoder + + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model or precompute_ref_log_probs: + # The `model` with adapters turned off will be used as the reference model + self.ref_model = None + else: + self.ref_model = create_reference_model(model) + + if data_collator is None: + if tokenizer is None: + raise ValueError( + "max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding" + ) + if max_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init" + " it will be set to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_prompt_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + + if max_target_length is None and self.is_encoder_decoder: + warnings.warn( + "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_target_length = 128 + + data_collator = DPODataCollatorWithPadding( + pad_token_id=tokenizer.pad_token_id, + label_pad_token_id=label_pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + if disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.max_length = max_length + self.generate_during_eval = generate_during_eval + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id + self.max_prompt_length = max_prompt_length + self.truncation_mode = truncation_mode + self.max_target_length = max_target_length + self.tokenizer = tokenizer + self.precompute_ref_log_probs = precompute_ref_log_probs + + # Since ref_logs are precomputed on the first call to get_train/eval_dataloader + # keep track of first called to avoid computation of future calls + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + + if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0: + warnings.warn( + "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter." + ) + + self.beta = beta + self.label_smoothing = label_smoothing + self.loss_type = loss_type + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # tokenize the dataset + train_dataset = train_dataset.map(self.tokenize_row) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(self.tokenize_row) + + GaudiTrainer.__init__( + self, + model=model, + args=args, + gaudi_config=gaudi_config, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + # Deepspeed Zero-3 does not support precompute_ref_log_probs + if self.is_deepspeed_enabled: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." + ) + + if self.ref_model is None: + if not (self.is_peft_model or self.precompute_ref_log_probs): + raise ValueError( + "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" + ) + else: + if self.is_deepspeed_enabled: + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + from habana_frameworks.torch.hpu import wrap_in_hpu_graph # use graph for ref_model + + ref_model = self.accelerator.unwrap_model(self.ref_model) + ref_model = wrap_in_hpu_graph(ref_model) + + @staticmethod + def concatenated_inputs( + batch: Dict[str, Union[List, torch.LongTensor]], + is_encoder_decoder: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + device: Optional[torch.device] = None, + padded_max_length: int = 0, + ) -> Dict[str, torch.LongTensor]: + """ + Copied from DPOTrainer.concatenated_inputs: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/dpo_trainer.py#L701 + - pad to self.max_length in Gaudi2 + """ + concatenated_batch = {} + + if is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + if padded_max_length != 0: # pad to max_length in Gaudi + max_length = padded_max_length + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(device=device) + + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1).to(device=device) + ) + + return concatenated_batch + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """ + Copied from DPOTrainer.concatenated_forward: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/dpo_trainer.py#L866 + - pad to self.max_length in Gaudi2 + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + padded_max_length=self.max_length, + ) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "labels": concatenated_batch["concatenated_labels"], + "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), + } + if self.is_encoder_decoder + else {} + ) + all_logits = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + **model_kwargs, + ).logits + + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) diff --git a/optimum/habana/trl/trainer/sft_trainer.py b/optimum/habana/trl/trainer/sft_trainer.py new file mode 100644 index 0000000000..49b2525f4c --- /dev/null +++ b/optimum/habana/trl/trainer/sft_trainer.py @@ -0,0 +1,244 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses +import inspect +import warnings +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from datasets import Dataset +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollator, + DataCollatorForLanguageModeling, + PreTrainedModel, + PreTrainedTokenizerBase, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction +from trl import SFTTrainer +from trl.import_utils import is_peft_available +from trl.trainer.utils import ( + DataCollatorForCompletionOnlyLM, + peft_module_casting_to_bf16, +) + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training + +from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments + + +class GaudiSFTTrainer(SFTTrainer, GaudiTrainer): + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + args: GaudiTrainingArguments = None, + gaudi_config: GaudiConfig = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional["PeftConfig"] = None, + dataset_text_field: Optional[str] = None, + packing: Optional[bool] = False, + formatting_func: Optional[Callable] = None, + max_seq_length: Optional[int] = None, + infinite: Optional[bool] = None, + num_of_sequences: Optional[int] = 1024, + chars_per_token: Optional[float] = 3.6, + dataset_num_proc: Optional[int] = None, + dataset_batch_size: int = 1000, + neftune_noise_alpha: Optional[float] = None, + model_init_kwargs: Optional[Dict] = None, + dataset_kwargs: Optional[Dict] = None, + ): + """ + Copied from SFTTrainer.__init__: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/sft_trainer.py#L120 + The only differences are: + - add new args gaudi_config + - use GaudiTrainer instead of Trainer + - cast peft model to bf16. + """ + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.") + + if infinite is not None: + warnings.warn( + "The `infinite` argument is deprecated and will be removed in a future version of TRL. Use `TrainingArguments.max_steps` or `TrainingArguments.num_train_epochs` instead to control training length." + ) + + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the SFTTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM): + raise ValueError( + "You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument." + ) + + if is_peft_available() and peft_config is not None: + if not isinstance(peft_config, PeftConfig): + raise ValueError( + "If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer." + f" and you passed a {type(peft_config)}." + ) + + if not isinstance(model, PeftModel): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + preprare_model_kwargs = { + "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) + } + + if _support_gc_kwargs: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + + if args is not None: + args = dataclasses.replace(args, gradient_checkpointing=False) + elif getattr(args, "gradient_checkpointing", False) and ( + "use_reentrant" not in gradient_checkpointing_kwargs + or gradient_checkpointing_kwargs["use_reentrant"] + ): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model = get_peft_model(model, peft_config) + if args.bf16: + peft_module_casting_to_bf16(model) + + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + if max_seq_length is None: + # to overcome some issues with broken tokenizers + max_seq_length = min(tokenizer.model_max_length, 1024) + + warnings.warn( + f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {max_seq_length}" + ) + + self.dataset_num_proc = dataset_num_proc + self.dataset_batch_size = dataset_batch_size + + self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") + + if neftune_noise_alpha is not None and self._trainer_supports_neftune: + args.neftune_noise_alpha = neftune_noise_alpha + warnings.warn( + "You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `TrainingArguments`." + ) + # self.neftune_noise_alpha is done at Trainer level + elif not self._trainer_supports_neftune: + self.neftune_noise_alpha = neftune_noise_alpha + + if not packing: + if dataset_text_field is None and formatting_func is None: + raise ValueError( + "You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument." + ) + + if data_collator is None: + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + if dataset_kwargs is None: + dataset_kwargs = {} + if train_dataset is not None: + train_dataset = self._prepare_dataset( + train_dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + **dataset_kwargs, + ) + if eval_dataset is not None: + _multiple = isinstance(eval_dataset, dict) + _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} + for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): + _eval_datasets[_eval_dataset_name] = self._prepare_dataset( + _eval_dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + **dataset_kwargs, + ) + if not _multiple: + eval_dataset = _eval_datasets["singleton"] + + if tokenizer.padding_side is not None and tokenizer.padding_side != "right": + warnings.warn( + "You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to " + "overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code." + ) + + GaudiTrainer.__init__( + self, + model=model, + args=args, + gaudi_config=gaudi_config, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if self.args.max_steps > 0 and packing: + warnings.warn( + "You passed `packing=True` to the SFTTrainer, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached." + ) + self.train_dataset.infinite = True + elif self.args.max_steps == -1 and packing: + self.train_dataset.infinite = False