diff --git a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py index b5c59eef9b7..da89b0526c1 100644 --- a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py @@ -9,7 +9,7 @@ from peft import LoraConfig from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed -from trl import DPOTrainer +from trl import DPOConfig, DPOTrainer # Define and parse arguments. @@ -167,7 +167,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]: ) # 4. initialize training arguments: - training_args = TrainingArguments( + training_args = DPOConfig( per_device_train_batch_size=script_args.per_device_train_batch_size, per_device_eval_batch_size=script_args.per_device_eval_batch_size, max_steps=script_args.max_steps, diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 587a5efdfb6..6b76b110431 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -55,7 +55,7 @@ from datasets import Dataset, load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments -from trl import DPOTrainer, ModelConfig, get_kbit_device_map, get_peft_config, get_quantization_config +from trl import DPOConfig, DPOTrainer, ModelConfig, get_kbit_device_map, get_peft_config, get_quantization_config @dataclass @@ -116,7 +116,7 @@ def split_prompt_and_responses(sample) -> Dict[str, str]: if __name__ == "__main__": - parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig)) + parser = HfArgumentParser((ScriptArguments, DPOConfig, ModelConfig)) args, training_args, model_config = parser.parse_args_into_dataclasses() ################ diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py index c90aae96465..7ee4433726e 100644 --- a/tests/slow/test_dpo_slow.py +++ b/tests/slow/test_dpo_slow.py @@ -20,9 +20,9 @@ from accelerate.utils.memory import release_memory from datasets import load_dataset from parameterized import parameterized -from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig -from trl import DPOTrainer, is_peft_available +from trl import DPOConfig, DPOTrainer, is_peft_available from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST @@ -60,7 +60,7 @@ def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits): tokenizer = AutoTokenizer.from_pretrained(model_id) with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=2, @@ -114,7 +114,7 @@ def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_ tokenizer = AutoTokenizer.from_pretrained(model_id) with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=2, @@ -178,7 +178,7 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra tokenizer = AutoTokenizer.from_pretrained(model_id) with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=2, diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 9852ac8b7b4..f5a0f61e506 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -18,9 +18,9 @@ from datasets import Dataset from parameterized import parameterized from pytest import mark -from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from trl import DPOTrainer +from trl import DPOConfig, DPOTrainer, FDivergenceType from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft @@ -92,7 +92,7 @@ def _init_dummy_dataset(self): ) def test_dpo_trainer(self, name, loss_type, pre_compute): with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=3, @@ -140,7 +140,7 @@ def test_dpo_trainer(self, name, loss_type, pre_compute): def test_dpo_trainer_without_providing_ref_model(self): with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=3, @@ -190,7 +190,7 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self): ) with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=3, @@ -230,7 +230,7 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self): def test_dpo_trainer_padding_token_is_none(self): with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=3, @@ -265,7 +265,7 @@ def test_dpo_trainer_padding_token_is_none(self): def test_dpo_trainer_w_dataset_num_proc(self): with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=3, @@ -302,7 +302,7 @@ def test_dpo_trainer_w_dataset_num_proc(self): @require_no_wandb def test_dpo_trainer_generate_during_eval_no_wandb(self): with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=3, @@ -348,7 +348,7 @@ def test_dpo_lora_save(self): model_peft = get_peft_model(model, lora_config) with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=3, @@ -407,7 +407,7 @@ def test_dpo_lora_bf16_autocast_llama(self): model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True) with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=3, @@ -475,7 +475,7 @@ def test_dpo_lora_bf16_autocast(self, name, loss_type, pre_compute, gen_during_e model = AutoModelForCausalLM.from_pretrained(self.model_id, load_in_4bit=True) with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=3, @@ -528,7 +528,7 @@ def test_dpo_lora_tags(self): model = AutoModelForCausalLM.from_pretrained(model_id) with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=3, @@ -563,7 +563,7 @@ def test_dpo_tags(self): model = AutoModelForCausalLM.from_pretrained(model_id) with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( + training_args = DPOConfig( output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=3, @@ -587,3 +587,93 @@ def test_dpo_tags(self): ) assert trainer.model.model_tags == trainer._tag_names + + def test_dpo_loss_alpha_div_f(self): + model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # lora model + model = AutoModelForCausalLM.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + f_divergence_type=FDivergenceType.ALPHA_DIVERGENCE.value, + f_alpha_divergence_coef=0.5, + ) + + dummy_dataset = self._init_dummy_dataset() + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + beta=0.1, + args=training_args, + tokenizer=tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + # Fake chosen and rejected log probs + policy_chosen_logps = torch.FloatTensor([410.0, 0.1]) + policy_rejected_logps = torch.FloatTensor([810.5, 0.2]) + reference_chosen_logps = torch.FloatTensor([-610.0, -0.1]) + reference_rejected_logps = torch.FloatTensor([110.6, 0.5]) + losses, _, _ = trainer.dpo_loss(policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps) + + assert torch.isfinite(losses).cpu().numpy().all() + + def test_dpo_loss_js_div_f(self): + model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # lora model + model = AutoModelForCausalLM.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + f_divergence_type=FDivergenceType.JS_DIVERGENCE.value, + f_alpha_divergence_coef=0.5, + ) + + dummy_dataset = self._init_dummy_dataset() + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + beta=0.1, + args=training_args, + tokenizer=tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + # Fake chosen and rejected log probs + policy_chosen_logps = torch.FloatTensor([410.0, 0.1]) + policy_rejected_logps = torch.FloatTensor([95.5, 0.2]) + reference_chosen_logps = torch.FloatTensor([-610.0, -0.1]) + reference_rejected_logps = torch.FloatTensor([5.5, 0.5]) + losses, _, _ = trainer.dpo_loss(policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps) + + assert torch.isfinite(losses).cpu().numpy().all() diff --git a/trl/__init__.py b/trl/__init__.py index 943323390b9..5a2b49ffc14 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -22,7 +22,10 @@ ) from .trainer import ( DataCollatorForCompletionOnlyLM, + DPOConfig, DPOTrainer, + FDivergenceConstants, + FDivergenceType, IterativeSFTTrainer, KTOConfig, KTOTrainer, diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 86e654213fc..4f45cfb5270 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -36,7 +36,7 @@ if is_diffusers_available(): from .ddpo_trainer import DDPOTrainer -from .dpo_trainer import DPOTrainer +from .dpo_trainer import DPOConfig, DPOTrainer, FDivergenceConstants, FDivergenceType from .iterative_sft_trainer import IterativeSFTTrainer from .kto_config import KTOConfig from .kto_trainer import KTOTrainer diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py new file mode 100644 index 00000000000..8792fe5f0b2 --- /dev/null +++ b/trl/trainer/dpo_config.py @@ -0,0 +1,46 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, Optional + +from transformers import TrainingArguments + + +class FDivergenceType(Enum): + REVERSE_KL = "reverse_kl" + JS_DIVERGENCE = "js_divergence" + ALPHA_DIVERGENCE = "alpha_divergence" + + +class FDivergenceConstants: + ALPHA_DIVERGENCE_COEF_KEY = "alpha_divergence_coef" + ALPHA_DIVERGENCE_COEF_DEFAULT = 1.0 + + +@dataclass +class DPOConfig(TrainingArguments): + """ + DPOConfig collects all training arguments related to the [`DPOTrainer`] class. + Using [`HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + Parameters: + f_divergence_type (`FDivergenceType`, *optional*, defaults to `FDivergenceType.REVERSE_KL`): + The type of f-divergence regularization function to compute divergence between policy and reference model. This argument is optional, defaults to `FDivergenceType.REVERSE_KL`. + f_divergence_params (`Dict`, *optional*, defaults to `None`): + The parameters of f-divergence regularization function, eg: the alpha parameter in alpha-divergence. This argument is optional, defaults to 'None'. + """ + + f_divergence_type: Optional[FDivergenceType] = FDivergenceType.REVERSE_KL + """The type of f-divergence regularization function to compute divergence between policy and reference model, This argument is optional, defaults to `FDivergenceType.REVERSE_KL`.""" + f_alpha_divergence_coef: float = field(default=1.0, metadata={"help": "the alpha coef in alpha-divergence(u^-alpha) regularization function for DPO loss"}) + """The alpha coef in alpha-divergence(u^-alpha) regularization function for DPO loss.""" diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 0dfc361f5f3..18c6937c758 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -42,12 +42,14 @@ from ..import_utils import is_peft_available, is_wandb_available from ..models import PreTrainedModelWrapper, create_reference_model +from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType from .utils import ( DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length, peft_module_casting_to_bf16, trl_sanitze_kwargs_for_tagging, + cap_exp, ) @@ -78,7 +80,7 @@ class DPOTrainer(Trainer): The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report that should be between 0 and 0.5. loss_type (`str`, defaults to `"sigmoid"`): The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper, or `"kto"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf). - args (`transformers.TrainingArguments`): + args (`DPOConfig`): The arguments to use for training. data_collator (`transformers.DataCollator`): The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used @@ -146,7 +148,7 @@ def __init__( beta: float = 0.1, label_smoothing: float = 0, loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid", - args: Optional[TrainingArguments] = None, + args: Optional[DPOConfig] = None, data_collator: Optional[DataCollator] = None, label_pad_token_id: int = -100, padding_value: Optional[int] = None, @@ -174,6 +176,8 @@ def __init__( ref_adapter_name: Optional[str] = None, reference_free: bool = False, ): + if type(args) == TrainingArguments: + raise ValueError("Please use `DPOConfig` instead of TrainingArguments.") if model_init_kwargs is None: model_init_kwargs = {} elif not isinstance(model, str): @@ -366,6 +370,12 @@ def make_inputs_require_grad(module, input, output): self._stored_metrics = defaultdict(lambda: defaultdict(list)) + # DPO-specific parameters + self.f_divergence_type = args.f_divergence_type + self.f_divergence_params = { + FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef + } + self.dataset_num_proc = dataset_num_proc # Compute that only on the main process for faster data processing. @@ -857,15 +867,41 @@ def dpo_loss( The losses tensor contains the DPO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. """ - pi_logratios = policy_chosen_logps - policy_rejected_logps - if self.reference_free: - ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device) + chosen_logratios = policy_chosen_logps.to(self.accelerator.device) - ( + not self.reference_free) * reference_chosen_logps.to(self.accelerator.device) + rejected_logratios = policy_rejected_logps.to(self.accelerator.device) - ( + not self.reference_free) * reference_rejected_logps.to(self.accelerator.device) + + if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value: + # The alpha-divergence formula: (1 - u^-alpha) / alpha + # The divergence difference between the chosen and rejected sample is: + # (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha + # = (u[l]^-alpha - u[w]^-alpha) / alpha + # where u[w] and u[l] are the policy/reference probability ratios + # for the chosen and rejected samples, respectively. + alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT + if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params: + alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY]) + logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef else: - ref_logratios = reference_chosen_logps - reference_rejected_logps - - pi_logratios = pi_logratios.to(self.accelerator.device) - ref_logratios = ref_logratios.to(self.accelerator.device) - logits = pi_logratios - ref_logratios + pi_logratios = policy_chosen_logps - policy_rejected_logps + if self.reference_free: + ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device) + else: + ref_logratios = reference_chosen_logps - reference_rejected_logps + + pi_logratios = pi_logratios.to(self.accelerator.device) + ref_logratios = ref_logratios.to(self.accelerator.device) + logits = pi_logratios - ref_logratios + + if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value: + # The js-divergence formula: log(2 * u / (1 + u)) + # The divergence difference between the chosen and rejected sample is: + # log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l])) + # = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l])) + # where u[w] and u[l] are the policy/reference probability ratios + # for the chosen and rejected samples, respectively. + logits -= (F.softplus(chosen_logratios) - F.softplus(rejected_logratios)) # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 26a5538cf4f..541399d20b0 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -732,3 +732,31 @@ def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]": ) return peft_config + + +def get_exp_cap(value, decimal=4): + """ + Get the exponent cap of a value. This is used to cap the exponent of a value to avoid overflow. + The formula is : log(value.dtype.max) + + E.g. + For float32 data type, the maximum exponent value is 88.7228 to 4 decimal points. + ``` + + Args: + value (`torch.Tensor`): + The input tensor to obtain the data type + decimal (`int`): + The number of decimal points of the output exponent cap. + eg: direct calling exp(log(torch.float32.max)) will result in inf + so we cap the exponent to 88.7228 to avoid overflow. + """ + vdtype_max = torch.zeros([1]).to(value.dtype) + torch.finfo(value.dtype).max + vdtype_log_max = torch.log(vdtype_max).to(value.device) + return torch.floor(vdtype_log_max * 10 ** decimal) / 10 ** decimal if decimal > 0 else vdtype_log_max + + +def cap_exp(value, cap=-1): + # Cap the exponent value below the upper-bound to avoid overflow, before calling torch.exp + cap = get_exp_cap(value) if cap < 0 else cap + return torch.exp(torch.clamp(value, max=cap))