diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index 6c054906165..ad943d6b870 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -16,7 +16,7 @@ Sequence lengths in the dataset can vary widely. When data is batched, sequences To reduce memory usage, it’s important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case. - + DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence. @@ -94,6 +94,21 @@ Packing may cause batch contamination, where adjacent sequences influence one an +## Liger for reducing peak memory usage + +[To complete] + + + + +To use Liger for reducing peak memory usage, use the following code snippet: + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., use_liger_loss=True) +``` + ## Disabling model gathering for generation in online methods When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204). diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index c4a0232ee30..5066dc98f8b 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -29,7 +29,12 @@ PreTrainedTokenizerBase, is_vision_available, ) -from transformers.testing_utils import require_peft, require_torch_gpu_if_bnb_not_multi_backend_enabled, require_vision +from transformers.testing_utils import ( + require_liger_kernel, + require_peft, + require_torch_gpu_if_bnb_not_multi_backend_enabled, + require_vision, +) from trl import DPOConfig, DPOTrainer, FDivergenceType @@ -1227,6 +1232,75 @@ def test_padding_free(self): if param.sum() != 0: # ignore 0 biases self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + @require_liger_kernel + @parameterized.expand([(0.1,), (0.5,)]) + def test_dpo_trainer_with_liger(self, beta): + """Test DPO trainer with Liger loss enabled. + + This test verifies that: + 1. Training runs successfully with Liger loss + 2. Model parameters update as expected + 3. Loss values are reasonable and finite + 4. Training works with both default and custom beta values + """ + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=beta, + use_liger_loss=True, # Enable Liger loss + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = DPOTrainer( + model=self.model, + ref_model=self.ref_model, # Add reference model + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + # Store initial parameters + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + train_output = trainer.train() + + # Verify training completed successfully + self.assertIsNotNone(train_output) + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Verify loss is finite + self.assertTrue(np.isfinite(trainer.state.log_history[-1]["train_loss"])) + + # Check parameters have been updated + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # Only check non-zero parameters + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + # Verify new parameters are finite + self.assertTrue(torch.isfinite(new_param).all()) + + # Verify model can still do forward pass after training + dummy_batch = next(iter(trainer.get_train_dataloader())) + model_inputs = { + "input_ids": dummy_batch["prompt_input_ids"], + "attention_mask": dummy_batch["prompt_attention_mask"], + } + with torch.no_grad(): + output = trainer.model(**model_inputs) + self.assertIsNotNone(output) + self.assertIsNone(output.loss) + @require_vision class DPOVisionTrainerTester(unittest.TestCase): diff --git a/tests/test_utils.py b/tests/test_utils.py index 0061dd5e5e3..6d02372bc91 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -31,6 +31,7 @@ generate_model_card, get_peft_config, pad, + get_decoder_outputs_for_liger_loss, ) @@ -451,3 +452,112 @@ def test_no_tensors(self): expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) self.assertTrue(torch.equal(new_mask, expected_mask)) + + +class TestGetDecoderOutputsForLigerLoss(unittest.TestCase): + def test_reference_free(self): + """Test that when reference_free is True, the function yields None values.""" + from trl.trainer.utils import get_decoder_outputs_for_liger_loss + from contextlib import nullcontext + + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + inputs = tokenizer("Hello world", return_tensors="pt") + + + with get_decoder_outputs_for_liger_loss( + model=model, + ref_model=model, + reference_free=True, + is_encoder_decoder=False, + base_model_attribute_name="model", + null_ref_context=nullcontext, + ref_model_inputs=inputs + ) as (ref_hidden_states, ref_weight, ref_bias): + self.assertIsNone(ref_hidden_states) + self.assertIsNone(ref_weight) + self.assertIsNone(ref_bias) + + def test_with_ref_model(self): + """Test with a real reference model.""" + from trl.trainer.utils import get_decoder_outputs_for_liger_loss + from contextlib import nullcontext + + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + inputs = tokenizer("Hello world", return_tensors="pt") + + with get_decoder_outputs_for_liger_loss( + model=model, + ref_model=model, + reference_free=False, + is_encoder_decoder=False, + base_model_attribute_name="model", + null_ref_context=nullcontext, + ref_model_inputs=inputs + ) as (ref_hidden_states, ref_weight, ref_bias): + + self.assertIsNotNone(ref_hidden_states) + self.assertIsNotNone(ref_weight) + + self.assertEqual(ref_hidden_states.shape[0], inputs["input_ids"].shape[0]) + self.assertEqual(ref_hidden_states.shape[1], inputs["input_ids"].shape[1] - 1) + self.assertEqual(ref_hidden_states.shape[2], model.config.hidden_size) + + self.assertEqual(ref_weight.shape[0], model.config.vocab_size) + self.assertEqual(ref_weight.shape[1], model.config.hidden_size) + + if ref_bias is not None: + self.assertEqual(ref_bias.shape[0], model.config.vocab_size) + + @require_peft + def test_with_peft_model(self): + """Test with a PEFT model that requires merge/unmerge operations.""" + from trl.trainer.utils import get_decoder_outputs_for_liger_loss + from contextlib import nullcontext + from peft import get_peft_model + + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + + peft_config = LoraConfig( + r=8, + lora_alpha=16, + lora_dropout=0.1, + target_modules=["q_proj", "v_proj", "lm_head"], + ) + peft_model = get_peft_model(model, peft_config) + + inputs = tokenizer("Hello, world!", return_tensors="pt") + input_ids = inputs["input_ids"] + + lm_head = peft_model.get_output_embeddings() + original_lm_head_weight = lm_head.base_layer.weight.clone() + + with get_decoder_outputs_for_liger_loss( + model=model, + ref_model=peft_model, + reference_free=False, + is_encoder_decoder=False, + base_model_attribute_name="model", + null_ref_context=nullcontext, + ref_model_inputs={"input_ids": input_ids} + ) as (ref_hidden_states, ref_weight, ref_bias): + self.assertEqual(ref_hidden_states.shape[0], input_ids.shape[0]) + self.assertEqual(ref_hidden_states.shape[1], input_ids.shape[1] - 1) + self.assertEqual(ref_hidden_states.shape[2], peft_model.config.hidden_size) + + self.assertEqual(ref_weight.shape[0], peft_model.config.vocab_size) + self.assertEqual(ref_weight.shape[1], peft_model.config.hidden_size) + + if ref_bias is not None: + self.assertEqual(ref_bias.shape[0], peft_model.config.vocab_size) + + restored_lm_head_weight = peft_model.get_output_embeddings().base_layer.weight + self.assertTrue(torch.equal(original_lm_head_weight, restored_lm_head_weight)) + + \ No newline at end of file diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index b7c18e11cc4..2ab34cd7f96 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -119,6 +119,11 @@ class DPOConfig(TrainingArguments): - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. + use_liger_loss (`bool`, *optional*, defaults to `False`): + Whether to use Liger loss. + base_model_attribute_name (`str`, *optional*, defaults to `"model"`): + Name of the attribute in the model that contains the base model. This is used to get the base model from + the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`. beta (`float`, *optional*, defaults to `0.1`): Parameter controlling the deviation from the reference model. Higher β means less deviation from the reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in @@ -301,6 +306,18 @@ class DPOConfig(TrainingArguments): ], }, ) + use_liger_loss: bool = field( + default=False, + metadata={"help": "Whether to use Liger loss."}, + ) + base_model_attribute_name: str = field( + default="model", + metadata={ + "help": "Name of the attribute in the model that contains the base model. This is used to get the base " + "model from the model when the model does not have a `get_decoder` method in the case when " + "`use_liger_loss` is `True`." + }, + ) beta: float = field( default=0.1, metadata={ diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 218f1af5a9e..2b9603c7020 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -50,7 +50,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput -from transformers.utils import is_peft_available, is_torch_xpu_available +from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_xpu_available from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt @@ -69,12 +69,15 @@ pad, pad_to_length, peft_module_casting_to_bf16, + get_decoder_outputs_for_liger_loss, ) - if is_peft_available(): from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss + if is_wandb_available(): import wandb @@ -83,6 +86,13 @@ import deepspeed +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int) -> torch.Tensor: + """Shift input ids one token to the right, and pad with pad_token_id""" + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + @dataclass class DataCollatorForPreference(DataCollatorMixin): """ @@ -388,6 +398,17 @@ def make_inputs_require_grad(module, input, output): if self.ref_model is not None: disable_dropout_in_model(self.ref_model) + # Liger kernel + if args.use_liger_loss: + if not is_liger_kernel_available(): + raise ImportError( + "You set `use_liger_loss=True` but the liger kernel is not available. " + "Please install liger-kernel first: `pip install liger-kernel`" + ) + self.dpo_loss_fn = LigerFusedLinearDPOLoss( + ignore_index=args.label_pad_token_id, beta=args.beta, use_ref_model=not args.reference_free + ) + self.max_length = args.max_length self.generate_during_eval = args.generate_during_eval self.label_pad_token_id = args.label_pad_token_id @@ -1093,6 +1114,195 @@ def dpo_loss( return losses, chosen_rewards, rejected_rewards + def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]): + concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) + + model_kwargs = {} + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + # Add the pixel values and attention masks for vision models + if "pixel_values" in concatenated_batch: + model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] + if "pixel_attention_mask" in concatenated_batch: + model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] + if "image_sizes" in concatenated_batch: + model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] + + if self.is_encoder_decoder: + # 1. Get encoder outputs + encoder_outputs = model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + # 2. Prepare decoder inputs + decoder_input_ids = shift_tokens_right( + concatenated_batch["completion_input_ids"], self.padding_value, model.config.decoder_start_token_id + ) + # 3. Get decoder outputs + decoder_outputs = model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + hidden_states = decoder_outputs.last_hidden_state + + ref_encoder_hidden_states = None + if not self.reference_free and self.ref_model is not None: + ref_encoder_outputs = self.ref_model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + ref_encoder_hidden_states = ref_encoder_outputs.last_hidden_state + # ref_decoder_outputs = self.ref_model.get_decoder()( + # input_ids=decoder_input_ids, + # attention_mask=concatenated_batch["completion_attention_mask"], + # encoder_hidden_states=ref_encoder_outputs.last_hidden_state, + # encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + # use_cache=False, + # ) + elif not self.reference_free: + with self.null_ref_context(): + ref_encoder_outputs = model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + ref_encoder_hidden_states = ref_encoder_outputs.last_hidden_state + # ref_decoder_outputs = model.get_decoder()( + # input_ids=decoder_input_ids, + # attention_mask=concatenated_batch["completion_attention_mask"], + # encoder_hidden_states=ref_encoder_outputs.last_hidden_state, + # encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + # use_cache=False, + # ) + + ref_model_inputs = { + "input_ids": decoder_input_ids, + "attention_mask": concatenated_batch["completion_attention_mask"], + "encoder_hidden_states": ref_encoder_hidden_states, + "encoder_attention_mask": concatenated_batch["prompt_attention_mask"], + } + labels = concatenated_batch["completion_input_ids"] + else: + # For decoder-only models + input_ids = torch.cat( + (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1 + ) + attention_mask = torch.cat( + (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]), + dim=1, + ) + + # Get the base model outputs (before LM head) + if hasattr(model, "get_decoder"): + base_model = model.get_decoder() + else: + base_model = getattr(model, self.args.base_model_attribute_name, model) + + outputs = base_model( + input_ids, + attention_mask=attention_mask, + use_cache=False, + **model_kwargs, + ) + hidden_states = outputs.last_hidden_state[:, :-1] + + # # Get reference hidden states if needed + # ref_hidden_states = None + # if not self.reference_free and self.ref_model is not None: + # if hasattr(self.ref_model, "get_decoder"): + # ref_base_model = self.ref_model.get_decoder() + # else: + # ref_base_model = getattr(self.ref_model, self.args.base_model_attribute_name, self.ref_model) + + # ref_outputs = ref_base_model( + # input_ids, + # attention_mask=attention_mask, + # use_cache=False, + # **model_kwargs, + # ) + # ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + # elif not self.reference_free: + # if hasattr(model, "get_decoder"): + # ref_base_model = model.get_decoder() + # else: + # ref_base_model = getattr(model, self.args.base_model_attribute_name, model) + # with self.null_ref_context(): + # ref_outputs = ref_base_model( + # input_ids, + # attention_mask=attention_mask, + # use_cache=False, + # **model_kwargs, + # ) + # ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + + ref_model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "use_cache": False, + **model_kwargs, + } + labels = input_ids[:, 1:] # Shift right for casual LM + + # Get the LM head + lm_head = model.get_output_embeddings() + + # # Get reference model weights if needed + # ref_weight = None + # ref_bias = None + # if not self.reference_free: + # if self.ref_model is not None: + # ref_lm_head = self.ref_model.get_output_embeddings() + # else: + # with self.null_ref_context(): + # ref_lm_head = model.get_output_embeddings() + # ref_weight = ref_lm_head.weight + # ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None + + with get_decoder_outputs_for_liger_loss( + model=self.model, + ref_model=self.ref_model, + reference_free=self.reference_free, + is_encoder_decoder=self.is_encoder_decoder, + base_model_attribute_name=self.args.base_model_attribute_name, + null_ref_context=self.null_ref_context, + ref_model_inputs=ref_model_inputs + ) as (ref_hidden_states, ref_weight, ref_bias): + # Compute loss using Liger kernel + loss_output = self.dpo_loss_fn( + lm_head.weight, + hidden_states, + labels, + bias=lm_head.bias if hasattr(lm_head, "bias") else None, + ref_input=ref_hidden_states, + ref_weight=ref_weight, + ref_bias=ref_bias, + ) + ( + loss, + (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs), + ) = loss_output + + output = { + "loss": loss, + "chosen_logps": chosen_logps, + "rejected_logps": rejected_logps, + "mean_chosen_logits": chosen_logits_mean, + "mean_rejected_logits": rejected_logits_mean, + "nll_loss": nll_loss, + "chosen_rewards": aux_outputs[0], + "rejected_rewards": aux_outputs[1], + } + if self.aux_loss_enabled: + output["aux_loss"] = outputs.aux_loss + + return output + def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]): """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. @@ -1224,8 +1434,8 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to if self.args.rpo_alpha is not None: # Only use the chosen logits for the RPO loss - chosen_logits = logits[:num_examples] - chosen_labels = labels[:num_examples] + chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples] + chosen_labels = labels[:num_examples, 1:] if self.is_encoder_decoder else labels[:num_examples] # Compute the log probabilities of the labels output["nll_loss"] = F.cross_entropy( @@ -1268,18 +1478,24 @@ def get_batch_loss_metrics( """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} - model_output = self.concatenated_forward(model, batch) - - # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model - if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: - ref_chosen_logps = batch["ref_chosen_logps"] - ref_rejected_logps = batch["ref_rejected_logps"] + if self.args.use_liger_loss and self.loss_type == "sigmoid": + model_output = self._compute_loss_liger(model, batch) + losses = model_output["loss"] + chosen_rewards = model_output["chosen_rewards"] + rejected_rewards = model_output["rejected_rewards"] else: - ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) + model_output = self.concatenated_forward(model, batch) - losses, chosen_rewards, rejected_rewards = self.dpo_loss( - model_output["chosen_logps"], model_output["rejected_logps"], ref_chosen_logps, ref_rejected_logps - ) + # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model + if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: + ref_chosen_logps = batch["ref_chosen_logps"] + ref_rejected_logps = batch["ref_rejected_logps"] + else: + ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) + + losses, chosen_rewards, rejected_rewards = self.dpo_loss( + model_output["chosen_logps"], model_output["rejected_logps"], ref_chosen_logps, ref_rejected_logps + ) reward_accuracies = (chosen_rewards > rejected_rewards).float() if self.args.rpo_alpha is not None: @@ -1565,4 +1781,4 @@ def create_model_card( paper_id="2305.18290", ) - model_card.save(os.path.join(self.args.output_dir, "README.md")) + model_card.save(os.path.join(self.args.output_dir, "README.md")) \ No newline at end of file diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 719d952f1f4..901559a29f9 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import dataclasses import importlib.resources as pkg_resources import json @@ -20,7 +21,7 @@ from collections import deque from dataclasses import dataclass, field from importlib.metadata import version -from typing import Any, Literal, Optional, Union +from typing import Any, Generator, Literal, Optional, Union import datasets import numpy as np @@ -51,6 +52,12 @@ is_torch_xpu_available, ) +from contextlib import contextmanager + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + from peft.tuners.tuners_utils import BaseTunerLayer + from ..trainer.model_config import ModelConfig @@ -308,9 +315,9 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids] attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask] labels = [torch.tensor(label, dtype=torch.long) for label in labels] - input_ids = pad(input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) - attention_mask = pad(attention_mask, padding_side="left", padding_value=0) - labels = pad(labels, padding_side="left", padding_value=self.ignore_index) + input_ids = pad(input_ids, padding_value=self.tokenizer.pad_token_id) + attention_mask = pad(attention_mask, padding_value=0) + labels = pad(labels, padding_value=self.ignore_index) prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in prompts_input_ids] prompt_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in prompt_attention_mask] @@ -1647,3 +1654,81 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor return mask else: return mask, *tensors + +@contextmanager +def get_decoder_outputs_for_liger_loss( + model: torch.nn.Module, + ref_model: Optional[torch.nn.Module], + reference_free: bool, + is_encoder_decoder: bool, + base_model_attribute_name: str, + null_ref_context: contextlib.ContextDecorator, + ref_model_inputs: dict[str, Union[list, torch.LongTensor]], +) -> Generator[Any, Any, Any]: + """ + Get the decoder outputs for the Liger loss. + + Args: + model (`torch.nn.Module`): + The model to get the decoder outputs for. + ref_model (`torch.nn.Module`): + The reference model to get the decoder outputs for. + reference_free (`bool`): + Whether the reference model is reference-free. + is_encoder_decoder (`bool`): + Whether the model is an encoder-decoder model. + base_model_attribute_name (`str`): + The attribute name of the base model in the reference model. + null_ref_context (`contextlib.ContextDecorator`): + The context manager for the reference model. + ref_model_inputs (`dict`): + The inputs to the reference model. + + Yields: + `tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: + The decoder outputs for the Liger loss. + The tuple contains the following elements: + - `ref_hidden_states`: The hidden states of the reference model. + - `ref_lm_head_weight`: The weight of the reference model's language model head. + - `ref_lm_head_bias`: The bias of the reference model's language model head. + + """ + if reference_free: + yield None, None, None + return + + if ref_model is None: + ref_model = model + context_manager = null_ref_context() + else: + from contextlib import nullcontext + context_manager = nullcontext() + + if is_encoder_decoder or hasattr(ref_model, "get_decoder"): + ref_base_model = ref_model.get_decoder() + else: + ref_base_model = getattr(ref_model, base_model_attribute_name, ref_model) + + with context_manager: + ref_outputs = ref_base_model( + **ref_model_inputs + ) + ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + + try: + ref_lm_head = ref_model.get_output_embeddings() + if is_peft_available(): + from peft.tuners.tuners_utils import BaseTunerLayer + if isinstance(ref_lm_head, BaseTunerLayer): + ref_lm_head.merge() + + yield ( + ref_hidden_states, + ref_lm_head.weight, + ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None + ) + finally: + if is_peft_available(): + from peft.tuners.tuners_utils import BaseTunerLayer + if isinstance(ref_lm_head, BaseTunerLayer): + ref_lm_head.unmerge() \ No newline at end of file