From f50e74d1f2811d11590bc561c25a7291627ee2b4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 14 Jan 2025 14:09:26 +0100 Subject: [PATCH 01/17] initial liger support --- tests/test_dpo_trainer.py | 76 ++++++- trl/trainer/dpo_config.py | 18 ++ trl/trainer/dpo_trainer.py | 448 ++++++++++++++++++++++++------------- 3 files changed, 391 insertions(+), 151 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index a786f6a41ef..242e63f08de 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -29,7 +29,12 @@ PreTrainedTokenizerBase, is_vision_available, ) -from transformers.testing_utils import require_peft, require_torch_gpu_if_bnb_not_multi_backend_enabled, require_vision +from transformers.testing_utils import ( + require_liger_kernel, + require_peft, + require_torch_gpu_if_bnb_not_multi_backend_enabled, + require_vision, +) from trl import DPOConfig, DPOTrainer, FDivergenceType @@ -1204,6 +1209,75 @@ def test_padding_free(self): if param.sum() != 0: self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + @require_liger_kernel + def test_dpo_trainer_with_liger(self): + """Test ORPO trainer with Liger loss enabled. + + This test verifies that: + 1. Training runs successfully with Liger loss + 2. Model parameters update as expected + 3. Loss values are reasonable and finite + 4. Training works with both default and custom beta values + """ + beta_values = [0.1, 0.5] # Test multiple beta values + + for beta in beta_values: + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=beta, + use_liger_loss=True, # Enable Liger loss + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = DPOTrainer( + model=self.model, + ref_model=self.ref_model, # Add reference model + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + # Store initial parameters + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + train_output = trainer.train() + + # Verify training completed successfully + self.assertIsNotNone(train_output) + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Verify loss is finite + self.assertTrue(np.isfinite(trainer.state.log_history[-1]["train_loss"])) + + # Check parameters have been updated + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # Only check non-zero parameters + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + # Verify new parameters are finite + self.assertTrue(torch.isfinite(new_param).all()) + + # Verify model can still do forward pass after training + dummy_batch = next(iter(trainer.get_train_dataloader())) + with torch.no_grad(): + output = trainer.model( + **{k: v for k, v in dummy_batch.items() if k in trainer.model.forward.__code__.co_varnames} + ) + self.assertIsNotNone(output) + self.assertTrue(torch.isfinite(output.loss)) + @require_vision class DPOVisionTrainerTester(unittest.TestCase): diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index b7dd6b049a7..e8ea7d0dc96 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -150,6 +150,12 @@ class DPOConfig(TrainingArguments): into a single continuous sequence. This approach requires associating a `position_ids` vector to track positional information. Currently, this is only supported with the `flash_attention_2` mechanism, as it can handle the flattened batch structure. + use_liger_loss (`bool`, *optional*, defaults to `False`): + Whether to use Liger loss. + base_model_attribute_name (`str`, *optional*, defaults to `"model"`): + Name of the attribute in the model that contains the base model. This is used to get the base model + from the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is + `True`. """ learning_rate: float = field( @@ -367,3 +373,15 @@ class DPOConfig(TrainingArguments): "batch into a single sample, associated with a position_ids vector. Only possible with flash-attention." }, ) + use_liger_loss: bool = field( + default=False, + metadata={"help": "Whether to use Liger loss."}, + ) + base_model_attribute_name: str = field( + default="model", + metadata={ + "help": "Name of the attribute in the model that contains the base model. This is used to get the base model " + "from the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is " + "`True`." + }, + ) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 3774c84b21a..3698db74335 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -50,7 +50,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput -from transformers.utils import is_peft_available, is_torch_xpu_available +from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_xpu_available from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt @@ -74,6 +74,9 @@ if is_peft_available(): from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss + if is_wandb_available(): import wandb @@ -82,6 +85,13 @@ import deepspeed +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int) -> torch.Tensor: + """Shift input ids one token to the right, and pad with pad_token_id""" + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + @dataclass class DataCollatorForPreference(DataCollatorMixin): """ @@ -386,6 +396,17 @@ def make_inputs_require_grad(module, input, output): if self.ref_model is not None: disable_dropout_in_model(self.ref_model) + # Liger kernel + if args.use_liger_loss and args.loss_type == "sigmoid": + if not is_liger_kernel_available(): + raise ValueError( + "You set `use_liger_loss=True` but the liger kernel is not available. " + "Please install liger-kernel first: `pip install liger-kernel`" + ) + self.dpo_loss_fn = LigerFusedLinearDPOLoss( + ignore_index=args.label_pad_token_id, beta=args.beta, use_ref_model=not args.reference_free + ) + self.max_length = args.max_length self.generate_during_eval = args.generate_during_eval self.label_pad_token_id = args.label_pad_token_id @@ -1110,161 +1131,283 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to if "image_sizes" in concatenated_batch: model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] - prompt_input_ids = concatenated_batch["prompt_input_ids"] - prompt_attention_mask = concatenated_batch["prompt_attention_mask"] - completion_input_ids = concatenated_batch["completion_input_ids"] - completion_attention_mask = concatenated_batch["completion_attention_mask"] - if self.is_encoder_decoder: - labels = completion_input_ids - labels[completion_attention_mask == 0] = self.label_pad_token_id - outputs = model( - input_ids=prompt_input_ids, - attention_mask=prompt_attention_mask, - labels=labels, # we need the labels for the logits to be returned - **model_kwargs, + if self.args.use_liger_loss and self.loss_type == "sigmoid": + if self.is_encoder_decoder: + # 1. Get encoder outputs + encoder_outputs = model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + # 2. Prepare decoder inputs + decoder_input_ids = shift_tokens_right( + concatenated_batch["completion_input_ids"], self.padding_value, model.config.decoder_start_token_id + ) + # 3. Get decoder outputs + decoder_outputs = model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + hidden_states = decoder_outputs.last_hidden_state + + ref_hidden_states = None + if not self.reference_free and self.ref_model is not None: + ref_encoder_outputs = self.ref_model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + ref_decoder_outputs = self.ref_model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=ref_encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + ref_hidden_states = ref_decoder_outputs.last_hidden_state + + labels = concatenated_batch["completion_input_ids"] + else: + # For decoder-only models + input_ids = torch.cat( + (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1 + ) + attention_mask = torch.cat( + (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]), + dim=1, + ) + + # Get the base model outputs (before LM head) + if hasattr(model, "get_decoder"): + base_model = model.get_decoder() + else: + base_model = getattr(model, self.args.base_model_attribute_name, model) + + outputs = base_model( + input_ids, + attention_mask=attention_mask, + use_cache=False, + **model_kwargs, + ) + hidden_states = outputs.last_hidden_state[:, :-1] + + # Get reference hidden states if needed + ref_hidden_states = None + if not self.reference_free and self.ref_model is not None: + if hasattr(self.ref_model, "get_decoder"): + ref_base_model = self.ref_model.get_decoder() + else: + ref_base_model = getattr(self.ref_model, self.args.base_model_attribute_name, self.ref_model) + + ref_outputs = ref_base_model( + input_ids, + attention_mask=attention_mask, + use_cache=False, + **model_kwargs, + ) + ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + + labels = input_ids[:, 1:] # Shift right for casual LM + + # Get the LM head + lm_head = model.get_output_embeddings() + + # Get reference model weights if needed + ref_weight = None + ref_bias = None + if not self.reference_free and self.ref_model is not None: + ref_lm_head = self.ref_model.get_output_embeddings() + ref_weight = ref_lm_head.weight + ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None + + # Compute loss using Liger kernel + loss_output = self.dpo_loss_fn( + lm_head.weight, + hidden_states, + labels, + bias=lm_head.bias if hasattr(lm_head, "bias") else None, + ref_input=ref_hidden_states if not self.reference_free else None, + ref_weight=ref_weight if not self.reference_free else None, + ref_bias=ref_bias if not self.reference_free else None, ) - logits = outputs.logits - loss_mask = completion_attention_mask.bool() - else: - # Concatenate the prompt and completion inputs - input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1) - attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1) - # Mask the prompt but not the completion for the loss - loss_mask = torch.cat( - (torch.zeros_like(prompt_attention_mask), completion_attention_mask), - dim=1, + loss, (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs) = ( + loss_output ) - # Flush left to reduce the memory usage - # [[0, 0, x, x, x, x], -> [[x, x, x, x], - # [0, x, x, x, 0, 0]] [x, x, x, 0]] - for i in range(attention_mask.size(0)): - first_one_idx = torch.nonzero(attention_mask[i])[0].item() - input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) - attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) - loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) - - # Get the first column idx that is all zeros and remove every column after that - empty_cols = torch.sum(attention_mask, dim=0) == 0 - first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1) - input_ids = input_ids[:, :first_empty_col] - attention_mask = attention_mask[:, :first_empty_col] - loss_mask = loss_mask[:, :first_empty_col] - - # Truncate right - if self.args.max_length is not None: - input_ids = input_ids[:, : self.args.max_length] - attention_mask = attention_mask[:, : self.args.max_length] - loss_mask = loss_mask[:, : self.args.max_length] - - if self.use_num_logits_to_keep: - # Compute num_logits_to_keep based on loss_mask pattern: - # [[0, 0, 0, x, x, x, x], - # [0, 0, 0, x, x, x, 0]] - # ^ start computing logits from here ([:, -(7-3+1):]) - first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() - num_logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label - model_kwargs["num_logits_to_keep"] = num_logits_to_keep + output = { + "loss": loss, + "chosen_logps": chosen_logps, + "rejected_logps": rejected_logps, + "mean_chosen_logits": chosen_logits_mean, + "mean_rejected_logits": rejected_logits_mean, + } + if self.aux_loss_enabled and aux_outputs: + output["aux_loss"] = aux_outputs[0] # Assuming aux_loss is the first aux output - if self.padding_free: - # Flatten the input_ids, position_ids, and loss_mask - # input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]] - # [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]] - input_ids = input_ids[attention_mask.bool()].unsqueeze(0) - loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) - position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 - model_kwargs["position_ids"] = position_ids + return output + else: + prompt_input_ids = concatenated_batch["prompt_input_ids"] + prompt_attention_mask = concatenated_batch["prompt_attention_mask"] + completion_input_ids = concatenated_batch["completion_input_ids"] + completion_attention_mask = concatenated_batch["completion_attention_mask"] + if self.is_encoder_decoder: + labels = completion_input_ids + labels[completion_attention_mask == 0] = self.label_pad_token_id + outputs = model( + input_ids=prompt_input_ids, + attention_mask=prompt_attention_mask, + labels=labels, # we need the labels for the logits to be returned + **model_kwargs, + ) + logits = outputs.logits + loss_mask = completion_attention_mask.bool() else: - model_kwargs["attention_mask"] = attention_mask - - outputs = model(input_ids, **model_kwargs) - logits = outputs.logits - - # Offset the logits by one to align with the labels - labels = torch.roll(input_ids, shifts=-1, dims=1) - loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() - - if self.use_num_logits_to_keep: - # Align labels with logits - # logits: -, -, [x2, x3, x4, x5, x6] - # ^ --------- ^ after logits[:, :-1, :] - # labels: [y0, y1, y2, y3, y4, y5, y6] - # ^ --------- ^ with num_logits_to_keep=4, [:, -4:] - # loss_mask: [0, 0, 0, 1, 1, 1, 1] - labels = labels[:, -num_logits_to_keep:] - loss_mask = loss_mask[:, -num_logits_to_keep:] - - if logits.shape[:2] != labels.shape[:2]: - # for llava, the returned logits include the image tokens (placed before the text tokens) - seq_len = labels.shape[1] - logits = logits[:, -seq_len:] - - # Compute the log probabilities of the labels - labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - per_token_logps[~loss_mask] = 0 - per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) - - if self.padding_free: - # Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len]) - batch_size, seq_len = attention_mask.shape - per_token_logps_ = torch.zeros( - batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype - ) - per_token_logps_[attention_mask.bool()] = per_token_logps - per_token_logps = per_token_logps_ + # Concatenate the prompt and completion inputs + input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1) + attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1) + # Mask the prompt but not the completion for the loss + loss_mask = torch.cat( + (torch.zeros_like(prompt_attention_mask), completion_attention_mask), + dim=1, + ) - all_logps = per_token_logps.sum(-1) + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + for i in range(attention_mask.size(0)): + first_one_idx = torch.nonzero(attention_mask[i])[0].item() + input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) + attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) + loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) + + # Get the first column idx that is all zeros and remove every column after that + empty_cols = torch.sum(attention_mask, dim=0) == 0 + first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1) + input_ids = input_ids[:, :first_empty_col] + attention_mask = attention_mask[:, :first_empty_col] + loss_mask = loss_mask[:, :first_empty_col] + + # Truncate right + if self.args.max_length is not None: + input_ids = input_ids[:, : self.args.max_length] + attention_mask = attention_mask[:, : self.args.max_length] + loss_mask = loss_mask[:, : self.args.max_length] + + if self.use_num_logits_to_keep: + # Compute num_logits_to_keep based on loss_mask pattern: + # [[0, 0, 0, x, x, x, x], + # [0, 0, 0, x, x, x, 0]] + # ^ start computing logits from here ([:, -(7-3+1):]) + first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() + num_logits_to_keep = ( + loss_mask.shape[1] - first_compute_index + ).item() + 1 # +1 for the first label + model_kwargs["num_logits_to_keep"] = num_logits_to_keep + + if self.padding_free: + # Flatten the input_ids, position_ids, and loss_mask + # input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]] + # [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]] + input_ids = input_ids[attention_mask.bool()].unsqueeze(0) + loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) + position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 + model_kwargs["position_ids"] = position_ids + else: + model_kwargs["attention_mask"] = attention_mask + + outputs = model(input_ids, **model_kwargs) + logits = outputs.logits + + # Offset the logits by one to align with the labels + labels = torch.roll(input_ids, shifts=-1, dims=1) + loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() + + if self.use_num_logits_to_keep: + # Align labels with logits + # logits: -, -, [x2, x3, x4, x5, x6] + # ^ --------- ^ after logits[:, :-1, :] + # labels: [y0, y1, y2, y3, y4, y5, y6] + # ^ --------- ^ with num_logits_to_keep=4, [:, -4:] + # loss_mask: [0, 0, 0, 1, 1, 1, 1] + labels = labels[:, -num_logits_to_keep:] + loss_mask = loss_mask[:, -num_logits_to_keep:] + + if logits.shape[:2] != labels.shape[:2]: + # for llava, the returned logits include the image tokens (placed before the text tokens) + seq_len = labels.shape[1] + logits = logits[:, -seq_len:] - output = {} + # Compute the log probabilities of the labels + labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps[~loss_mask] = 0 + per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) - if self.use_weighting: - with torch.no_grad(): - # Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827 - logprobs = F.log_softmax(logits, dim=-1) - weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1) # same as sum(probs**2) in log space - per_token_logps_adjusted = per_token_logps - weights_adjustment_factor - all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1) - chosen_weights = all_weights[:num_examples] - rejected_weights = all_weights[num_examples:] - output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) + if self.padding_free: + # Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len]) + batch_size, seq_len = attention_mask.shape + per_token_logps_ = torch.zeros( + batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype + ) + per_token_logps_[attention_mask.bool()] = per_token_logps + per_token_logps = per_token_logps_ + + all_logps = per_token_logps.sum(-1) + + output = {} + + if self.use_weighting: + with torch.no_grad(): + # Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827 + logprobs = F.log_softmax(logits, dim=-1) + weights_adjustment_factor = torch.logsumexp( + 2 * logprobs, dim=-1 + ) # same as sum(probs**2) in log space + per_token_logps_adjusted = per_token_logps - weights_adjustment_factor + all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1) + chosen_weights = all_weights[:num_examples] + rejected_weights = all_weights[num_examples:] + output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) + + if self.args.rpo_alpha is not None: + # Only use the chosen logits for the RPO loss + chosen_logits = logits[:num_examples] + chosen_labels = labels[:num_examples] + + # Compute the log probabilities of the labels + output["nll_loss"] = F.cross_entropy( + torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0 + ) - if self.args.rpo_alpha is not None: - # Only use the chosen logits for the RPO loss - chosen_logits = logits[:num_examples] - chosen_labels = labels[:num_examples] + if self.loss_type == "ipo": + all_logps = all_logps / loss_mask.sum(-1) - # Compute the log probabilities of the labels - output["nll_loss"] = F.cross_entropy( - torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0 - ) + output["chosen_logps"] = all_logps[:num_examples] + output["rejected_logps"] = all_logps[num_examples:] - if self.loss_type == "ipo": - all_logps = all_logps / loss_mask.sum(-1) - - output["chosen_logps"] = all_logps[:num_examples] - output["rejected_logps"] = all_logps[num_examples:] - - # Compute the mean logits - if self.padding_free: - # position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]). - # There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens, - # and the second half to the rejected tokens. - # To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id. - split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples] - mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean() - mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean() - else: - mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() - mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean() + # Compute the mean logits + if self.padding_free: + # position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]). + # There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens, + # and the second half to the rejected tokens. + # To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id. + split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples] + mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean() + mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean() + else: + mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() + mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean() - output["mean_chosen_logits"] = mean_chosen_logits - output["mean_rejected_logits"] = mean_rejected_logits + output["mean_chosen_logits"] = mean_chosen_logits + output["mean_rejected_logits"] = mean_rejected_logits - if self.aux_loss_enabled: - output["aux_loss"] = outputs.aux_loss + if self.aux_loss_enabled: + output["aux_loss"] = outputs.aux_loss - return output + return output def get_batch_loss_metrics( self, @@ -1277,16 +1420,21 @@ def get_batch_loss_metrics( model_output = self.concatenated_forward(model, batch) - # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model - if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: - ref_chosen_logps = batch["ref_chosen_logps"] - ref_rejected_logps = batch["ref_rejected_logps"] + if self.args.use_liger_loss: + losses = model_output["loss"] + chosen_rewards = model_output["chosen_rewards"] + rejected_rewards = model_output["rejected_rewards"] else: - ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) + # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model + if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: + ref_chosen_logps = batch["ref_chosen_logps"] + ref_rejected_logps = batch["ref_rejected_logps"] + else: + ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) - losses, chosen_rewards, rejected_rewards = self.dpo_loss( - model_output["chosen_logps"], model_output["rejected_logps"], ref_chosen_logps, ref_rejected_logps - ) + losses, chosen_rewards, rejected_rewards = self.dpo_loss( + model_output["chosen_logps"], model_output["rejected_logps"], ref_chosen_logps, ref_rejected_logps + ) reward_accuracies = (chosen_rewards > rejected_rewards).float() if self.args.rpo_alpha is not None: From e3eebd3416695feadf5b68cb39d6329821ed0d80 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 15 Jan 2025 12:05:19 +0100 Subject: [PATCH 02/17] fix outputs --- tests/test_dpo_trainer.py | 10 ++++++---- trl/trainer/dpo_trainer.py | 18 +++++++++++------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 242e63f08de..e4b9da4d225 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1271,12 +1271,14 @@ def test_dpo_trainer_with_liger(self): # Verify model can still do forward pass after training dummy_batch = next(iter(trainer.get_train_dataloader())) + model_inputs = { + "input_ids": dummy_batch["prompt_input_ids"], + "attention_mask": dummy_batch["prompt_attention_mask"], + } with torch.no_grad(): - output = trainer.model( - **{k: v for k, v in dummy_batch.items() if k in trainer.model.forward.__code__.co_varnames} - ) + output = trainer.model(**model_inputs) self.assertIsNotNone(output) - self.assertTrue(torch.isfinite(output.loss)) + self.assertIsNone(output.loss) @require_vision diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 3698db74335..77532fff0fc 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1233,9 +1233,10 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to ref_weight=ref_weight if not self.reference_free else None, ref_bias=ref_bias if not self.reference_free else None, ) - loss, (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs) = ( - loss_output - ) + ( + loss, + (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, _, _, *aux_outputs), + ) = loss_output output = { "loss": loss, @@ -1243,9 +1244,12 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to "rejected_logps": rejected_logps, "mean_chosen_logits": chosen_logits_mean, "mean_rejected_logits": rejected_logits_mean, + "nll_loss": nll_loss, + "chosen_rewards": aux_outputs[0], + "rejected_rewards": aux_outputs[1], } - if self.aux_loss_enabled and aux_outputs: - output["aux_loss"] = aux_outputs[0] # Assuming aux_loss is the first aux output + if self.aux_loss_enabled: + output["aux_loss"] = outputs.aux_loss return output else: @@ -1374,8 +1378,8 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to if self.args.rpo_alpha is not None: # Only use the chosen logits for the RPO loss - chosen_logits = logits[:num_examples] - chosen_labels = labels[:num_examples] + chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples] + chosen_labels = labels[:num_examples, 1:] if self.is_encoder_decoder else labels[:num_examples] # Compute the log probabilities of the labels output["nll_loss"] = F.cross_entropy( From 2d82b39364727fb2976178ca908a9eecaa48fbbd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 15 Jan 2025 12:10:39 +0100 Subject: [PATCH 03/17] fix config merge conflict --- trl/trainer/dpo_config.py | 430 ++++++++++++++++++++------------------ 1 file changed, 226 insertions(+), 204 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index e8ea7d0dc96..e581240e6f2 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass, field from enum import Enum from typing import Any, Optional @@ -40,16 +41,70 @@ class DPOConfig(TrainingArguments): command line. Parameters: + > Parameters that control the model and reference model + + model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of the + [`DPOTrainer`] is provided as a string. + ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument of the + [`DPOTrainer`] is provided as a string. + model_adapter_name (`str` or `None`, *optional*, defaults to `None`): + Name of the train target PEFT adapter, when using LoRA with multiple adapters. + ref_adapter_name (`str` or `None`, *optional*, defaults to `None`): + Name of the reference PEFT adapter, when using LoRA with multiple adapters. + force_use_ref_model (`bool`, *optional*, defaults to `False`): + If you provide a PEFT model as the active model and wish to use a different model for the `ref_model`, set + this flag to `True`. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model and reference model. + use_num_logits_to_keep (`bool`, *optional*, defaults to `False`): + If `True`, only a specified number of logits are computed in the forward pass. This can be useful for + saving memory and speeding up training by not computing the logits for all tokens, especially in + scenarios when working with very long prompts where labels are ignored (-100). + + > Parameters that control the data preprocessing + + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + padding_value (`int` or `None`, *optional*, defaults to `None`): + Padding value to use. If `None`, the padding value of the tokenizer is used. + label_pad_token_id (`int`, *optional*, defaults to `-100`): + Padding value to use for labels. + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to usewhen the prompt is too long, either `keep_end` or `keep_start`. + max_prompt_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the prompt. + max_completion_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the completion. + max_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the full sequence (prompt + completion). + padding_free (`bool`, *optional*, defaults to `False`): + Whether forward passes are performed without padding by flattening all sequences in the batch + into a single continuous sequence. This approach requires associating a `position_ids` vector to track + positional information. Currently, this is only supported with the `flash_attention_2` mechanism, as it + can handle the flattened batch structure. + precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): + Whether to precompute the log probabilities from the reference model. Setting this to `True` allows + training without needing the reference model during training, which can help reduce GPU memory usage. If + set to `False` (default), the reference model will be used during training to compute log probabilities + on-the-fly. + precompute_ref_batch_size (`int` or `None`, *optional*, defaults to `None`): + Batch size to use when precomputing reference model log probabilities. This can be set higher than the + training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for + training and `per_device_eval_batch_size` for evaluation. + use_liger_loss (`bool`, *optional*, defaults to `False`): + Whether to use Liger loss. + base_model_attribute_name (`str`, *optional*, defaults to `"model"`): + Name of the attribute in the model that contains the base model. This is used to get the base model + from the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is + `True`. + + > Parameters that control the training + learning_rate (`float`, *optional*, defaults to `1e-6`): Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`]. - beta (`float`, *optional*, defaults to `0.1`): - Parameter controlling the deviation from the reference model. Higher β means less deviation from the - reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in - the [paper](https://huggingface.co/papers/2310.12036). - label_smoothing (`float`, *optional*, defaults to `0.0`): - Robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report and - [Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`. loss_type (`str`, *optional*, defaults to `"sigmoid"`): Type of loss to use. Possible values are: @@ -66,200 +121,142 @@ class DPOConfig(TrainingArguments): - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper. - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - use_weighting (`bool`, *optional*, defaults to `False`): - Whether to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. - label_pad_token_id (`int`, *optional*, defaults to `-100`): - Label pad token id. This argument is required if you want to use the default data collator. - padding_value (`int` or `None`, *optional*, defaults to `None`): - Padding value to use. If `None`, the padding value of the tokenizer is used. - truncation_mode (`str`, *optional*, defaults to `"keep_end"`): - Truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the - default data collator. - max_length (`int` or `None`, *optional*, defaults to `None`): - Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want - to use the default data collator. - max_prompt_length (`int` or `None`, *optional*, defaults to `None`): - Maximum length of the prompt. This argument is required if you want to use the default data collator. - max_completion_length (`int` or `None`, *optional*, defaults to `None`): - Maximum length of the target. This argument is required if you want to use the default data collator and - your model is an encoder-decoder. - is_encoder_decoder(`Optional[int]`, *optional*, defaults to `None`): - When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, - you need to specify if the model returned by the callable is an encoder-decoder model. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model and reference model. - generate_during_eval (`bool`, *optional*, defaults to `False`): - If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during - evaluation. - precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): - Whether to precompute reference model log probabilities for training and evaluation datasets. This is - useful when training without the reference model to reduce the total GPU memory needed. - precompute_ref_batch_size (`int` or `None`, *optional*, defaults to `None`): - Batch size to use when precomputing reference model log probabilities. This can be set higher than the - training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for - training and `per_device_eval_batch_size` for evaluation. - dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): - Number of processes to use for processing the dataset. - model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a - string. - ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model - from a string. - model_adapter_name (`str` or `None`, *optional*, defaults to `None`): - Name of the train target PEFT adapter, when using LoRA with multiple adapters. - ref_adapter_name (`str` or `None`, *optional*, defaults to `None`): - Name of the reference PEFT adapter, when using LoRA with multiple adapters. - reference_free (`bool`, *optional*, defaults to `False`): - If `True`, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal - probability to all responses. - force_use_ref_model (`bool`, *optional*, defaults to `False`): - In case one passes a PEFT model for the active model and you want to use a different model for the - ref_model, set this flag to `True`. + + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher β means less deviation from the + reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in + the [paper](https://huggingface.co/papers/2310.12036). f_divergence_type (`str`, *optional*, defaults to `FDivergenceType.REVERSE_KL`): Type of f-divergence regularization function to compute divergence between policy and reference model. f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`): α coefficient in the α-divergence u^-α regularization function for DPO loss. + reference_free (`bool`, *optional*, defaults to `False`): + Whether to ignore the provided reference model and implicitly use a reference model that assigns equal + probability to all responses. + label_smoothing (`float`, *optional*, defaults to `0.0`): + Robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report and + [Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`. + use_weighting (`bool`, *optional*, defaults to `False`): + Whether to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. + rpo_alpha (`float`, *optional*, defaults to `None`): + α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the + weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the + DPO loss. The paper recommends `rpo_alpha=1.0`. + discopop_tau (`float`, *optional*, defaults to `0.05`): + τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls + the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. sync_ref_model (`bool`, *optional*, defaults to `False`): - When set to `True`, the reference model is synchronized with the active model every `ref_model_sync_steps` - steps, using the `ref_model_mixup_alpha` parameter. This synchronization originites from the + Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using + the `ref_model_mixup_alpha` parameter. This synchronization originites from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper. ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`): α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix between the current policy and the previous reference policy during updates. The reference policy is - updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev` - To use this parameter, you must set `sync_ref_model=True`. + updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you + must set `sync_ref_model=True`. ref_model_sync_steps (`int`, *optional*, defaults to `64`): τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how frequently the current policy is synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`. - rpo_alpha (`float`, *optional*, defaults to `None`): - α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the - weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the - DPO loss. The paper recommends `rpo_alpha=1.0`. - discopop_tau (`float`, *optional*, defaults to `0.05`): - τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls - the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. - use_num_logits_to_keep (`bool`, *optional*, defaults to `False`): - If `True`, only a specified number of logits are computed in the forward pass of CausalLM. This can be - useful for saving memory and speeding up training by not computing the logits for all tokens, especially in - scenarios when working with very long prompts where labels are ignored (-100). - [Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM) - padding_free (`bool`, *optional*, defaults to `False`): - Whether forward passes are performed without padding by flattening all sequences in the batch - into a single continuous sequence. This approach requires associating a `position_ids` vector to track - positional information. Currently, this is only supported with the `flash_attention_2` mechanism, as it - can handle the flattened batch structure. - use_liger_loss (`bool`, *optional*, defaults to `False`): - Whether to use Liger loss. - base_model_attribute_name (`str`, *optional*, defaults to `"model"`): - Name of the attribute in the model that contains the base model. This is used to get the base model - from the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is - `True`. + + > Parameters that control the logging + + generate_during_eval (`bool`, *optional*, defaults to `False`): + Whether to generate and log completions from both the model and the reference model to W&B or Comet during + evaluation. """ - learning_rate: float = field( - default=1e-6, + # Parameters that control the model and reference model + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, metadata={ - "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " - "`transformers.TrainingArguments`." + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `DPOTrainer` is provided as a string." }, ) - beta: float = field( - default=0.1, + ref_model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, metadata={ - "help": "Parameter controlling the deviation from the reference model. " - "Higher β means less deviation from the reference model." + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument " + "of the `DPOTrainer` is provided as a string." }, ) - label_smoothing: float = field( - default=0.0, - metadata={"help": "Label smoothing factor."}, + model_adapter_name: Optional[str] = field( + default=None, + metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."}, ) - loss_type: str = field( - default="sigmoid", + ref_adapter_name: Optional[str] = field( + default=None, + metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."}, + ) + force_use_ref_model: bool = field( + default=False, metadata={ - "help": "Type of loss to use.", - "choices": [ - "sigmoid", - "hinge", - "ipo", - "exo_pair", - "nca_pair", - "robust", - "bco_pair", - "sppo_hard", - "aot", - "aot_pair", - "discopop", - "apo_zero", - "apo_down", - ], + "help": "If you provide a PEFT model as the active model and wish to use a different model for the " + "`ref_model`, set this flag to `True`." }, ) - use_weighting: bool = field( - default=False, - metadata={"help": "Whether to weight the loss as done in the WPO paper."}, + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model and reference model."}, ) - label_pad_token_id: int = field( - default=-100, + use_num_logits_to_keep: bool = field( + default=False, metadata={ - "help": "Label pad token id. This argument is required if you want to use the default data collator." + "help": "If `True`, only a specified number of logits are computed in the forward pass. This can be " + "useful for saving memory and speeding up training by not computing the logits for all tokens, especially " + "in scenarios when working with very long prompts where labels are ignored (-100)." }, ) + + # Parameters that control the data preprocessing + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) padding_value: Optional[int] = field( default=None, metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, ) + label_pad_token_id: int = field( + default=-100, + metadata={"help": "Padding value to use for labels."}, + ) truncation_mode: str = field( default="keep_end", metadata={ - "help": "Truncation mode to use when the prompt is too long. This argument is required if you want to use " - "the default data collator.", + "help": "Truncation mode to use when the prompt is too long.", "choices": ["keep_end", "keep_start"], }, ) - max_length: Optional[int] = field( - default=None, - metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, - ) max_prompt_length: Optional[int] = field( default=None, - metadata={ - "help": "Maximum length of the prompt. This argument is required if you want to use the default data " - "collator and your model is an encoder-decoder." - }, + metadata={"help": "Maximum length of the prompt."}, ) max_completion_length: Optional[int] = field( default=None, - metadata={ - "help": "Maximum length of the completion. This argument is required if you want to use the default data " - "collator and your model is an encoder-decoder." - }, + metadata={"help": "Maximum length of the completion."}, ) - is_encoder_decoder: Optional[bool] = field( + max_length: Optional[int] = field( default=None, - metadata={ - "help": "When using the `model_init` argument (callable) to instantiate the model instead of the " - "`model` argument, you need to specify if the model returned by the callable is an encoder-decoder model." - }, - ) - disable_dropout: bool = field( - default=True, - metadata={"help": "Whether to disable dropout in the model and reference model."}, + metadata={"help": "Maximum length of the full sequence (prompt + completion)."}, ) - generate_during_eval: bool = field( + padding_free: bool = field( default=False, metadata={ - "help": "If `True`, generates and logs completions from both the model and the reference model " - "to W&B during evaluation." + "help": "Whether forward passes are performed without padding by flattening all sequences in the batch " + "into a single continuous sequence. This approach requires associating a `position_ids` vector to track " + "positional information. Currently, this is only supported with the `flash_attention_2` mechanism, as it " + "can handle the flattened batch structure." }, ) precompute_ref_log_probs: bool = field( default=False, metadata={ - "help": "Whether to precompute reference model log probabilities for training and evaluation datasets. " - "This is useful when training without the reference model to reduce the total GPU memory needed." + "help": "Whether to precompute the log probabilities from the reference model. Setting this to `True` " + "allows training without needing the reference model during training, which can help reduce GPU memory " + "usage. If set to `False` (default), the reference model will be used during training to compute log " + "probabilities on-the-fly." }, ) precompute_ref_batch_size: Optional[int] = field( @@ -270,44 +267,54 @@ class DPOConfig(TrainingArguments): "`per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation." }, ) - dataset_num_proc: Optional[int] = field( - default=None, - metadata={"help": "Number of processes to use for processing the dataset."}, + use_liger_loss: bool = field( + default=False, + metadata={"help": "Whether to use Liger loss."}, ) - model_init_kwargs: Optional[dict[str, Any]] = field( - default=None, + base_model_attribute_name: str = field( + default="model", metadata={ - "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " - "model from a string." + "help": "Name of the attribute in the model that contains the base model. This is used to get the base model " + "from the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is " + "`True`." }, ) - ref_model_init_kwargs: Optional[dict[str, Any]] = field( - default=None, + + + # Parameters that control the training + learning_rate: float = field( + default=1e-6, metadata={ - "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " - "reference model from a string." + "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " + "`transformers.TrainingArguments`." }, ) - model_adapter_name: Optional[str] = field( - default=None, - metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."}, - ) - ref_adapter_name: Optional[str] = field( - default=None, - metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."}, - ) - reference_free: bool = field( - default=False, + loss_type: str = field( + default="sigmoid", metadata={ - "help": "If `True`, we ignore the _provided_ reference model and implicitly use a reference model that " - "assigns equal probability to all responses." + "help": "Type of loss to use.", + "choices": [ + "sigmoid", + "hinge", + "ipo", + "exo_pair", + "nca_pair", + "robust", + "bco_pair", + "sppo_hard", + "aot", + "aot_pair", + "discopop", + "apo_zero", + "apo_down", + ], }, ) - force_use_ref_model: bool = field( - default=False, + beta: float = field( + default=0.1, metadata={ - "help": "In case one passes a PEFT model for the active model and you want to use a different model for " - "the ref_model, set this flag to `True`." + "help": "Parameter controlling the deviation from the reference model. " + "Higher β means less deviation from the reference model." }, ) f_divergence_type: FDivergenceType = field( @@ -321,27 +328,23 @@ class DPOConfig(TrainingArguments): default=1.0, metadata={"help": "α coefficient in the α-divergence u^-α regularization function for DPO loss."}, ) - sync_ref_model: bool = field( + reference_free: bool = field( default=False, metadata={ - "help": "When set to `True`, the reference model is synchronized with the active model every " - "`ref_model_sync_steps` steps, using the `ref_model_mixup_alpha` parameter." + "help": "Whether to ignore the provided reference model and implicitly use a reference model that assigns " + "equal probability to all responses." }, ) - ref_model_mixup_alpha: float = field( - default=0.9, + label_smoothing: float = field( + default=0.0, metadata={ - "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the " - "previous reference policy during updates. The reference policy is updated according to the equation: " - "`π_ref = α * π_θ + (1 - α) * π_ref_prev`" + "help": "Robust DPO label smoothing parameter from the cDPO report and Robust DPO paper that should " + "be between `0.0` and `0.5`." }, ) - ref_model_sync_steps: int = field( - default=64, - metadata={ - "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is " - "synchronized with the reference policy." - }, + use_weighting: bool = field( + default=False, + metadata={"help": "Whether to weight the loss as done in the WPO paper."}, ) rpo_alpha: Optional[float] = field( default=None, @@ -358,30 +361,49 @@ class DPOConfig(TrainingArguments): "loss. The paper recommends the default value `discopop_tau=0.05`." }, ) - use_num_logits_to_keep: bool = field( + sync_ref_model: bool = field( default=False, metadata={ - "help": "If `True`, only a specified number of logits are computed in the forward pass of CausalLM. " - "This can be useful for saving memory and speeding up training by not computing the logits for all " - "tokens, especially in scenarios when working with very long prompts where labels are ignored (-100)." + "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " + "steps, using the `ref_model_mixup_alpha` parameter." }, ) - padding_free: bool = field( - default=False, + ref_model_mixup_alpha: float = field( + default=0.9, metadata={ - "help": "Whether the forward passes are performed without padding, i.e. flattening all the samples in the " - "batch into a single sample, associated with a position_ids vector. Only possible with flash-attention." + "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the " + "previous reference policy during updates. The reference policy is updated according to the equation: " + "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." }, ) - use_liger_loss: bool = field( - default=False, - metadata={"help": "Whether to use Liger loss."}, + ref_model_sync_steps: int = field( + default=64, + metadata={ + "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is " + "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." + }, ) - base_model_attribute_name: str = field( - default="model", + + # Parameters that control the logging + generate_during_eval: bool = field( + default=False, metadata={ - "help": "Name of the attribute in the model that contains the base model. This is used to get the base model " - "from the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is " - "`True`." + "help": "Whether to generate and log completions from both the model and the reference model to W&B or " + "Comet during evaluation." }, ) + + # Deprecated parameters + is_encoder_decoder: Optional[bool] = field( + default=None, + metadata={"help": "Deprecated. This argument is not used anymore."}, + ) + + def __post_init__(self): + if self.is_encoder_decoder is not None: + warnings.warn( + "The `is_encoder_decoder` parameter is deprecated will be removed in version 0.15. The trainer now " + "automatically determines if the model is an encoder-decoder, so you can safely remove it." + ) + + return super().__post_init__() \ No newline at end of file From 8ae06b1da2342f376044707af881556a6f8ca2bf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 15 Jan 2025 12:19:11 +0100 Subject: [PATCH 04/17] fix comment --- tests/test_dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index e908879e04d..20ea8db3f64 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1204,7 +1204,7 @@ def test_padding_free(self): @require_liger_kernel def test_dpo_trainer_with_liger(self): - """Test ORPO trainer with Liger loss enabled. + """Test DPO trainer with Liger loss enabled. This test verifies that: 1. Training runs successfully with Liger loss From cc2b7b9af9a1bda0cca8994b3087173c22318112 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 15 Jan 2025 15:22:16 +0100 Subject: [PATCH 05/17] fix peft training --- trl/trainer/dpo_trainer.py | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index a2d55c775f1..60bb20db7e6 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1169,6 +1169,21 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to use_cache=False, ) ref_hidden_states = ref_decoder_outputs.last_hidden_state + elif not self.reference_free: + with self.null_ref_context(): + ref_encoder_outputs = model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + ref_decoder_outputs = model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=ref_encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + ref_hidden_states = ref_decoder_outputs.last_hidden_state labels = concatenated_batch["completion_input_ids"] else: @@ -1210,6 +1225,19 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to **model_kwargs, ) ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + elif not self.reference_free: + if hasattr(model, "get_decoder"): + ref_base_model = model.get_decoder() + else: + ref_base_model = getattr(model, self.args.base_model_attribute_name, model) + with self.null_ref_context(): + ref_outputs = ref_base_model( + input_ids, + attention_mask=attention_mask, + use_cache=False, + **model_kwargs, + ) + ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] labels = input_ids[:, 1:] # Shift right for casual LM @@ -1219,8 +1247,12 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to # Get reference model weights if needed ref_weight = None ref_bias = None - if not self.reference_free and self.ref_model is not None: - ref_lm_head = self.ref_model.get_output_embeddings() + if not self.reference_free: + if self.ref_model is not None: + ref_lm_head = self.ref_model.get_output_embeddings() + else: + with self.null_ref_context(): + ref_lm_head = model.get_output_embeddings() ref_weight = ref_lm_head.weight ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None From 03fd0052cf63b16bdfb818e19221fd6ad5e31cb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 17 Jan 2025 16:31:28 +0000 Subject: [PATCH 06/17] use parametrized --- tests/test_dpo_trainer.py | 106 +++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 54 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 20ea8db3f64..69a64e6a603 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1203,7 +1203,8 @@ def test_padding_free(self): self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) @require_liger_kernel - def test_dpo_trainer_with_liger(self): + @parameterized.expand([(0.1,), (0.5,)]) + def test_dpo_trainer_with_liger(self, beta): """Test DPO trainer with Liger loss enabled. This test verifies that: @@ -1212,66 +1213,63 @@ def test_dpo_trainer_with_liger(self): 3. Loss values are reasonable and finite 4. Training works with both default and custom beta values """ - beta_values = [0.1, 0.5] # Test multiple beta values - - for beta in beta_values: - with tempfile.TemporaryDirectory() as tmp_dir: - training_args = DPOConfig( - output_dir=tmp_dir, - per_device_train_batch_size=2, - max_steps=3, - remove_unused_columns=False, - gradient_accumulation_steps=1, - learning_rate=9e-1, - eval_strategy="steps", - beta=beta, - use_liger_loss=True, # Enable Liger loss - report_to="none", - ) + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=beta, + use_liger_loss=True, # Enable Liger loss + report_to="none", + ) - dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") - trainer = DPOTrainer( - model=self.model, - ref_model=self.ref_model, # Add reference model - args=training_args, - processing_class=self.tokenizer, - train_dataset=dummy_dataset["train"], - eval_dataset=dummy_dataset["test"], - ) + trainer = DPOTrainer( + model=self.model, + ref_model=self.ref_model, # Add reference model + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) - # Store initial parameters - previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + # Store initial parameters + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} - # Train the model - train_output = trainer.train() + # Train the model + train_output = trainer.train() - # Verify training completed successfully - self.assertIsNotNone(train_output) - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + # Verify training completed successfully + self.assertIsNotNone(train_output) + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Verify loss is finite - self.assertTrue(np.isfinite(trainer.state.log_history[-1]["train_loss"])) + # Verify loss is finite + self.assertTrue(np.isfinite(trainer.state.log_history[-1]["train_loss"])) - # Check parameters have been updated - for n, param in previous_trainable_params.items(): - new_param = trainer.model.get_parameter(n) - # Only check non-zero parameters - if param.sum() != 0: - self.assertFalse(torch.equal(param, new_param)) - # Verify new parameters are finite - self.assertTrue(torch.isfinite(new_param).all()) - - # Verify model can still do forward pass after training - dummy_batch = next(iter(trainer.get_train_dataloader())) - model_inputs = { - "input_ids": dummy_batch["prompt_input_ids"], - "attention_mask": dummy_batch["prompt_attention_mask"], - } - with torch.no_grad(): - output = trainer.model(**model_inputs) - self.assertIsNotNone(output) - self.assertIsNone(output.loss) + # Check parameters have been updated + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # Only check non-zero parameters + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + # Verify new parameters are finite + self.assertTrue(torch.isfinite(new_param).all()) + + # Verify model can still do forward pass after training + dummy_batch = next(iter(trainer.get_train_dataloader())) + model_inputs = { + "input_ids": dummy_batch["prompt_input_ids"], + "attention_mask": dummy_batch["prompt_attention_mask"], + } + with torch.no_grad(): + output = trainer.model(**model_inputs) + self.assertIsNotNone(output) + self.assertIsNone(output.loss) @require_vision From 5f4110fda8bcb5e2c0d19610694cf2e00eefecc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 17 Jan 2025 16:42:04 +0000 Subject: [PATCH 07/17] raise error as soon as dep is not met --- trl/trainer/dpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 60bb20db7e6..b221bfd9f02 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -398,9 +398,9 @@ def make_inputs_require_grad(module, input, output): disable_dropout_in_model(self.ref_model) # Liger kernel - if args.use_liger_loss and args.loss_type == "sigmoid": + if args.use_liger_loss: if not is_liger_kernel_available(): - raise ValueError( + raise ImportError( "You set `use_liger_loss=True` but the liger kernel is not available. " "Please install liger-kernel first: `pip install liger-kernel`" ) From b22eb24a0e637c7ac3f50f831d85981b3ac1ebf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 17 Jan 2025 16:42:22 +0000 Subject: [PATCH 08/17] move param to the right section --- trl/trainer/dpo_config.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 7f389f4327f..4d7c8103bd0 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -93,12 +93,6 @@ class DPOConfig(TrainingArguments): Batch size to use when precomputing reference model log probabilities. This can be set higher than the training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation. - use_liger_loss (`bool`, *optional*, defaults to `False`): - Whether to use Liger loss. - base_model_attribute_name (`str`, *optional*, defaults to `"model"`): - Name of the attribute in the model that contains the base model. This is used to get the base model - from the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is - `True`. > Parameters that control the training @@ -122,6 +116,11 @@ class DPOConfig(TrainingArguments): - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. + use_liger_loss (`bool`, *optional*, defaults to `False`): + Whether to use Liger loss. + base_model_attribute_name (`str`, *optional*, defaults to `"model"`): + Name of the attribute in the model that contains the base model. This is used to get the base model from + the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`. beta (`float`, *optional*, defaults to `0.1`): Parameter controlling the deviation from the reference model. Higher β means less deviation from the reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in @@ -267,18 +266,6 @@ class DPOConfig(TrainingArguments): "`per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation." }, ) - use_liger_loss: bool = field( - default=False, - metadata={"help": "Whether to use Liger loss."}, - ) - base_model_attribute_name: str = field( - default="model", - metadata={ - "help": "Name of the attribute in the model that contains the base model. This is used to get the base model " - "from the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is " - "`True`." - }, - ) # Parameters that control the training learning_rate: float = field( @@ -309,6 +296,18 @@ class DPOConfig(TrainingArguments): ], }, ) + use_liger_loss: bool = field( + default=False, + metadata={"help": "Whether to use Liger loss."}, + ) + base_model_attribute_name: str = field( + default="model", + metadata={ + "help": "Name of the attribute in the model that contains the base model. This is used to get the base " + "model from the model when the model does not have a `get_decoder` method in the case when " + "`use_liger_loss` is `True`." + }, + ) beta: float = field( default=0.1, metadata={ From b8e6f8c220821cc17d997c0fec5ab61b3f90c5c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 17 Jan 2025 16:56:24 +0000 Subject: [PATCH 09/17] reducing memory doc --- docs/source/reducing_memory_usage.md | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index cf277c1e03f..0051f7ece18 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -16,7 +16,7 @@ Sequence lengths in the dataset can vary widely, and by default, TRL does not mo To reduce memory usage, it’s important to truncate sequences to a reasonable length. Even discarding just a few tokens from the dataset can result in significant memory savings by minimizing unnecessary padding. Truncation is a good practice and should always be applied to ensure efficient use of resources. While the truncation limit doesn’t need to be overly restrictive, setting a sensible value is essential for optimal performance. - + DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence. @@ -84,4 +84,22 @@ training_args = SFTConfig(..., packing=True, max_seq_length=512) Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see [#1230](https://github.com/huggingface/trl/issues/1230). - \ No newline at end of file + + +## Liger for reducing peak memory usage + +[To complete] + + + + +To use Liger for reducing peak memory usage, use the following code snippet: + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., use_liger_loss=True) +``` + + + From 6310dbd067771d41e5a3c7ffa67f89306e2e5e99 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 21 Jan 2025 14:57:04 +0100 Subject: [PATCH 10/17] use liger specifc method --- trl/trainer/dpo_trainer.py | 553 +++++++++++++++++++------------------ 1 file changed, 282 insertions(+), 271 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index b221bfd9f02..75b3a8e4b84 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1111,13 +1111,7 @@ def dpo_loss( return losses, chosen_rewards, rejected_rewards - def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]): - """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. - - We do this to avoid doing two forward passes, because it's faster for FSDP. - """ - num_examples = batch["prompt_input_ids"].shape[0] - + def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]): concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) model_kwargs = {} @@ -1132,36 +1126,50 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to if "image_sizes" in concatenated_batch: model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] - if self.args.use_liger_loss and self.loss_type == "sigmoid": - if self.is_encoder_decoder: - # 1. Get encoder outputs - encoder_outputs = model.get_encoder()( + if self.is_encoder_decoder: + # 1. Get encoder outputs + encoder_outputs = model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + # 2. Prepare decoder inputs + decoder_input_ids = shift_tokens_right( + concatenated_batch["completion_input_ids"], self.padding_value, model.config.decoder_start_token_id + ) + # 3. Get decoder outputs + decoder_outputs = model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + hidden_states = decoder_outputs.last_hidden_state + + ref_hidden_states = None + if not self.reference_free and self.ref_model is not None: + ref_encoder_outputs = self.ref_model.get_encoder()( concatenated_batch["prompt_input_ids"], attention_mask=concatenated_batch["prompt_attention_mask"], return_dict=True, ) - # 2. Prepare decoder inputs - decoder_input_ids = shift_tokens_right( - concatenated_batch["completion_input_ids"], self.padding_value, model.config.decoder_start_token_id - ) - # 3. Get decoder outputs - decoder_outputs = model.get_decoder()( + ref_decoder_outputs = self.ref_model.get_decoder()( input_ids=decoder_input_ids, attention_mask=concatenated_batch["completion_attention_mask"], - encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_hidden_states=ref_encoder_outputs.last_hidden_state, encoder_attention_mask=concatenated_batch["prompt_attention_mask"], use_cache=False, ) - hidden_states = decoder_outputs.last_hidden_state - - ref_hidden_states = None - if not self.reference_free and self.ref_model is not None: - ref_encoder_outputs = self.ref_model.get_encoder()( + ref_hidden_states = ref_decoder_outputs.last_hidden_state + elif not self.reference_free: + with self.null_ref_context(): + ref_encoder_outputs = model.get_encoder()( concatenated_batch["prompt_input_ids"], attention_mask=concatenated_batch["prompt_attention_mask"], return_dict=True, ) - ref_decoder_outputs = self.ref_model.get_decoder()( + ref_decoder_outputs = model.get_decoder()( input_ids=decoder_input_ids, attention_mask=concatenated_batch["completion_attention_mask"], encoder_hidden_states=ref_encoder_outputs.last_hidden_state, @@ -1169,55 +1177,53 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to use_cache=False, ) ref_hidden_states = ref_decoder_outputs.last_hidden_state - elif not self.reference_free: - with self.null_ref_context(): - ref_encoder_outputs = model.get_encoder()( - concatenated_batch["prompt_input_ids"], - attention_mask=concatenated_batch["prompt_attention_mask"], - return_dict=True, - ) - ref_decoder_outputs = model.get_decoder()( - input_ids=decoder_input_ids, - attention_mask=concatenated_batch["completion_attention_mask"], - encoder_hidden_states=ref_encoder_outputs.last_hidden_state, - encoder_attention_mask=concatenated_batch["prompt_attention_mask"], - use_cache=False, - ) - ref_hidden_states = ref_decoder_outputs.last_hidden_state - labels = concatenated_batch["completion_input_ids"] + labels = concatenated_batch["completion_input_ids"] + else: + # For decoder-only models + input_ids = torch.cat( + (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1 + ) + attention_mask = torch.cat( + (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]), + dim=1, + ) + + # Get the base model outputs (before LM head) + if hasattr(model, "get_decoder"): + base_model = model.get_decoder() else: - # For decoder-only models - input_ids = torch.cat( - (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1 - ) - attention_mask = torch.cat( - (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]), - dim=1, - ) + base_model = getattr(model, self.args.base_model_attribute_name, model) - # Get the base model outputs (before LM head) - if hasattr(model, "get_decoder"): - base_model = model.get_decoder() + outputs = base_model( + input_ids, + attention_mask=attention_mask, + use_cache=False, + **model_kwargs, + ) + hidden_states = outputs.last_hidden_state[:, :-1] + + # Get reference hidden states if needed + ref_hidden_states = None + if not self.reference_free and self.ref_model is not None: + if hasattr(self.ref_model, "get_decoder"): + ref_base_model = self.ref_model.get_decoder() else: - base_model = getattr(model, self.args.base_model_attribute_name, model) + ref_base_model = getattr(self.ref_model, self.args.base_model_attribute_name, self.ref_model) - outputs = base_model( + ref_outputs = ref_base_model( input_ids, attention_mask=attention_mask, use_cache=False, **model_kwargs, ) - hidden_states = outputs.last_hidden_state[:, :-1] - - # Get reference hidden states if needed - ref_hidden_states = None - if not self.reference_free and self.ref_model is not None: - if hasattr(self.ref_model, "get_decoder"): - ref_base_model = self.ref_model.get_decoder() - else: - ref_base_model = getattr(self.ref_model, self.args.base_model_attribute_name, self.ref_model) - + ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + elif not self.reference_free: + if hasattr(model, "get_decoder"): + ref_base_model = model.get_decoder() + else: + ref_base_model = getattr(model, self.args.base_model_attribute_name, model) + with self.null_ref_context(): ref_outputs = ref_base_model( input_ids, attention_mask=attention_mask, @@ -1225,226 +1231,230 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to **model_kwargs, ) ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] - elif not self.reference_free: - if hasattr(model, "get_decoder"): - ref_base_model = model.get_decoder() - else: - ref_base_model = getattr(model, self.args.base_model_attribute_name, model) - with self.null_ref_context(): - ref_outputs = ref_base_model( - input_ids, - attention_mask=attention_mask, - use_cache=False, - **model_kwargs, - ) - ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] - labels = input_ids[:, 1:] # Shift right for casual LM + labels = input_ids[:, 1:] # Shift right for casual LM - # Get the LM head - lm_head = model.get_output_embeddings() - - # Get reference model weights if needed - ref_weight = None - ref_bias = None - if not self.reference_free: - if self.ref_model is not None: - ref_lm_head = self.ref_model.get_output_embeddings() - else: - with self.null_ref_context(): - ref_lm_head = model.get_output_embeddings() - ref_weight = ref_lm_head.weight - ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None - - # Compute loss using Liger kernel - loss_output = self.dpo_loss_fn( - lm_head.weight, - hidden_states, - labels, - bias=lm_head.bias if hasattr(lm_head, "bias") else None, - ref_input=ref_hidden_states if not self.reference_free else None, - ref_weight=ref_weight if not self.reference_free else None, - ref_bias=ref_bias if not self.reference_free else None, - ) - ( - loss, - (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, _, _, *aux_outputs), - ) = loss_output - - output = { - "loss": loss, - "chosen_logps": chosen_logps, - "rejected_logps": rejected_logps, - "mean_chosen_logits": chosen_logits_mean, - "mean_rejected_logits": rejected_logits_mean, - "nll_loss": nll_loss, - "chosen_rewards": aux_outputs[0], - "rejected_rewards": aux_outputs[1], - } - if self.aux_loss_enabled: - output["aux_loss"] = outputs.aux_loss + # Get the LM head + lm_head = model.get_output_embeddings() - return output - else: - prompt_input_ids = concatenated_batch["prompt_input_ids"] - prompt_attention_mask = concatenated_batch["prompt_attention_mask"] - completion_input_ids = concatenated_batch["completion_input_ids"] - completion_attention_mask = concatenated_batch["completion_attention_mask"] - if self.is_encoder_decoder: - labels = completion_input_ids - labels[completion_attention_mask == 0] = self.label_pad_token_id - outputs = model( - input_ids=prompt_input_ids, - attention_mask=prompt_attention_mask, - labels=labels, # we need the labels for the logits to be returned - **model_kwargs, - ) - logits = outputs.logits - loss_mask = completion_attention_mask.bool() + # Get reference model weights if needed + ref_weight = None + ref_bias = None + if not self.reference_free: + if self.ref_model is not None: + ref_lm_head = self.ref_model.get_output_embeddings() else: - # Concatenate the prompt and completion inputs - input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1) - attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1) - # Mask the prompt but not the completion for the loss - loss_mask = torch.cat( - (torch.zeros_like(prompt_attention_mask), completion_attention_mask), - dim=1, - ) + with self.null_ref_context(): + ref_lm_head = model.get_output_embeddings() + ref_weight = ref_lm_head.weight + ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None + + # Compute loss using Liger kernel + loss_output = self.dpo_loss_fn( + lm_head.weight, + hidden_states, + labels, + bias=lm_head.bias if hasattr(lm_head, "bias") else None, + ref_input=ref_hidden_states if not self.reference_free else None, + ref_weight=ref_weight if not self.reference_free else None, + ref_bias=ref_bias if not self.reference_free else None, + ) + ( + loss, + (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, _, _, *aux_outputs), + ) = loss_output - # Flush left to reduce the memory usage - # [[0, 0, x, x, x, x], -> [[x, x, x, x], - # [0, x, x, x, 0, 0]] [x, x, x, 0]] - for i in range(attention_mask.size(0)): - first_one_idx = torch.nonzero(attention_mask[i])[0].item() - input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) - attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) - loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) - - # Get the first column idx that is all zeros and remove every column after that - empty_cols = torch.sum(attention_mask, dim=0) == 0 - first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1) - input_ids = input_ids[:, :first_empty_col] - attention_mask = attention_mask[:, :first_empty_col] - loss_mask = loss_mask[:, :first_empty_col] - - # Truncate right - if self.args.max_length is not None: - input_ids = input_ids[:, : self.args.max_length] - attention_mask = attention_mask[:, : self.args.max_length] - loss_mask = loss_mask[:, : self.args.max_length] - - if self.use_num_logits_to_keep: - # Compute num_logits_to_keep based on loss_mask pattern: - # [[0, 0, 0, x, x, x, x], - # [0, 0, 0, x, x, x, 0]] - # ^ start computing logits from here ([:, -(7-3+1):]) - first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() - num_logits_to_keep = ( - loss_mask.shape[1] - first_compute_index - ).item() + 1 # +1 for the first label - model_kwargs["num_logits_to_keep"] = num_logits_to_keep - - if self.padding_free: - # Flatten the input_ids, position_ids, and loss_mask - # input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]] - # [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]] - input_ids = input_ids[attention_mask.bool()].unsqueeze(0) - loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) - position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 - model_kwargs["position_ids"] = position_ids - else: - model_kwargs["attention_mask"] = attention_mask - - outputs = model(input_ids, **model_kwargs) - logits = outputs.logits - - # Offset the logits by one to align with the labels - labels = torch.roll(input_ids, shifts=-1, dims=1) - loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() - - if self.use_num_logits_to_keep: - # Align labels with logits - # logits: -, -, [x2, x3, x4, x5, x6] - # ^ --------- ^ after logits[:, :-1, :] - # labels: [y0, y1, y2, y3, y4, y5, y6] - # ^ --------- ^ with num_logits_to_keep=4, [:, -4:] - # loss_mask: [0, 0, 0, 1, 1, 1, 1] - labels = labels[:, -num_logits_to_keep:] - loss_mask = loss_mask[:, -num_logits_to_keep:] - - if logits.shape[:2] != labels.shape[:2]: - # for llava, the returned logits include the image tokens (placed before the text tokens) - seq_len = labels.shape[1] - logits = logits[:, -seq_len:] + output = { + "loss": loss, + "chosen_logps": chosen_logps, + "rejected_logps": rejected_logps, + "mean_chosen_logits": chosen_logits_mean, + "mean_rejected_logits": rejected_logits_mean, + "nll_loss": nll_loss, + "chosen_rewards": aux_outputs[0], + "rejected_rewards": aux_outputs[1], + } + if self.aux_loss_enabled: + output["aux_loss"] = outputs.aux_loss - # Compute the log probabilities of the labels - labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - per_token_logps[~loss_mask] = 0 - per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) + return output - if self.padding_free: - # Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len]) - batch_size, seq_len = attention_mask.shape - per_token_logps_ = torch.zeros( - batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype - ) - per_token_logps_[attention_mask.bool()] = per_token_logps - per_token_logps = per_token_logps_ - - all_logps = per_token_logps.sum(-1) - - output = {} - - if self.use_weighting: - with torch.no_grad(): - # Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827 - logprobs = F.log_softmax(logits, dim=-1) - weights_adjustment_factor = torch.logsumexp( - 2 * logprobs, dim=-1 - ) # same as sum(probs**2) in log space - per_token_logps_adjusted = per_token_logps - weights_adjustment_factor - all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1) - chosen_weights = all_weights[:num_examples] - rejected_weights = all_weights[num_examples:] - output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) - - if self.args.rpo_alpha is not None: - # Only use the chosen logits for the RPO loss - chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples] - chosen_labels = labels[:num_examples, 1:] if self.is_encoder_decoder else labels[:num_examples] - - # Compute the log probabilities of the labels - output["nll_loss"] = F.cross_entropy( - torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0 - ) + def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]): + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + num_examples = batch["prompt_input_ids"].shape[0] + + concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) + + model_kwargs = {} + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + # Add the pixel values and attention masks for vision models + if "pixel_values" in concatenated_batch: + model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] + if "pixel_attention_mask" in concatenated_batch: + model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] + if "image_sizes" in concatenated_batch: + model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] - if self.loss_type == "ipo": - all_logps = all_logps / loss_mask.sum(-1) + prompt_input_ids = concatenated_batch["prompt_input_ids"] + prompt_attention_mask = concatenated_batch["prompt_attention_mask"] + completion_input_ids = concatenated_batch["completion_input_ids"] + completion_attention_mask = concatenated_batch["completion_attention_mask"] + if self.is_encoder_decoder: + labels = completion_input_ids + labels[completion_attention_mask == 0] = self.label_pad_token_id + outputs = model( + input_ids=prompt_input_ids, + attention_mask=prompt_attention_mask, + labels=labels, # we need the labels for the logits to be returned + **model_kwargs, + ) + logits = outputs.logits + loss_mask = completion_attention_mask.bool() + else: + # Concatenate the prompt and completion inputs + input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1) + attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1) + # Mask the prompt but not the completion for the loss + loss_mask = torch.cat( + (torch.zeros_like(prompt_attention_mask), completion_attention_mask), + dim=1, + ) - output["chosen_logps"] = all_logps[:num_examples] - output["rejected_logps"] = all_logps[num_examples:] + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + for i in range(attention_mask.size(0)): + first_one_idx = torch.nonzero(attention_mask[i])[0].item() + input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) + attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) + loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) + + # Get the first column idx that is all zeros and remove every column after that + empty_cols = torch.sum(attention_mask, dim=0) == 0 + first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1) + input_ids = input_ids[:, :first_empty_col] + attention_mask = attention_mask[:, :first_empty_col] + loss_mask = loss_mask[:, :first_empty_col] + + # Truncate right + if self.args.max_length is not None: + input_ids = input_ids[:, : self.args.max_length] + attention_mask = attention_mask[:, : self.args.max_length] + loss_mask = loss_mask[:, : self.args.max_length] + + if self.use_num_logits_to_keep: + # Compute num_logits_to_keep based on loss_mask pattern: + # [[0, 0, 0, x, x, x, x], + # [0, 0, 0, x, x, x, 0]] + # ^ start computing logits from here ([:, -(7-3+1):]) + first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() + num_logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label + model_kwargs["num_logits_to_keep"] = num_logits_to_keep - # Compute the mean logits if self.padding_free: - # position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]). - # There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens, - # and the second half to the rejected tokens. - # To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id. - split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples] - mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean() - mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean() + # Flatten the input_ids, position_ids, and loss_mask + # input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]] + # [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]] + input_ids = input_ids[attention_mask.bool()].unsqueeze(0) + loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) + position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 + model_kwargs["position_ids"] = position_ids else: - mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() - mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean() + model_kwargs["attention_mask"] = attention_mask + + outputs = model(input_ids, **model_kwargs) + logits = outputs.logits + + # Offset the logits by one to align with the labels + labels = torch.roll(input_ids, shifts=-1, dims=1) + loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() + + if self.use_num_logits_to_keep: + # Align labels with logits + # logits: -, -, [x2, x3, x4, x5, x6] + # ^ --------- ^ after logits[:, :-1, :] + # labels: [y0, y1, y2, y3, y4, y5, y6] + # ^ --------- ^ with num_logits_to_keep=4, [:, -4:] + # loss_mask: [0, 0, 0, 1, 1, 1, 1] + labels = labels[:, -num_logits_to_keep:] + loss_mask = loss_mask[:, -num_logits_to_keep:] + + if logits.shape[:2] != labels.shape[:2]: + # for llava, the returned logits include the image tokens (placed before the text tokens) + seq_len = labels.shape[1] + logits = logits[:, -seq_len:] + + # Compute the log probabilities of the labels + labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps[~loss_mask] = 0 + per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) + + if self.padding_free: + # Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len]) + batch_size, seq_len = attention_mask.shape + per_token_logps_ = torch.zeros( + batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype + ) + per_token_logps_[attention_mask.bool()] = per_token_logps + per_token_logps = per_token_logps_ + + all_logps = per_token_logps.sum(-1) + + output = {} + + if self.use_weighting: + with torch.no_grad(): + # Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827 + logprobs = F.log_softmax(logits, dim=-1) + weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1) # same as sum(probs**2) in log space + per_token_logps_adjusted = per_token_logps - weights_adjustment_factor + all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1) + chosen_weights = all_weights[:num_examples] + rejected_weights = all_weights[num_examples:] + output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) - output["mean_chosen_logits"] = mean_chosen_logits - output["mean_rejected_logits"] = mean_rejected_logits + if self.args.rpo_alpha is not None: + # Only use the chosen logits for the RPO loss + chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples] + chosen_labels = labels[:num_examples, 1:] if self.is_encoder_decoder else labels[:num_examples] + + # Compute the log probabilities of the labels + output["nll_loss"] = F.cross_entropy( + torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0 + ) + + if self.loss_type == "ipo": + all_logps = all_logps / loss_mask.sum(-1) + + output["chosen_logps"] = all_logps[:num_examples] + output["rejected_logps"] = all_logps[num_examples:] + + # Compute the mean logits + if self.padding_free: + # position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]). + # There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens, + # and the second half to the rejected tokens. + # To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id. + split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples] + mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean() + mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean() + else: + mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() + mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean() - if self.aux_loss_enabled: - output["aux_loss"] = outputs.aux_loss + output["mean_chosen_logits"] = mean_chosen_logits + output["mean_rejected_logits"] = mean_rejected_logits - return output + if self.aux_loss_enabled: + output["aux_loss"] = outputs.aux_loss + + return output def get_batch_loss_metrics( self, @@ -1455,13 +1465,14 @@ def get_batch_loss_metrics( """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} - model_output = self.concatenated_forward(model, batch) - - if self.args.use_liger_loss: + if self.args.use_liger_loss and self.loss_type == "sigmoid": + model_output = self._compute_loss_liger(model, batch) losses = model_output["loss"] chosen_rewards = model_output["chosen_rewards"] rejected_rewards = model_output["rejected_rewards"] else: + model_output = self.concatenated_forward(model, batch) + # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: ref_chosen_logps = batch["ref_chosen_logps"] From dbece5404c2df02b3c5dcb4761d09d148cf9fef3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 22 Jan 2025 13:11:59 +0100 Subject: [PATCH 11/17] update return signature --- trl/trainer/dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 9fa3a711a3e..9502509095d 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1264,7 +1264,7 @@ def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, tor ) ( loss, - (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, _, _, *aux_outputs), + (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs), ) = loss_output output = { From 7fc0615994f96c312b1d9a39b259177a7549906b Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 12 Mar 2025 17:58:53 +0000 Subject: [PATCH 12/17] adding ref model ctx manager --- trl/trainer/dpo_trainer.py | 204 +++++++++++++++++++++++-------------- 1 file changed, 130 insertions(+), 74 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 9502509095d..206bcc8f9ce 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1149,22 +1149,21 @@ def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, tor use_cache=False, ) hidden_states = decoder_outputs.last_hidden_state - - ref_hidden_states = None + + ref_encoder_outputs = None if not self.reference_free and self.ref_model is not None: ref_encoder_outputs = self.ref_model.get_encoder()( concatenated_batch["prompt_input_ids"], attention_mask=concatenated_batch["prompt_attention_mask"], return_dict=True, ) - ref_decoder_outputs = self.ref_model.get_decoder()( - input_ids=decoder_input_ids, - attention_mask=concatenated_batch["completion_attention_mask"], - encoder_hidden_states=ref_encoder_outputs.last_hidden_state, - encoder_attention_mask=concatenated_batch["prompt_attention_mask"], - use_cache=False, - ) - ref_hidden_states = ref_decoder_outputs.last_hidden_state + # ref_decoder_outputs = self.ref_model.get_decoder()( + # input_ids=decoder_input_ids, + # attention_mask=concatenated_batch["completion_attention_mask"], + # encoder_hidden_states=ref_encoder_outputs.last_hidden_state, + # encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + # use_cache=False, + # ) elif not self.reference_free: with self.null_ref_context(): ref_encoder_outputs = model.get_encoder()( @@ -1172,15 +1171,20 @@ def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, tor attention_mask=concatenated_batch["prompt_attention_mask"], return_dict=True, ) - ref_decoder_outputs = model.get_decoder()( - input_ids=decoder_input_ids, - attention_mask=concatenated_batch["completion_attention_mask"], - encoder_hidden_states=ref_encoder_outputs.last_hidden_state, - encoder_attention_mask=concatenated_batch["prompt_attention_mask"], - use_cache=False, - ) - ref_hidden_states = ref_decoder_outputs.last_hidden_state - + # ref_decoder_outputs = model.get_decoder()( + # input_ids=decoder_input_ids, + # attention_mask=concatenated_batch["completion_attention_mask"], + # encoder_hidden_states=ref_encoder_outputs.last_hidden_state, + # encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + # use_cache=False, + # ) + + ref_model_inputs = { + "input_ids": decoder_input_ids, + "attention_mask": concatenated_batch["completion_attention_mask"], + "encoder_hidden_states": ref_encoder_outputs.last_hidden_state, + "encoder_attention_mask": concatenated_batch["prompt_attention_mask"], + } labels = concatenated_batch["completion_input_ids"] else: # For decoder-only models @@ -1206,66 +1210,73 @@ def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, tor ) hidden_states = outputs.last_hidden_state[:, :-1] - # Get reference hidden states if needed - ref_hidden_states = None - if not self.reference_free and self.ref_model is not None: - if hasattr(self.ref_model, "get_decoder"): - ref_base_model = self.ref_model.get_decoder() - else: - ref_base_model = getattr(self.ref_model, self.args.base_model_attribute_name, self.ref_model) - - ref_outputs = ref_base_model( - input_ids, - attention_mask=attention_mask, - use_cache=False, - **model_kwargs, - ) - ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] - elif not self.reference_free: - if hasattr(model, "get_decoder"): - ref_base_model = model.get_decoder() - else: - ref_base_model = getattr(model, self.args.base_model_attribute_name, model) - with self.null_ref_context(): - ref_outputs = ref_base_model( - input_ids, - attention_mask=attention_mask, - use_cache=False, - **model_kwargs, - ) - ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] - + # # Get reference hidden states if needed + # ref_hidden_states = None + # if not self.reference_free and self.ref_model is not None: + # if hasattr(self.ref_model, "get_decoder"): + # ref_base_model = self.ref_model.get_decoder() + # else: + # ref_base_model = getattr(self.ref_model, self.args.base_model_attribute_name, self.ref_model) + + # ref_outputs = ref_base_model( + # input_ids, + # attention_mask=attention_mask, + # use_cache=False, + # **model_kwargs, + # ) + # ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + # elif not self.reference_free: + # if hasattr(model, "get_decoder"): + # ref_base_model = model.get_decoder() + # else: + # ref_base_model = getattr(model, self.args.base_model_attribute_name, model) + # with self.null_ref_context(): + # ref_outputs = ref_base_model( + # input_ids, + # attention_mask=attention_mask, + # use_cache=False, + # **model_kwargs, + # ) + # ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + + ref_model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "use_cache": False, + **model_kwargs, + } labels = input_ids[:, 1:] # Shift right for casual LM # Get the LM head lm_head = model.get_output_embeddings() - # Get reference model weights if needed - ref_weight = None - ref_bias = None - if not self.reference_free: - if self.ref_model is not None: - ref_lm_head = self.ref_model.get_output_embeddings() - else: - with self.null_ref_context(): - ref_lm_head = model.get_output_embeddings() - ref_weight = ref_lm_head.weight - ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None - - # Compute loss using Liger kernel - loss_output = self.dpo_loss_fn( - lm_head.weight, - hidden_states, - labels, - bias=lm_head.bias if hasattr(lm_head, "bias") else None, - ref_input=ref_hidden_states if not self.reference_free else None, - ref_weight=ref_weight if not self.reference_free else None, - ref_bias=ref_bias if not self.reference_free else None, - ) - ( - loss, - (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs), - ) = loss_output + # # Get reference model weights if needed + # ref_weight = None + # ref_bias = None + # if not self.reference_free: + # if self.ref_model is not None: + # ref_lm_head = self.ref_model.get_output_embeddings() + # else: + # with self.null_ref_context(): + # ref_lm_head = model.get_output_embeddings() + # ref_weight = ref_lm_head.weight + # ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None + + with self.get_decoder_outputs_for_liger_loss(ref_model_inputs) as (ref_hidden_states, ref_weight, ref_bias): + # Compute loss using Liger kernel + loss_output = self.dpo_loss_fn( + lm_head.weight, + hidden_states, + labels, + bias=lm_head.bias if hasattr(lm_head, "bias") else None, + ref_input=ref_hidden_states, + ref_weight=ref_weight, + ref_bias=ref_bias, + ) + ( + loss, + (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs), + ) = loss_output output = { "loss": loss, @@ -1761,3 +1772,48 @@ def create_model_card( ) model_card.save(os.path.join(self.args.output_dir, "README.md")) + + @contextmanager + def get_decoder_outputs_for_liger_loss( + self, + ref_model_inputs: dict[str, Union[list, torch.LongTensor]], + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + if self.reference_free: + yield None, None, None + return + + if self.ref_model is None: + ref_model = self.model + context_manager = self.null_ref_context() + else: + ref_model = self.ref_model + context_manager = nullcontext() + + if self.is_encoder_decoder or hasattr(ref_model, "get_decoder"): + ref_base_model = ref_model.get_decoder() + else: + ref_base_model = getattr(ref_model, self.args.base_model_attribute_name, ref_model) + + with context_manager: + ref_outputs = ref_base_model( + **ref_model_inputs + ) + + # Get LM head and yield results + ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + + try: + if self.ref_model is None: + ref_lm_head = self.model.get_output_embeddings() + ref_lm_head.merge() + else: + ref_lm_head = self.ref_model.get_output_embeddings() + + yield ( + ref_hidden_states, + ref_lm_head.weight, + ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None + ) + finally: + if self.ref_model is None: + ref_lm_head.unmerge() \ No newline at end of file From 29fa096844223e526b9f071fd98ebe8b1f6de334 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 12 Mar 2025 18:00:37 +0000 Subject: [PATCH 13/17] removing comment --- trl/trainer/dpo_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 206bcc8f9ce..d65427b5a65 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1799,7 +1799,6 @@ def get_decoder_outputs_for_liger_loss( **ref_model_inputs ) - # Get LM head and yield results ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] try: From d1a4ebf00ba09af090ef389b8b68297c6c62ac3b Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 12 Mar 2025 18:09:04 +0000 Subject: [PATCH 14/17] only merging if peft has been applied --- trl/trainer/dpo_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index d65427b5a65..871bdcf5ad6 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -71,6 +71,7 @@ peft_module_casting_to_bf16, ) +from peft.tuners.tuners_utils import BaseTunerLayer if is_peft_available(): from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training @@ -1804,7 +1805,8 @@ def get_decoder_outputs_for_liger_loss( try: if self.ref_model is None: ref_lm_head = self.model.get_output_embeddings() - ref_lm_head.merge() + if isinstance(ref_lm_head, BaseTunerLayer): + ref_lm_head.merge() else: ref_lm_head = self.ref_model.get_output_embeddings() @@ -1814,5 +1816,5 @@ def get_decoder_outputs_for_liger_loss( ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None ) finally: - if self.ref_model is None: + if self.ref_model is None and isinstance(ref_lm_head, BaseTunerLayer): ref_lm_head.unmerge() \ No newline at end of file From f90b13902bd571cde023248e3db20bc59c9665a7 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 12 Mar 2025 18:12:41 +0000 Subject: [PATCH 15/17] simplifying --- trl/trainer/dpo_trainer.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 871bdcf5ad6..3bd2a615ebe 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1799,16 +1799,12 @@ def get_decoder_outputs_for_liger_loss( ref_outputs = ref_base_model( **ref_model_inputs ) - ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] try: - if self.ref_model is None: - ref_lm_head = self.model.get_output_embeddings() - if isinstance(ref_lm_head, BaseTunerLayer): - ref_lm_head.merge() - else: - ref_lm_head = self.ref_model.get_output_embeddings() + ref_lm_head = ref_model.get_output_embeddings() + if isinstance(ref_lm_head, BaseTunerLayer): + ref_lm_head.merge() yield ( ref_hidden_states, @@ -1816,5 +1812,5 @@ def get_decoder_outputs_for_liger_loss( ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None ) finally: - if self.ref_model is None and isinstance(ref_lm_head, BaseTunerLayer): + if isinstance(ref_lm_head, BaseTunerLayer): ref_lm_head.unmerge() \ No newline at end of file From f24a83f7f09f1b8f2772a44532ca6127c2a32cce Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 12 Mar 2025 18:37:05 +0000 Subject: [PATCH 16/17] potential bugfix --- trl/trainer/dpo_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 3bd2a615ebe..552c136fc34 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1151,13 +1151,14 @@ def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, tor ) hidden_states = decoder_outputs.last_hidden_state - ref_encoder_outputs = None + ref_encoder_hidden_states = None if not self.reference_free and self.ref_model is not None: ref_encoder_outputs = self.ref_model.get_encoder()( concatenated_batch["prompt_input_ids"], attention_mask=concatenated_batch["prompt_attention_mask"], return_dict=True, ) + ref_encoder_hidden_states = ref_encoder_outputs.last_hidden_state # ref_decoder_outputs = self.ref_model.get_decoder()( # input_ids=decoder_input_ids, # attention_mask=concatenated_batch["completion_attention_mask"], @@ -1172,6 +1173,7 @@ def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, tor attention_mask=concatenated_batch["prompt_attention_mask"], return_dict=True, ) + ref_encoder_hidden_states = ref_encoder_outputs.last_hidden_state # ref_decoder_outputs = model.get_decoder()( # input_ids=decoder_input_ids, # attention_mask=concatenated_batch["completion_attention_mask"], @@ -1183,7 +1185,7 @@ def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, tor ref_model_inputs = { "input_ids": decoder_input_ids, "attention_mask": concatenated_batch["completion_attention_mask"], - "encoder_hidden_states": ref_encoder_outputs.last_hidden_state, + "encoder_hidden_states": ref_encoder_hidden_states, "encoder_attention_mask": concatenated_batch["prompt_attention_mask"], } labels = concatenated_batch["completion_input_ids"] From 41c18339681c6cf27acff007d0208d415cedba5f Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Fri, 14 Mar 2025 13:54:55 +0000 Subject: [PATCH 17/17] adding tests - moving to utils --- tests/test_utils.py | 110 +++++++++++++++++++++++++++++++++++++ trl/trainer/dpo_trainer.py | 56 ++++--------------- trl/trainer/utils.py | 93 +++++++++++++++++++++++++++++-- 3 files changed, 210 insertions(+), 49 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0061dd5e5e3..6d02372bc91 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -31,6 +31,7 @@ generate_model_card, get_peft_config, pad, + get_decoder_outputs_for_liger_loss, ) @@ -451,3 +452,112 @@ def test_no_tensors(self): expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) self.assertTrue(torch.equal(new_mask, expected_mask)) + + +class TestGetDecoderOutputsForLigerLoss(unittest.TestCase): + def test_reference_free(self): + """Test that when reference_free is True, the function yields None values.""" + from trl.trainer.utils import get_decoder_outputs_for_liger_loss + from contextlib import nullcontext + + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + inputs = tokenizer("Hello world", return_tensors="pt") + + + with get_decoder_outputs_for_liger_loss( + model=model, + ref_model=model, + reference_free=True, + is_encoder_decoder=False, + base_model_attribute_name="model", + null_ref_context=nullcontext, + ref_model_inputs=inputs + ) as (ref_hidden_states, ref_weight, ref_bias): + self.assertIsNone(ref_hidden_states) + self.assertIsNone(ref_weight) + self.assertIsNone(ref_bias) + + def test_with_ref_model(self): + """Test with a real reference model.""" + from trl.trainer.utils import get_decoder_outputs_for_liger_loss + from contextlib import nullcontext + + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + inputs = tokenizer("Hello world", return_tensors="pt") + + with get_decoder_outputs_for_liger_loss( + model=model, + ref_model=model, + reference_free=False, + is_encoder_decoder=False, + base_model_attribute_name="model", + null_ref_context=nullcontext, + ref_model_inputs=inputs + ) as (ref_hidden_states, ref_weight, ref_bias): + + self.assertIsNotNone(ref_hidden_states) + self.assertIsNotNone(ref_weight) + + self.assertEqual(ref_hidden_states.shape[0], inputs["input_ids"].shape[0]) + self.assertEqual(ref_hidden_states.shape[1], inputs["input_ids"].shape[1] - 1) + self.assertEqual(ref_hidden_states.shape[2], model.config.hidden_size) + + self.assertEqual(ref_weight.shape[0], model.config.vocab_size) + self.assertEqual(ref_weight.shape[1], model.config.hidden_size) + + if ref_bias is not None: + self.assertEqual(ref_bias.shape[0], model.config.vocab_size) + + @require_peft + def test_with_peft_model(self): + """Test with a PEFT model that requires merge/unmerge operations.""" + from trl.trainer.utils import get_decoder_outputs_for_liger_loss + from contextlib import nullcontext + from peft import get_peft_model + + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + + peft_config = LoraConfig( + r=8, + lora_alpha=16, + lora_dropout=0.1, + target_modules=["q_proj", "v_proj", "lm_head"], + ) + peft_model = get_peft_model(model, peft_config) + + inputs = tokenizer("Hello, world!", return_tensors="pt") + input_ids = inputs["input_ids"] + + lm_head = peft_model.get_output_embeddings() + original_lm_head_weight = lm_head.base_layer.weight.clone() + + with get_decoder_outputs_for_liger_loss( + model=model, + ref_model=peft_model, + reference_free=False, + is_encoder_decoder=False, + base_model_attribute_name="model", + null_ref_context=nullcontext, + ref_model_inputs={"input_ids": input_ids} + ) as (ref_hidden_states, ref_weight, ref_bias): + self.assertEqual(ref_hidden_states.shape[0], input_ids.shape[0]) + self.assertEqual(ref_hidden_states.shape[1], input_ids.shape[1] - 1) + self.assertEqual(ref_hidden_states.shape[2], peft_model.config.hidden_size) + + self.assertEqual(ref_weight.shape[0], peft_model.config.vocab_size) + self.assertEqual(ref_weight.shape[1], peft_model.config.hidden_size) + + if ref_bias is not None: + self.assertEqual(ref_bias.shape[0], peft_model.config.vocab_size) + + restored_lm_head_weight = peft_model.get_output_embeddings().base_layer.weight + self.assertTrue(torch.equal(original_lm_head_weight, restored_lm_head_weight)) + + \ No newline at end of file diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 552c136fc34..2b9603c7020 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -69,10 +69,9 @@ pad, pad_to_length, peft_module_casting_to_bf16, + get_decoder_outputs_for_liger_loss, ) -from peft.tuners.tuners_utils import BaseTunerLayer - if is_peft_available(): from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training @@ -1265,7 +1264,15 @@ def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, tor # ref_weight = ref_lm_head.weight # ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None - with self.get_decoder_outputs_for_liger_loss(ref_model_inputs) as (ref_hidden_states, ref_weight, ref_bias): + with get_decoder_outputs_for_liger_loss( + model=self.model, + ref_model=self.ref_model, + reference_free=self.reference_free, + is_encoder_decoder=self.is_encoder_decoder, + base_model_attribute_name=self.args.base_model_attribute_name, + null_ref_context=self.null_ref_context, + ref_model_inputs=ref_model_inputs + ) as (ref_hidden_states, ref_weight, ref_bias): # Compute loss using Liger kernel loss_output = self.dpo_loss_fn( lm_head.weight, @@ -1774,45 +1781,4 @@ def create_model_card( paper_id="2305.18290", ) - model_card.save(os.path.join(self.args.output_dir, "README.md")) - - @contextmanager - def get_decoder_outputs_for_liger_loss( - self, - ref_model_inputs: dict[str, Union[list, torch.LongTensor]], - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - if self.reference_free: - yield None, None, None - return - - if self.ref_model is None: - ref_model = self.model - context_manager = self.null_ref_context() - else: - ref_model = self.ref_model - context_manager = nullcontext() - - if self.is_encoder_decoder or hasattr(ref_model, "get_decoder"): - ref_base_model = ref_model.get_decoder() - else: - ref_base_model = getattr(ref_model, self.args.base_model_attribute_name, ref_model) - - with context_manager: - ref_outputs = ref_base_model( - **ref_model_inputs - ) - ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] - - try: - ref_lm_head = ref_model.get_output_embeddings() - if isinstance(ref_lm_head, BaseTunerLayer): - ref_lm_head.merge() - - yield ( - ref_hidden_states, - ref_lm_head.weight, - ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None - ) - finally: - if isinstance(ref_lm_head, BaseTunerLayer): - ref_lm_head.unmerge() \ No newline at end of file + model_card.save(os.path.join(self.args.output_dir, "README.md")) \ No newline at end of file diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 719d952f1f4..901559a29f9 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import dataclasses import importlib.resources as pkg_resources import json @@ -20,7 +21,7 @@ from collections import deque from dataclasses import dataclass, field from importlib.metadata import version -from typing import Any, Literal, Optional, Union +from typing import Any, Generator, Literal, Optional, Union import datasets import numpy as np @@ -51,6 +52,12 @@ is_torch_xpu_available, ) +from contextlib import contextmanager + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + from peft.tuners.tuners_utils import BaseTunerLayer + from ..trainer.model_config import ModelConfig @@ -308,9 +315,9 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids] attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask] labels = [torch.tensor(label, dtype=torch.long) for label in labels] - input_ids = pad(input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) - attention_mask = pad(attention_mask, padding_side="left", padding_value=0) - labels = pad(labels, padding_side="left", padding_value=self.ignore_index) + input_ids = pad(input_ids, padding_value=self.tokenizer.pad_token_id) + attention_mask = pad(attention_mask, padding_value=0) + labels = pad(labels, padding_value=self.ignore_index) prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in prompts_input_ids] prompt_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in prompt_attention_mask] @@ -1647,3 +1654,81 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor return mask else: return mask, *tensors + +@contextmanager +def get_decoder_outputs_for_liger_loss( + model: torch.nn.Module, + ref_model: Optional[torch.nn.Module], + reference_free: bool, + is_encoder_decoder: bool, + base_model_attribute_name: str, + null_ref_context: contextlib.ContextDecorator, + ref_model_inputs: dict[str, Union[list, torch.LongTensor]], +) -> Generator[Any, Any, Any]: + """ + Get the decoder outputs for the Liger loss. + + Args: + model (`torch.nn.Module`): + The model to get the decoder outputs for. + ref_model (`torch.nn.Module`): + The reference model to get the decoder outputs for. + reference_free (`bool`): + Whether the reference model is reference-free. + is_encoder_decoder (`bool`): + Whether the model is an encoder-decoder model. + base_model_attribute_name (`str`): + The attribute name of the base model in the reference model. + null_ref_context (`contextlib.ContextDecorator`): + The context manager for the reference model. + ref_model_inputs (`dict`): + The inputs to the reference model. + + Yields: + `tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: + The decoder outputs for the Liger loss. + The tuple contains the following elements: + - `ref_hidden_states`: The hidden states of the reference model. + - `ref_lm_head_weight`: The weight of the reference model's language model head. + - `ref_lm_head_bias`: The bias of the reference model's language model head. + + """ + if reference_free: + yield None, None, None + return + + if ref_model is None: + ref_model = model + context_manager = null_ref_context() + else: + from contextlib import nullcontext + context_manager = nullcontext() + + if is_encoder_decoder or hasattr(ref_model, "get_decoder"): + ref_base_model = ref_model.get_decoder() + else: + ref_base_model = getattr(ref_model, base_model_attribute_name, ref_model) + + with context_manager: + ref_outputs = ref_base_model( + **ref_model_inputs + ) + ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + + try: + ref_lm_head = ref_model.get_output_embeddings() + if is_peft_available(): + from peft.tuners.tuners_utils import BaseTunerLayer + if isinstance(ref_lm_head, BaseTunerLayer): + ref_lm_head.merge() + + yield ( + ref_hidden_states, + ref_lm_head.weight, + ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None + ) + finally: + if is_peft_available(): + from peft.tuners.tuners_utils import BaseTunerLayer + if isinstance(ref_lm_head, BaseTunerLayer): + ref_lm_head.unmerge() \ No newline at end of file