From 41da917eb6c1b0becf2cc5295e90ed4bae50a913 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 20 Dec 2023 02:47:21 -0800 Subject: [PATCH 1/2] add DPO and SFT of TRL support in Gaudi and example Signed-off-by: Wang, Yi A --- examples/trl/stack_llama_2/README.md | 70 ++++ examples/trl/stack_llama_2/dpo_llama2.py | 231 +++++++++++ .../trl/stack_llama_2/merge_peft_adapter.py | 48 +++ examples/trl/stack_llama_2/requirements.txt | 5 + examples/trl/stack_llama_2/sft_llama2.py | 199 +++++++++ optimum/habana/trl/__init__.py | 2 + optimum/habana/trl/trainer/__init__.py | 21 + optimum/habana/trl/trainer/dpo_trainer.py | 387 ++++++++++++++++++ optimum/habana/trl/trainer/sft_trainer.py | 271 ++++++++++++ 9 files changed, 1234 insertions(+) create mode 100644 examples/trl/stack_llama_2/README.md create mode 100644 examples/trl/stack_llama_2/dpo_llama2.py create mode 100644 examples/trl/stack_llama_2/merge_peft_adapter.py create mode 100644 examples/trl/stack_llama_2/requirements.txt create mode 100644 examples/trl/stack_llama_2/sft_llama2.py create mode 100644 optimum/habana/trl/__init__.py create mode 100644 optimum/habana/trl/trainer/__init__.py create mode 100644 optimum/habana/trl/trainer/dpo_trainer.py create mode 100644 optimum/habana/trl/trainer/sft_trainer.py diff --git a/examples/trl/stack_llama_2/README.md b/examples/trl/stack_llama_2/README.md new file mode 100644 index 0000000000..b27f62a413 --- /dev/null +++ b/examples/trl/stack_llama_2/README.md @@ -0,0 +1,70 @@ +# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model + +## Prerequisites + +Install all the dependencies in the `requirements.txt`: + +``` +$ pip install -U -r requirements.txt +``` + + +## Training + +There were two main steps to the DPO training process: +1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se: + + ``` + python ../../gaudi_spawn.py --world_size 8 --use_mpi sft_llama2.py \ + --training-args.output_dir="./sft" \ + --training-args.max_steps=500 \ + --training-args.logging_steps=10 \ + --training-args.save_steps=10 \ + --training-args.per_device_train_batch_size=4 \ + --training-args.per_device_eval_batch_size=1 \ + --training-args.gradient_accumulation_steps=2 \ + --training-args.learning_rate=1e-4 \ + --training-args.lr_scheduler_type="cosine" \ + --training-args.warmup_steps=100 \ + --training-args.weight_decay=0.05 \ + --training-args.optim="paged_adamw_32bit" \ + --training-args.bf16 \ + --training-args.remove_unused_columns=False \ + --training-args.run_name="sft_llama2" \ + --training-args.report_to=none + ``` +2. Run the DPO trainer using the model saved by the previous step: + ``` + python ../../gaudi_spawn.py --world_size 8 --use_mpi dpo_llama2.py \ + --model_name_or_path="sft/final_merged_checkpoint" \ + --output_dir="dpo" \ + --report_to=none + ``` + + +## Merging the adaptors + +To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL: + +``` +python merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo" --output_name="stack-llama-2" +``` + +which will also push the model to your HuggingFace hub account. + +## Running the model + +We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via: + +```py +from peft import AutoPeftModelForCausalLM + + +model = AutoPeftModelForCausalLM.from_pretrained( + "dpo/final_checkpoint", + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, +) + +model.generate(...) +``` diff --git a/examples/trl/stack_llama_2/dpo_llama2.py b/examples/trl/stack_llama_2/dpo_llama2.py new file mode 100644 index 0000000000..eaadebd4ae --- /dev/null +++ b/examples/trl/stack_llama_2/dpo_llama2.py @@ -0,0 +1,231 @@ +# 0. imports +from dataclasses import dataclass, field +from typing import Dict, Optional + +import torch +from datasets import Dataset, load_dataset +from peft import LoraConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser + +from optimum.habana import GaudiConfig, GaudiTrainingArguments +from optimum.habana.trl import GaudiDPOTrainer + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + The arguments for the DPO training script. + """ + + # data parameters + beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) + + # training parameters + model_name_or_path: Optional[str] = field( + default="../sft/results/final_checkpoint", + metadata={"help": "the location of the SFT model name or path"}, + ) + tokenizer_name_or_path: Optional[str] = field( + default="meta-llama/Llama-2-7b-hf", + metadata={"help": "the location of the SFT model name or path"}, + ) + learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"}) + lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"}) + warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"}) + weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"}) + optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"}) + + per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "train batch size per device"}) + per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"}) + gradient_accumulation_steps: Optional[int] = field( + default=4, metadata={"help": "the number of gradient accumulation steps"} + ) + gradient_checkpointing: Optional[bool] = field( + default=False, metadata={"help": "whether to use gradient checkpointing"} + ) + + lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) + lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) + + max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"}) + max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) + max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) + logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"}) + save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"}) + eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"}) + + output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"}) + log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"}) + + # instrumentation + sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"}) + report_to: Optional[str] = field( + default="wandb", + metadata={ + "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' + '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' + 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' + }, + ) + # debug argument for distributed training + ignore_bias_buffers: Optional[bool] = field( + default=False, + metadata={ + "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" + "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" + }, + ) + + +def get_stack_exchange_paired( + data_dir: str = "data/rl", + sanity_check: bool = False, + cache_dir: str = None, + num_proc=24, +) -> Dataset: + """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format. + + The dataset is converted to a dictionary with the following structure: + { + 'prompt': List[str], + 'chosen': List[str], + 'rejected': List[str], + } + + Prompts are structured as follows: + "Question: " + + "\n\nAnswer: " + """ + dataset = load_dataset( + "lvwerra/stack-exchange-paired", + split="train", + cache_dir=cache_dir, + data_dir=data_dir, + ) + original_columns = dataset.column_names + + if sanity_check: + dataset = dataset.select(range(min(len(dataset), 1000))) + + def return_prompt_and_responses(samples) -> Dict[str, str]: + return { + "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]], + "chosen": samples["response_j"], + "rejected": samples["response_k"], + } + + return dataset.map( + return_prompt_and_responses, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, + ) + + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + # 1. initialize training arguments: + training_args = GaudiTrainingArguments( + per_device_train_batch_size=script_args.per_device_train_batch_size, + per_device_eval_batch_size=script_args.per_device_eval_batch_size, + max_steps=script_args.max_steps, + logging_steps=script_args.logging_steps, + save_steps=script_args.save_steps, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + gradient_checkpointing=script_args.gradient_checkpointing, + learning_rate=script_args.learning_rate, + evaluation_strategy="steps", + eval_steps=script_args.eval_steps, + output_dir=script_args.output_dir, + report_to=script_args.report_to, + lr_scheduler_type=script_args.lr_scheduler_type, + warmup_steps=script_args.warmup_steps, + optim=script_args.optimizer_type, + bf16=True, + remove_unused_columns=False, + run_name="dpo_llama2", + use_habana=True, + use_lazy_mode=True, + use_hpu_graphs_for_training=True, + use_hpu_graphs_for_inference=True, + ) + # 2. load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + script_args.model_name_or_path, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, + ) + model.config.use_cache = False + + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + model_ref = AutoModelForCausalLM.from_pretrained( + script_args.model_name_or_path, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, + ) + model_ref.config.use_cache = False + tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name_or_path) + tokenizer.pad_token = tokenizer.eos_token + + # 3. Load the Stack-exchange paired dataset + train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check) + train_dataset = train_dataset.filter( + lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length + and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length + ) + + # 4. Load evaluation dataset + eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True) + eval_dataset = eval_dataset.filter( + lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length + and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length + ) + + peft_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + target_modules=[ + "q_proj", + "v_proj", + "k_proj", + "out_proj", + "fc_in", + "fc_out", + "wte", + ], + bias="none", + task_type="CAUSAL_LM", + ) + + gaudi_config = GaudiConfig() + gaudi_config.use_fused_adam = True + gaudi_config.use_fused_clip_norm = True + + # 5. initialize the DPO trainer + dpo_trainer = GaudiDPOTrainer( + model, + model_ref, + gaudi_config=gaudi_config, + args=training_args, + beta=script_args.beta, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + peft_config=peft_config, + max_prompt_length=script_args.max_prompt_length, + max_length=script_args.max_length, + ) + + # 6. train + dpo_trainer.train() + + # 7. save + dpo_trainer.save_model(script_args.output_dir) diff --git a/examples/trl/stack_llama_2/merge_peft_adapter.py b/examples/trl/stack_llama_2/merge_peft_adapter.py new file mode 100644 index 0000000000..058099180a --- /dev/null +++ b/examples/trl/stack_llama_2/merge_peft_adapter.py @@ -0,0 +1,48 @@ +from dataclasses import dataclass, field +from typing import Optional + +import torch +from peft import PeftConfig, PeftModel +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser + + +@dataclass +class ScriptArguments: + """ + The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the + merged model. + """ + + adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"}) + base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"}) + output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"}) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] +assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge" +assert script_args.base_model_name is not None, "please provide the name of the Base model" +assert script_args.output_name is not None, "please provide the output name of the merged model" + +peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name) +if peft_config.task_type == "SEQ_CLS": + # The sequence classification task is used for the reward model in PPO + model = AutoModelForSequenceClassification.from_pretrained( + script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16 + ) +else: + model = AutoModelForCausalLM.from_pretrained( + script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16 + ) + +tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name) + +# Load the PEFT model +model = PeftModel.from_pretrained(model, script_args.adapter_model_name) +model.eval() + +model = model.merge_and_unload() + +model.save_pretrained(f"{script_args.output_name}") +tokenizer.save_pretrained(f"{script_args.output_name}") +# model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False) diff --git a/examples/trl/stack_llama_2/requirements.txt b/examples/trl/stack_llama_2/requirements.txt new file mode 100644 index 0000000000..1da3dabbc8 --- /dev/null +++ b/examples/trl/stack_llama_2/requirements.txt @@ -0,0 +1,5 @@ +trl == 0.7.4 +peft == 0.6.2 +datasets +wandb +tyro diff --git a/examples/trl/stack_llama_2/sft_llama2.py b/examples/trl/stack_llama_2/sft_llama2.py new file mode 100644 index 0000000000..1ce957b319 --- /dev/null +++ b/examples/trl/stack_llama_2/sft_llama2.py @@ -0,0 +1,199 @@ +# Fine-Tune Llama2-7b on SE paired dataset +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import torch +import transformers +import tyro +from datasets import load_dataset +from peft import AutoPeftModelForCausalLM, LoraConfig +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.trainer_utils import is_main_process +from trl.trainer import ConstantLengthDataset + +from optimum.habana import GaudiConfig, GaudiTrainingArguments +from optimum.habana.trl import GaudiSFTTrainer + + +logger = logging.getLogger(__name__) + + +@dataclass +class ScriptArguments: + model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"}) + + dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"}) + subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"}) + split: Optional[str] = field(default="train", metadata={"help": "the split to use"}) + size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"}) + streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"}) + shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"}) + seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"}) + num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"}) + + training_args: GaudiTrainingArguments = field( + default_factory=lambda: GaudiTrainingArguments( + output_dir="./results", + max_steps=500, + logging_steps=10, + save_steps=10, + per_device_train_batch_size=4, + per_device_eval_batch_size=1, + gradient_accumulation_steps=2, + gradient_checkpointing=False, + group_by_length=False, + learning_rate=1e-4, + lr_scheduler_type="cosine", + warmup_steps=100, + weight_decay=0.05, + optim="paged_adamw_32bit", + bf16=True, + remove_unused_columns=False, + run_name="sft_llama2", + report_to="wandb", + use_habana=True, + use_lazy_mode=True, + log_level="info", + ) + ) + + packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"}) + + peft_config: LoraConfig = field( + default_factory=lambda: LoraConfig( + r=8, + lora_alpha=16, + lora_dropout=0.05, + target_modules=["q_proj", "v_proj"], + bias="none", + task_type="CAUSAL_LM", + ) + ) + + +script_args = tyro.cli(ScriptArguments) + +if script_args.training_args.group_by_length and script_args.packing: + raise ValueError("Cannot use both packing and group by length") + + +def chars_token_ratio(dataset, tokenizer, nb_examples=400): + """ + Estimate the average number of characters per token in the dataset. + """ + total_characters, total_tokens = 0, 0 + for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): + text = prepare_sample_text(example) + total_characters += len(text) + if tokenizer.is_fast: + total_tokens += len(tokenizer(text).tokens()) + else: + total_tokens += len(tokenizer.tokenize(text)) + + return total_characters / total_tokens + + +def prepare_sample_text(example): + """Prepare the text from a sample of the dataset.""" + text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}" + return text + + +def create_datasets(tokenizer, args): + dataset = load_dataset( + args.dataset_name, + data_dir=args.subset, + split=args.split, + use_auth_token=True, + num_proc=args.num_workers if not args.streaming else None, + streaming=args.streaming, + ) + if args.streaming: + print("Loading the dataset in streaming mode") + valid_data = dataset.take(args.size_valid_set) + train_data = dataset.skip(args.size_valid_set) + train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None) + else: + dataset = dataset.train_test_split(test_size=0.005, seed=None) + train_data = dataset["train"] + valid_data = dataset["test"] + print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") + + chars_per_token = chars_token_ratio(train_data, tokenizer) + print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + formatting_func=prepare_sample_text, + infinite=True, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + valid_dataset = ConstantLengthDataset( + tokenizer, + valid_data, + formatting_func=prepare_sample_text, + infinite=False, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + return train_dataset, valid_dataset + + +base_model = AutoModelForCausalLM.from_pretrained( + script_args.model_name, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, + use_auth_token=True, +) +base_model.config.use_cache = False + +peft_config = script_args.peft_config + +tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training + +training_args = script_args.training_args + +log_level = training_args.get_process_log_level() +logger.setLevel(log_level) +transformers.utils.logging.set_verbosity(log_level) +transformers.utils.logging.enable_default_handler() +transformers.utils.logging.enable_explicit_format() + +train_dataset, eval_dataset = create_datasets(tokenizer, script_args) + +gaudi_config = GaudiConfig() +gaudi_config.use_fused_adam = True +gaudi_config.use_fused_clip_norm = True + +trainer = GaudiSFTTrainer( + model=base_model, + gaudi_config=gaudi_config, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + packing=script_args.packing, + max_seq_length=None, + tokenizer=tokenizer, + args=training_args, +) +trainer.train() +trainer.save_model(script_args.training_args.output_dir) + +# Free memory for merging weights +del base_model +with script_args.training_args.main_process_first(desc="merge peft model"): + if is_main_process(script_args.training_args.local_rank): + model = AutoPeftModelForCausalLM.from_pretrained( + script_args.training_args.output_dir, torch_dtype=torch.bfloat16 + ) + model = model.merge_and_unload() + + output_merged_dir = os.path.join(script_args.training_args.output_dir, "final_merged_checkpoint") + model.save_pretrained(output_merged_dir, safe_serialization=True) diff --git a/optimum/habana/trl/__init__.py b/optimum/habana/trl/__init__.py new file mode 100644 index 0000000000..e80fac8b8a --- /dev/null +++ b/optimum/habana/trl/__init__.py @@ -0,0 +1,2 @@ +from .trainer.dpo_trainer import GaudiDPOTrainer +from .trainer.sft_trainer import GaudiSFTTrainer diff --git a/optimum/habana/trl/trainer/__init__.py b/optimum/habana/trl/trainer/__init__.py new file mode 100644 index 0000000000..13bf554fd7 --- /dev/null +++ b/optimum/habana/trl/trainer/__init__.py @@ -0,0 +1,21 @@ +# flake8: noqa + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# There is a circular import in the PPOTrainer if we let isort sort these +# isort: on + +from .sft_trainer import GaudiSFTTrainer +from .dpo_trainer import GaudiDPOTrainer diff --git a/optimum/habana/trl/trainer/dpo_trainer.py b/optimum/habana/trl/trainer/dpo_trainer.py new file mode 100644 index 0000000000..c3c7660819 --- /dev/null +++ b/optimum/habana/trl/trainer/dpo_trainer.py @@ -0,0 +1,387 @@ +# DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import warnings +from collections import defaultdict +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +from accelerate.utils import is_deepspeed_available +from datasets import Dataset +from transformers import ( + AutoModelForCausalLM, + DataCollator, + PreTrainedModel, + PreTrainedTokenizerBase, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput +from trl import DPOTrainer, create_reference_model +from trl.import_utils import is_peft_available, is_wandb_available +from trl.trainer.utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length + +from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +if is_wandb_available(): + pass + +if is_deepspeed_available(): + pass + + +class GaudiDPOTrainer(DPOTrainer, GaudiTrainer): + r""" + Initialize DPOTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no + reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. + beta (`float`, defaults to 0.1): + The beta factor in DPO loss. Higher beta means less divergence from the initial policy. + loss_type (`str`, defaults to `"sigmoid"`): + The type of DPO loss to use. Either `"sigmoid"` the default DPO loss or `"hinge"` loss from SLiC paper. + args (`transformers.TrainingArguments`): + The arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + label_pad_token_id (`int`, defaults to `-100`): + The label pad token id. This argument is required if you want to use the default data collator. + padding_value (`int`, defaults to `0`): + The padding value. This argument is required if you want to use the default data collator. + truncation_mode (`str`, defaults to `keep_end`): + The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + tokenizer (`transformers.PreTrainedTokenizerBase`): + The tokenizer to use for training. This argument is required if you want to use the default data collator. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + max_length (`int`, defaults to `None`): + The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. + max_prompt_length (`int`, defaults to `None`): + The maximum length of the prompt. This argument is required if you want to use the default data collator. + max_target_length (`int`, defaults to `None`): + The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder. + peft_config (`Dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): + If no model is provided, we need to know if the model_init returns an encoder-decoder. + disable_dropout (`bool`, defaults to `True`): + Whether or not to disable dropouts in `model` and `ref_model`. + generate_during_eval (`bool`, defaults to `False`): + Whether to sample and log generations during evaluation step. + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + model_init_kwargs: (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when instantiating the model from a string + ref_model_init_kwargs: (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when instantiating the ref model from a string + + """ + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + gaudi_config: GaudiConfig = None, + beta: float = 0.1, + loss_type: Literal["sigmoid", "hinge"] = "sigmoid", + args: GaudiTrainingArguments = None, + data_collator: Optional[DataCollator] = None, + label_pad_token_id: int = -100, + padding_value: int = 0, + truncation_mode: str = "keep_end", + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + max_length: Optional[int] = None, + max_prompt_length: Optional[int] = None, + max_target_length: Optional[int] = None, + peft_config: Optional[Dict] = None, + is_encoder_decoder: Optional[bool] = None, + disable_dropout: bool = True, + generate_during_eval: bool = False, + compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + model_init_kwargs: Optional[Dict] = None, + ref_model_init_kwargs: Optional[Dict] = None, + ): + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the DPOTrainer. But your model is already instantiated.") + + if ref_model_init_kwargs is None: + ref_model_init_kwargs = {} + elif not isinstance(ref_model, str): + raise ValueError( + "You passed ref_model_kwargs to the DPOTrainer. But your ref_model is already instantiated." + ) + + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the DPOTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if isinstance(ref_model, str): + warnings.warn( + "You passed a ref model_id to the DPOTrainer. This will automatically create an " + "`AutoModelForCausalLM`" + ) + ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_config, we merge and unload it first + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + dtype = model.dtype + model = get_peft_model(model, peft_config) + model = model.to(dtype) + # For models that use gradient_checkpoiting, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if generate_during_eval and not is_wandb_available(): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases to be installed." + " Please install `wandb` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = is_encoder_decoder + + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model: + # The `model` with adapters turned off will be used as the reference model + self.ref_model = None + else: + self.ref_model = create_reference_model(model) + + if data_collator is None: + if tokenizer is None: + raise ValueError( + "max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding" + ) + if max_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init" + " it will be set to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_prompt_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + + if max_target_length is None and self.is_encoder_decoder: + warnings.warn( + "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_target_length = 128 + + data_collator = DPODataCollatorWithPadding( + tokenizer, + max_length=max_length, + max_prompt_length=max_prompt_length, + label_pad_token_id=label_pad_token_id, + padding_value=padding_value, + truncation_mode=truncation_mode, + is_encoder_decoder=self.is_encoder_decoder, + max_target_length=max_target_length, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + if disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.max_length = max_length + self.generate_during_eval = generate_during_eval + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value + + self.beta = beta + self.loss_type = loss_type + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + GaudiTrainer.__init__( + self, + model=model, + args=args, + gaudi_config=gaudi_config, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + if self.ref_model is None: + if not hasattr(self.accelerator.unwrap_model(self.model), "disable_adapter"): + raise ValueError( + "You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version." + ) + else: + if self.is_deepspeed_enabled: + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + from habana_frameworks.torch.hpu import wrap_in_hpu_graph # use graph for ref_model + + ref_model = self.accelerator.unwrap_model(self.ref_model) + ref_model = wrap_in_hpu_graph(ref_model) + + def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + concatenated_batch = {} + if self.is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + max_length = self.max_length # pad to max_length in Gaudi + + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(self.accelerator.device) + + if self.is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1) + concatenated_batch["concatenated_attention_mask"] = batch["prompt_attention_mask"].repeat(2, 1) + + return concatenated_batch diff --git a/optimum/habana/trl/trainer/sft_trainer.py b/optimum/habana/trl/trainer/sft_trainer.py new file mode 100644 index 0000000000..0bce27b1ce --- /dev/null +++ b/optimum/habana/trl/trainer/sft_trainer.py @@ -0,0 +1,271 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from datasets import Dataset +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollator, + DataCollatorForLanguageModeling, + PreTrainedModel, + PreTrainedTokenizerBase, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction +from trl import SFTTrainer +from trl.import_utils import is_peft_available +from trl.trainer.utils import ( + DataCollatorForCompletionOnlyLM, + PeftSavingCallback, +) + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model + +from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments + + +class GaudiSFTTrainer(SFTTrainer, GaudiTrainer): + r""" + Class definition of the Supervised Finetuning Trainer (SFT Trainer). + This class is a wrapper around the `transformers.Trainer` class and inherits all of its attributes and methods. + The trainer takes care of properly initializing the PeftModel in case a user passes a `PeftConfig` object. + + Args: + model (Union[`transformers.PreTrainedModel`, `nn.Module`, `str`]): + The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to + load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is + passed to the `peft_config` argument. + args (Optional[`transformers.TrainingArguments`]): + The arguments to tweak for training. Please refer to the official documentation of `transformers.TrainingArguments` + for more information. + data_collator (Optional[`transformers.DataCollator`]): + The data collator to use for training. + train_dataset (Optional[`datasets.Dataset`]): + The dataset to use for training. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. + eval_dataset (Optional[Union[`datasets.Dataset`, Dict[`str`, `datasets.Dataset`]]]): + The dataset to use for evaluation. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. + tokenizer (Optional[`transformers.PreTrainedTokenizer`]): + The tokenizer to use for training. If not specified, the tokenizer associated to the model will be used. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to None): + The function used to compute metrics during evaluation. It should return a dictionary mapping metric names to metric values. + If not specified, only the loss will be computed during evaluation. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`Optional[PeftConfig]`): + The PeftConfig object to use to initialize the PeftModel. + dataset_text_field (`Optional[str]`): + The name of the text field of the dataset, in case this is passed by a user, the trainer will automatically create a + `ConstantLengthDataset` based on the `dataset_text_field` argument. + formatting_func (`Optional[Callable]`): + The formatting function to be used for creating the `ConstantLengthDataset`. + max_seq_length (`Optional[int]`): + The maximum sequence length to use for the `ConstantLengthDataset` and for automaticallty creating the Dataset. Defaults to `512`. + infinite (`Optional[bool]`): + Whether to use an infinite dataset or not. Defaults to `False`. + num_of_sequences (`Optional[int]`): + The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`. + chars_per_token (`Optional[float]`): + The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the + stack-llama example: https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53. + packing (`Optional[bool]`): + Used only in case `dataset_text_field` is passed. This argument is used by the `ConstantLengthDataset` to pack the sequences + of the dataset. + dataset_num_proc (`Optional[int]`): + The number of workers to use to tokenize the data. Only used when `packing=False`. Defaults to None. + dataset_batch_size (`int`): + The number of examples to tokenize per batch. If batch_size <= 0 or batch_size == None, + tokenize the full dataset as a single batch. Defaults to 1000. + neftune_noise_alpha (`Optional[float]`): + If not `None`, this will activate NEFTune noise embeddings. This has been proven to drastically improve model performances for instrcution + fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune + model_init_kwargs: (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when instantiating the model from a string + """ + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + args: GaudiTrainingArguments = None, + gaudi_config: GaudiConfig = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional["PeftConfig"] = None, + dataset_text_field: Optional[str] = None, + packing: Optional[bool] = False, + formatting_func: Optional[Callable] = None, + max_seq_length: Optional[int] = None, + infinite: Optional[bool] = False, + num_of_sequences: Optional[int] = 1024, + chars_per_token: Optional[float] = 3.6, + dataset_num_proc: Optional[int] = None, + dataset_batch_size: int = 1000, + neftune_noise_alpha: Optional[float] = None, + model_init_kwargs: Optional[Dict] = None, + ): + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.") + + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the SFTTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM): + raise ValueError( + "You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument." + ) + + if is_peft_available() and peft_config is not None: + if not isinstance(peft_config, PeftConfig): + raise ValueError( + "If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer." + f" and you passed a {type(peft_config)}." + ) + + if not isinstance(model, PeftModel): + if getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + dtype = model.dtype + model = get_peft_model(model, peft_config) + model = model.to(dtype) + + if callbacks is None: + callbacks = [PeftSavingCallback] + + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + if max_seq_length is None: + # to overcome some issues with broken tokenizers + max_seq_length = min(tokenizer.model_max_length, 1024) + + warnings.warn( + f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {max_seq_length}" + ) + + self.dataset_num_proc = dataset_num_proc + self.dataset_batch_size = dataset_batch_size + + self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") + + if neftune_noise_alpha is not None and self._trainer_supports_neftune: + args.neftune_noise_alpha = neftune_noise_alpha + warnings.warn( + "You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `TrainingArguments`." + ) + # self.neftune_noise_alpha is done at Trainer level + elif not self._trainer_supports_neftune: + self.neftune_noise_alpha = neftune_noise_alpha + + if not packing: + if dataset_text_field is None and formatting_func is None: + raise ValueError( + "You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument." + ) + + if data_collator is None: + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + if train_dataset is not None: + train_dataset = self._prepare_dataset( + train_dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + infinite, + num_of_sequences, + chars_per_token, + ) + if eval_dataset is not None: + _multiple = isinstance(eval_dataset, dict) + _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} + for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): + _eval_datasets[_eval_dataset_name] = self._prepare_dataset( + _eval_dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + infinite, + num_of_sequences, + chars_per_token, + ) + if not _multiple: + eval_dataset = _eval_datasets["singleton"] + + if tokenizer.padding_side is not None and tokenizer.padding_side != "right": + warnings.warn( + "You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to " + "overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code." + ) + + GaudiTrainer.__init__( + self, + model=model, + args=args, + gaudi_config=gaudi_config, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if self.args.max_steps > 0 and packing: + warnings.warn( + "You passed `packing=True` to the SFTTrainer, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached." + ) + self.train_dataset.infinite = True + elif self.args.max_steps == -1 and packing: + self.train_dataset.infinite = False From 4c7f9b8c066f936bbe487c1bad8f162c811d647d Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 24 Dec 2023 17:45:30 -0800 Subject: [PATCH 2/2] upgrade SFTTrainer/DPO trainer and stack_llama_2 example to v0.7.6 Signed-off-by: Wang, Yi A --- examples/trl/stack_llama_2/README.md | 34 +-- examples/trl/stack_llama_2/dpo_llama2.py | 2 +- .../trl/stack_llama_2/merge_peft_adapter.py | 2 + examples/trl/stack_llama_2/requirements.txt | 2 +- examples/trl/stack_llama_2/sft_llama2.py | 79 ++---- optimum/habana/trl/trainer/dpo_trainer.py | 243 ++++++++++-------- optimum/habana/trl/trainer/sft_trainer.py | 121 ++++----- 7 files changed, 234 insertions(+), 249 deletions(-) diff --git a/examples/trl/stack_llama_2/README.md b/examples/trl/stack_llama_2/README.md index b27f62a413..12b7e4da80 100644 --- a/examples/trl/stack_llama_2/README.md +++ b/examples/trl/stack_llama_2/README.md @@ -16,22 +16,24 @@ There were two main steps to the DPO training process: ``` python ../../gaudi_spawn.py --world_size 8 --use_mpi sft_llama2.py \ - --training-args.output_dir="./sft" \ - --training-args.max_steps=500 \ - --training-args.logging_steps=10 \ - --training-args.save_steps=10 \ - --training-args.per_device_train_batch_size=4 \ - --training-args.per_device_eval_batch_size=1 \ - --training-args.gradient_accumulation_steps=2 \ - --training-args.learning_rate=1e-4 \ - --training-args.lr_scheduler_type="cosine" \ - --training-args.warmup_steps=100 \ - --training-args.weight_decay=0.05 \ - --training-args.optim="paged_adamw_32bit" \ - --training-args.bf16 \ - --training-args.remove_unused_columns=False \ - --training-args.run_name="sft_llama2" \ - --training-args.report_to=none + --output_dir="./sft" \ + --max_steps=500 \ + --logging_steps=10 \ + --save_steps=10 \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=1 \ + --gradient_accumulation_steps=2 \ + --learning_rate=1e-4 \ + --lr_scheduler_type="cosine" \ + --warmup_steps=100 \ + --weight_decay=0.05 \ + --optim="paged_adamw_32bit" \ + --bf16 \ + --remove_unused_columns=False \ + --run_name="sft_llama2" \ + --report_to=none \ + --use_habana \ + --use_lazy_mode ``` 2. Run the DPO trainer using the model saved by the previous step: ``` diff --git a/examples/trl/stack_llama_2/dpo_llama2.py b/examples/trl/stack_llama_2/dpo_llama2.py index eaadebd4ae..2b102e1825 100644 --- a/examples/trl/stack_llama_2/dpo_llama2.py +++ b/examples/trl/stack_llama_2/dpo_llama2.py @@ -1,4 +1,4 @@ -# 0. imports +# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py, enable it for Gaudi2 from dataclasses import dataclass, field from typing import Dict, Optional diff --git a/examples/trl/stack_llama_2/merge_peft_adapter.py b/examples/trl/stack_llama_2/merge_peft_adapter.py index 058099180a..8913fc62a4 100644 --- a/examples/trl/stack_llama_2/merge_peft_adapter.py +++ b/examples/trl/stack_llama_2/merge_peft_adapter.py @@ -1,3 +1,5 @@ +# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py. +# only difference is removal of model.push_to_hub from dataclasses import dataclass, field from typing import Optional diff --git a/examples/trl/stack_llama_2/requirements.txt b/examples/trl/stack_llama_2/requirements.txt index 1da3dabbc8..c980a4b30c 100644 --- a/examples/trl/stack_llama_2/requirements.txt +++ b/examples/trl/stack_llama_2/requirements.txt @@ -1,4 +1,4 @@ -trl == 0.7.4 +trl == 0.7.6 peft == 0.6.2 datasets wandb diff --git a/examples/trl/stack_llama_2/sft_llama2.py b/examples/trl/stack_llama_2/sft_llama2.py index 1ce957b319..1ebff0df14 100644 --- a/examples/trl/stack_llama_2/sft_llama2.py +++ b/examples/trl/stack_llama_2/sft_llama2.py @@ -1,4 +1,5 @@ # Fine-Tune Llama2-7b on SE paired dataset +# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama_2/scripts/sft_llama2.py, enable it for Gaudi2 import logging import os from dataclasses import dataclass, field @@ -6,11 +7,10 @@ import torch import transformers -import tyro from datasets import load_dataset from peft import AutoPeftModelForCausalLM, LoraConfig from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser from transformers.trainer_utils import is_main_process from trl.trainer import ConstantLengthDataset @@ -24,7 +24,6 @@ @dataclass class ScriptArguments: model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"}) - dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"}) subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"}) split: Optional[str] = field(default="train", metadata={"help": "the split to use"}) @@ -33,50 +32,26 @@ class ScriptArguments: shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"}) seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"}) num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"}) - - training_args: GaudiTrainingArguments = field( - default_factory=lambda: GaudiTrainingArguments( - output_dir="./results", - max_steps=500, - logging_steps=10, - save_steps=10, - per_device_train_batch_size=4, - per_device_eval_batch_size=1, - gradient_accumulation_steps=2, - gradient_checkpointing=False, - group_by_length=False, - learning_rate=1e-4, - lr_scheduler_type="cosine", - warmup_steps=100, - weight_decay=0.05, - optim="paged_adamw_32bit", - bf16=True, - remove_unused_columns=False, - run_name="sft_llama2", - report_to="wandb", - use_habana=True, - use_lazy_mode=True, - log_level="info", - ) - ) - packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"}) - peft_config: LoraConfig = field( - default_factory=lambda: LoraConfig( - r=8, - lora_alpha=16, - lora_dropout=0.05, - target_modules=["q_proj", "v_proj"], - bias="none", - task_type="CAUSAL_LM", - ) - ) - - -script_args = tyro.cli(ScriptArguments) + # LoraConfig + lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) + lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) + + +parser = HfArgumentParser((ScriptArguments, GaudiTrainingArguments)) +script_args, training_args = parser.parse_args_into_dataclasses() +peft_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + target_modules=["q_proj", "v_proj"], + bias="none", + task_type="CAUSAL_LM", +) -if script_args.training_args.group_by_length and script_args.packing: +if training_args.group_by_length and script_args.packing: raise ValueError("Cannot use both packing and group by length") @@ -152,14 +127,10 @@ def create_datasets(tokenizer, args): ) base_model.config.use_cache = False -peft_config = script_args.peft_config - tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training -training_args = script_args.training_args - log_level = training_args.get_process_log_level() logger.setLevel(log_level) transformers.utils.logging.set_verbosity(log_level) @@ -184,16 +155,14 @@ def create_datasets(tokenizer, args): args=training_args, ) trainer.train() -trainer.save_model(script_args.training_args.output_dir) +trainer.save_model(training_args.output_dir) # Free memory for merging weights del base_model -with script_args.training_args.main_process_first(desc="merge peft model"): - if is_main_process(script_args.training_args.local_rank): - model = AutoPeftModelForCausalLM.from_pretrained( - script_args.training_args.output_dir, torch_dtype=torch.bfloat16 - ) +with training_args.main_process_first(desc="merge peft model"): + if is_main_process(training_args.local_rank): + model = AutoPeftModelForCausalLM.from_pretrained(training_args.output_dir, torch_dtype=torch.bfloat16) model = model.merge_and_unload() - output_merged_dir = os.path.join(script_args.training_args.output_dir, "final_merged_checkpoint") + output_merged_dir = os.path.join(training_args.output_dir, "final_merged_checkpoint") model.save_pretrained(output_merged_dir, safe_serialization=True) diff --git a/optimum/habana/trl/trainer/dpo_trainer.py b/optimum/habana/trl/trainer/dpo_trainer.py index c3c7660819..e5cfea0cd3 100644 --- a/optimum/habana/trl/trainer/dpo_trainer.py +++ b/optimum/habana/trl/trainer/dpo_trainer.py @@ -31,7 +31,12 @@ from transformers.trainer_utils import EvalLoopOutput from trl import DPOTrainer, create_reference_model from trl.import_utils import is_peft_available, is_wandb_available -from trl.trainer.utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length +from trl.trainer.utils import ( + DPODataCollatorWithPadding, + disable_dropout_in_model, + pad_to_length, + peft_module_casting_to_bf16, +) from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments @@ -48,89 +53,25 @@ class GaudiDPOTrainer(DPOTrainer, GaudiTrainer): - r""" - Initialize DPOTrainer. - - Args: - model (`transformers.PreTrainedModel`): - The model to train, preferably an `AutoModelForSequenceClassification`. - ref_model (`PreTrainedModelWrapper`): - Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no - reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. - beta (`float`, defaults to 0.1): - The beta factor in DPO loss. Higher beta means less divergence from the initial policy. - loss_type (`str`, defaults to `"sigmoid"`): - The type of DPO loss to use. Either `"sigmoid"` the default DPO loss or `"hinge"` loss from SLiC paper. - args (`transformers.TrainingArguments`): - The arguments to use for training. - data_collator (`transformers.DataCollator`): - The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used - which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. - label_pad_token_id (`int`, defaults to `-100`): - The label pad token id. This argument is required if you want to use the default data collator. - padding_value (`int`, defaults to `0`): - The padding value. This argument is required if you want to use the default data collator. - truncation_mode (`str`, defaults to `keep_end`): - The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator. - train_dataset (`datasets.Dataset`): - The dataset to use for training. - eval_dataset (`datasets.Dataset`): - The dataset to use for evaluation. - tokenizer (`transformers.PreTrainedTokenizerBase`): - The tokenizer to use for training. This argument is required if you want to use the default data collator. - model_init (`Callable[[], transformers.PreTrainedModel]`): - The model initializer to use for training. If None is specified, the default model initializer will be used. - callbacks (`List[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - max_length (`int`, defaults to `None`): - The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. - max_prompt_length (`int`, defaults to `None`): - The maximum length of the prompt. This argument is required if you want to use the default data collator. - max_target_length (`int`, defaults to `None`): - The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder. - peft_config (`Dict`, defaults to `None`): - The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. - is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): - If no model is provided, we need to know if the model_init returns an encoder-decoder. - disable_dropout (`bool`, defaults to `True`): - Whether or not to disable dropouts in `model` and `ref_model`. - generate_during_eval (`bool`, defaults to `False`): - Whether to sample and log generations during evaluation step. - compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): - The function to use to compute the metrics. Must take a `EvalPrediction` and return - a dictionary string to metric values. - model_init_kwargs: (`Optional[Dict]`, *optional*): - Dict of Optional kwargs to pass when instantiating the model from a string - ref_model_init_kwargs: (`Optional[Dict]`, *optional*): - Dict of Optional kwargs to pass when instantiating the ref model from a string - - """ - def __init__( self, model: Union[PreTrainedModel, nn.Module, str] = None, ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, - gaudi_config: GaudiConfig = None, beta: float = 0.1, - loss_type: Literal["sigmoid", "hinge"] = "sigmoid", + label_smoothing: float = 0, + loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] = "sigmoid", args: GaudiTrainingArguments = None, + gaudi_config: GaudiConfig = None, data_collator: Optional[DataCollator] = None, label_pad_token_id: int = -100, - padding_value: int = 0, + padding_value: int = None, truncation_mode: str = "keep_end", train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, callbacks: Optional[List[TrainerCallback]] = None, - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( - None, - None, - ), + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, max_length: Optional[int] = None, max_prompt_length: Optional[int] = None, @@ -140,9 +81,18 @@ def __init__( disable_dropout: bool = True, generate_during_eval: bool = False, compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + precompute_ref_log_probs: bool = False, model_init_kwargs: Optional[Dict] = None, ref_model_init_kwargs: Optional[Dict] = None, ): + """ + Copied from DPOTrainer.__init__: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/dpo_trainer.py#L127 + The only differences are: + - add new args gaudi_config + - use graph for ref_model + - use GaudiTrainer instead of Trainer + - cast peft model to bf16. + """ if model_init_kwargs is None: model_init_kwargs = {} elif not isinstance(model, str): @@ -203,9 +153,10 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) # get peft model with the given config - dtype = model.dtype model = get_peft_model(model, peft_config) - model = model.to(dtype) + if args.bf16: + peft_module_casting_to_bf16(model) + # For models that use gradient_checkpoiting, we need to attach a hook that enables input # to explicitly have `requires_grad=True`, otherwise training will either silently # fail or completely fail. @@ -237,7 +188,7 @@ def make_inputs_require_grad(module, input, output): if ref_model: self.ref_model = ref_model - elif self.is_peft_model: + elif self.is_peft_model or precompute_ref_log_probs: # The `model` with adapters turned off will be used as the reference model self.ref_model = None else: @@ -272,14 +223,9 @@ def make_inputs_require_grad(module, input, output): max_target_length = 128 data_collator = DPODataCollatorWithPadding( - tokenizer, - max_length=max_length, - max_prompt_length=max_prompt_length, + pad_token_id=tokenizer.pad_token_id, label_pad_token_id=label_pad_token_id, - padding_value=padding_value, - truncation_mode=truncation_mode, is_encoder_decoder=self.is_encoder_decoder, - max_target_length=max_target_length, ) if args.remove_unused_columns: @@ -303,12 +249,34 @@ def make_inputs_require_grad(module, input, output): self.max_length = max_length self.generate_during_eval = generate_during_eval self.label_pad_token_id = label_pad_token_id - self.padding_value = padding_value + self.padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id + self.max_prompt_length = max_prompt_length + self.truncation_mode = truncation_mode + self.max_target_length = max_target_length + self.tokenizer = tokenizer + self.precompute_ref_log_probs = precompute_ref_log_probs + + # Since ref_logs are precomputed on the first call to get_train/eval_dataloader + # keep track of first called to avoid computation of future calls + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + + if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0: + warnings.warn( + "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter." + ) self.beta = beta + self.label_smoothing = label_smoothing self.loss_type = loss_type self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # tokenize the dataset + train_dataset = train_dataset.map(self.tokenize_row) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(self.tokenize_row) + GaudiTrainer.__init__( self, model=model, @@ -330,10 +298,17 @@ def make_inputs_require_grad(module, input, output): "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." ) + # Deepspeed Zero-3 does not support precompute_ref_log_probs + if self.is_deepspeed_enabled: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." + ) + if self.ref_model is None: - if not hasattr(self.accelerator.unwrap_model(self.model), "disable_adapter"): + if not (self.is_peft_model or self.precompute_ref_log_probs): raise ValueError( - "You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version." + "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" ) else: if self.is_deepspeed_enabled: @@ -341,36 +316,51 @@ def make_inputs_require_grad(module, input, output): else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - from habana_frameworks.torch.hpu import wrap_in_hpu_graph # use graph for ref_model - - ref_model = self.accelerator.unwrap_model(self.ref_model) - ref_model = wrap_in_hpu_graph(ref_model) + from habana_frameworks.torch.hpu import wrap_in_hpu_graph # use graph for ref_model - def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: - """Concatenate the chosen and rejected inputs into a single tensor. + ref_model = self.accelerator.unwrap_model(self.ref_model) + ref_model = wrap_in_hpu_graph(ref_model) - Args: - batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). - - Returns: - A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + @staticmethod + def concatenated_inputs( + batch: Dict[str, Union[List, torch.LongTensor]], + is_encoder_decoder: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + device: Optional[torch.device] = None, + padded_max_length: int = 0, + ) -> Dict[str, torch.LongTensor]: + """ + Copied from DPOTrainer.concatenated_inputs: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/dpo_trainer.py#L701 + - pad to self.max_length in Gaudi2 """ concatenated_batch = {} - if self.is_encoder_decoder: + + if is_encoder_decoder: max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) else: max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) - max_length = self.max_length # pad to max_length in Gaudi - + if padded_max_length != 0: # pad to max_length in Gaudi + max_length = padded_max_length for k in batch: if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): - pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 concatenated_key = k.replace("chosen", "concatenated") concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) for k in batch: if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): - pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 concatenated_key = k.replace("rejected", "concatenated") concatenated_batch[concatenated_key] = torch.cat( ( @@ -378,10 +368,59 @@ def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) - pad_to_length(batch[k], max_length, pad_value=pad_value), ), dim=0, - ).to(self.accelerator.device) + ).to(device=device) - if self.is_encoder_decoder: - concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1) - concatenated_batch["concatenated_attention_mask"] = batch["prompt_attention_mask"].repeat(2, 1) + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1).to(device=device) + ) return concatenated_batch + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """ + Copied from DPOTrainer.concatenated_forward: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/dpo_trainer.py#L866 + - pad to self.max_length in Gaudi2 + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + padded_max_length=self.max_length, + ) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "labels": concatenated_batch["concatenated_labels"], + "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), + } + if self.is_encoder_decoder + else {} + ) + all_logits = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + **model_kwargs, + ).logits + + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) diff --git a/optimum/habana/trl/trainer/sft_trainer.py b/optimum/habana/trl/trainer/sft_trainer.py index 0bce27b1ce..49b2525f4c 100644 --- a/optimum/habana/trl/trainer/sft_trainer.py +++ b/optimum/habana/trl/trainer/sft_trainer.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses +import inspect import warnings from typing import Callable, Dict, List, Optional, Tuple, Union @@ -31,80 +33,17 @@ from trl.import_utils import is_peft_available from trl.trainer.utils import ( DataCollatorForCompletionOnlyLM, - PeftSavingCallback, + peft_module_casting_to_bf16, ) if is_peft_available(): - from peft import PeftConfig, PeftModel, get_peft_model + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments class GaudiSFTTrainer(SFTTrainer, GaudiTrainer): - r""" - Class definition of the Supervised Finetuning Trainer (SFT Trainer). - This class is a wrapper around the `transformers.Trainer` class and inherits all of its attributes and methods. - The trainer takes care of properly initializing the PeftModel in case a user passes a `PeftConfig` object. - - Args: - model (Union[`transformers.PreTrainedModel`, `nn.Module`, `str`]): - The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to - load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is - passed to the `peft_config` argument. - args (Optional[`transformers.TrainingArguments`]): - The arguments to tweak for training. Please refer to the official documentation of `transformers.TrainingArguments` - for more information. - data_collator (Optional[`transformers.DataCollator`]): - The data collator to use for training. - train_dataset (Optional[`datasets.Dataset`]): - The dataset to use for training. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. - eval_dataset (Optional[Union[`datasets.Dataset`, Dict[`str`, `datasets.Dataset`]]]): - The dataset to use for evaluation. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. - tokenizer (Optional[`transformers.PreTrainedTokenizer`]): - The tokenizer to use for training. If not specified, the tokenizer associated to the model will be used. - model_init (`Callable[[], transformers.PreTrainedModel]`): - The model initializer to use for training. If None is specified, the default model initializer will be used. - compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to None): - The function used to compute metrics during evaluation. It should return a dictionary mapping metric names to metric values. - If not specified, only the loss will be computed during evaluation. - callbacks (`List[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - peft_config (`Optional[PeftConfig]`): - The PeftConfig object to use to initialize the PeftModel. - dataset_text_field (`Optional[str]`): - The name of the text field of the dataset, in case this is passed by a user, the trainer will automatically create a - `ConstantLengthDataset` based on the `dataset_text_field` argument. - formatting_func (`Optional[Callable]`): - The formatting function to be used for creating the `ConstantLengthDataset`. - max_seq_length (`Optional[int]`): - The maximum sequence length to use for the `ConstantLengthDataset` and for automaticallty creating the Dataset. Defaults to `512`. - infinite (`Optional[bool]`): - Whether to use an infinite dataset or not. Defaults to `False`. - num_of_sequences (`Optional[int]`): - The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`. - chars_per_token (`Optional[float]`): - The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the - stack-llama example: https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53. - packing (`Optional[bool]`): - Used only in case `dataset_text_field` is passed. This argument is used by the `ConstantLengthDataset` to pack the sequences - of the dataset. - dataset_num_proc (`Optional[int]`): - The number of workers to use to tokenize the data. Only used when `packing=False`. Defaults to None. - dataset_batch_size (`int`): - The number of examples to tokenize per batch. If batch_size <= 0 or batch_size == None, - tokenize the full dataset as a single batch. Defaults to 1000. - neftune_noise_alpha (`Optional[float]`): - If not `None`, this will activate NEFTune noise embeddings. This has been proven to drastically improve model performances for instrcution - fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune - model_init_kwargs: (`Optional[Dict]`, *optional*): - Dict of Optional kwargs to pass when instantiating the model from a string - """ - def __init__( self, model: Union[PreTrainedModel, nn.Module, str] = None, @@ -124,19 +63,32 @@ def __init__( packing: Optional[bool] = False, formatting_func: Optional[Callable] = None, max_seq_length: Optional[int] = None, - infinite: Optional[bool] = False, + infinite: Optional[bool] = None, num_of_sequences: Optional[int] = 1024, chars_per_token: Optional[float] = 3.6, dataset_num_proc: Optional[int] = None, dataset_batch_size: int = 1000, neftune_noise_alpha: Optional[float] = None, model_init_kwargs: Optional[Dict] = None, + dataset_kwargs: Optional[Dict] = None, ): + """ + Copied from SFTTrainer.__init__: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/sft_trainer.py#L120 + The only differences are: + - add new args gaudi_config + - use GaudiTrainer instead of Trainer + - cast peft model to bf16. + """ if model_init_kwargs is None: model_init_kwargs = {} elif not isinstance(model, str): raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.") + if infinite is not None: + warnings.warn( + "The `infinite` argument is deprecated and will be removed in a future version of TRL. Use `TrainingArguments.max_steps` or `TrainingArguments.num_train_epochs` instead to control training length." + ) + if isinstance(model, str): warnings.warn( "You passed a model_id to the SFTTrainer. This will automatically create an " @@ -157,7 +109,28 @@ def __init__( ) if not isinstance(model, PeftModel): - if getattr(args, "gradient_checkpointing", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + preprare_model_kwargs = { + "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) + } + + if _support_gc_kwargs: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + + if args is not None: + args = dataclasses.replace(args, gradient_checkpointing=False) + elif getattr(args, "gradient_checkpointing", False) and ( + "use_reentrant" not in gradient_checkpointing_kwargs + or gradient_checkpointing_kwargs["use_reentrant"] + ): # For backward compatibility with older versions of transformers if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() @@ -168,12 +141,9 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - dtype = model.dtype model = get_peft_model(model, peft_config) - model = model.to(dtype) - - if callbacks is None: - callbacks = [PeftSavingCallback] + if args.bf16: + peft_module_casting_to_bf16(model) if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) @@ -210,6 +180,9 @@ def make_inputs_require_grad(module, input, output): if data_collator is None: data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + if dataset_kwargs is None: + dataset_kwargs = {} if train_dataset is not None: train_dataset = self._prepare_dataset( train_dataset, @@ -218,9 +191,9 @@ def make_inputs_require_grad(module, input, output): dataset_text_field, max_seq_length, formatting_func, - infinite, num_of_sequences, chars_per_token, + **dataset_kwargs, ) if eval_dataset is not None: _multiple = isinstance(eval_dataset, dict) @@ -233,9 +206,9 @@ def make_inputs_require_grad(module, input, output): dataset_text_field, max_seq_length, formatting_func, - infinite, num_of_sequences, chars_per_token, + **dataset_kwargs, ) if not _multiple: eval_dataset = _eval_datasets["singleton"]