Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Sequence lengths in the dataset can vary widely. When data is batched, sequences

To reduce memory usage, it’s important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.

<hfoptions id="dpo">
<hfoptions id="truncation">
<hfoption id="DPO">

DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence.
Expand Down Expand Up @@ -94,6 +94,21 @@ Packing may cause batch contamination, where adjacent sequences influence one an

</Tip>

## Liger for reducing peak memory usage

[To complete]

<hfoptions id="liger">
<hfoption id="DPO">

To use Liger for reducing peak memory usage, use the following code snippet:

```python
from trl import DPOConfig

training_args = DPOConfig(..., use_liger_loss=True)
```

## Disabling model gathering for generation in online methods

When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204).
Expand Down
76 changes: 75 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
PreTrainedTokenizerBase,
is_vision_available,
)
from transformers.testing_utils import require_peft, require_torch_gpu_if_bnb_not_multi_backend_enabled, require_vision
from transformers.testing_utils import (
require_liger_kernel,
require_peft,
require_torch_gpu_if_bnb_not_multi_backend_enabled,
require_vision,
)

from trl import DPOConfig, DPOTrainer, FDivergenceType

Expand Down Expand Up @@ -1227,6 +1232,75 @@ def test_padding_free(self):
if param.sum() != 0: # ignore 0 biases
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))

@require_liger_kernel
@parameterized.expand([(0.1,), (0.5,)])
def test_dpo_trainer_with_liger(self, beta):
"""Test DPO trainer with Liger loss enabled.

This test verifies that:
1. Training runs successfully with Liger loss
2. Model parameters update as expected
3. Loss values are reasonable and finite
4. Training works with both default and custom beta values
"""
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
beta=beta,
use_liger_loss=True, # Enable Liger loss
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = DPOTrainer(
model=self.model,
ref_model=self.ref_model, # Add reference model
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

# Store initial parameters
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
train_output = trainer.train()

# Verify training completed successfully
self.assertIsNotNone(train_output)
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Verify loss is finite
self.assertTrue(np.isfinite(trainer.state.log_history[-1]["train_loss"]))

# Check parameters have been updated
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# Only check non-zero parameters
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
# Verify new parameters are finite
self.assertTrue(torch.isfinite(new_param).all())

# Verify model can still do forward pass after training
dummy_batch = next(iter(trainer.get_train_dataloader()))
model_inputs = {
"input_ids": dummy_batch["prompt_input_ids"],
"attention_mask": dummy_batch["prompt_attention_mask"],
}
with torch.no_grad():
output = trainer.model(**model_inputs)
self.assertIsNotNone(output)
self.assertIsNone(output.loss)


@require_vision
class DPOVisionTrainerTester(unittest.TestCase):
Expand Down
110 changes: 110 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
generate_model_card,
get_peft_config,
pad,
get_decoder_outputs_for_liger_loss,
)


Expand Down Expand Up @@ -451,3 +452,112 @@ def test_no_tensors(self):
expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])

self.assertTrue(torch.equal(new_mask, expected_mask))


class TestGetDecoderOutputsForLigerLoss(unittest.TestCase):
def test_reference_free(self):
"""Test that when reference_free is True, the function yields None values."""
from trl.trainer.utils import get_decoder_outputs_for_liger_loss
from contextlib import nullcontext

model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

inputs = tokenizer("Hello world", return_tensors="pt")


with get_decoder_outputs_for_liger_loss(
model=model,
ref_model=model,
reference_free=True,
is_encoder_decoder=False,
base_model_attribute_name="model",
null_ref_context=nullcontext,
ref_model_inputs=inputs
) as (ref_hidden_states, ref_weight, ref_bias):
self.assertIsNone(ref_hidden_states)
self.assertIsNone(ref_weight)
self.assertIsNone(ref_bias)

def test_with_ref_model(self):
"""Test with a real reference model."""
from trl.trainer.utils import get_decoder_outputs_for_liger_loss
from contextlib import nullcontext

model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

inputs = tokenizer("Hello world", return_tensors="pt")

with get_decoder_outputs_for_liger_loss(
model=model,
ref_model=model,
reference_free=False,
is_encoder_decoder=False,
base_model_attribute_name="model",
null_ref_context=nullcontext,
ref_model_inputs=inputs
) as (ref_hidden_states, ref_weight, ref_bias):

self.assertIsNotNone(ref_hidden_states)
self.assertIsNotNone(ref_weight)

self.assertEqual(ref_hidden_states.shape[0], inputs["input_ids"].shape[0])
self.assertEqual(ref_hidden_states.shape[1], inputs["input_ids"].shape[1] - 1)
self.assertEqual(ref_hidden_states.shape[2], model.config.hidden_size)

self.assertEqual(ref_weight.shape[0], model.config.vocab_size)
self.assertEqual(ref_weight.shape[1], model.config.hidden_size)

if ref_bias is not None:
self.assertEqual(ref_bias.shape[0], model.config.vocab_size)

@require_peft
def test_with_peft_model(self):
"""Test with a PEFT model that requires merge/unmerge operations."""
from trl.trainer.utils import get_decoder_outputs_for_liger_loss
from contextlib import nullcontext
from peft import get_peft_model

model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")

peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "lm_head"],
)
peft_model = get_peft_model(model, peft_config)

inputs = tokenizer("Hello, world!", return_tensors="pt")
input_ids = inputs["input_ids"]

lm_head = peft_model.get_output_embeddings()
original_lm_head_weight = lm_head.base_layer.weight.clone()

with get_decoder_outputs_for_liger_loss(
model=model,
ref_model=peft_model,
reference_free=False,
is_encoder_decoder=False,
base_model_attribute_name="model",
null_ref_context=nullcontext,
ref_model_inputs={"input_ids": input_ids}
) as (ref_hidden_states, ref_weight, ref_bias):
self.assertEqual(ref_hidden_states.shape[0], input_ids.shape[0])
self.assertEqual(ref_hidden_states.shape[1], input_ids.shape[1] - 1)
self.assertEqual(ref_hidden_states.shape[2], peft_model.config.hidden_size)

self.assertEqual(ref_weight.shape[0], peft_model.config.vocab_size)
self.assertEqual(ref_weight.shape[1], peft_model.config.hidden_size)

if ref_bias is not None:
self.assertEqual(ref_bias.shape[0], peft_model.config.vocab_size)

restored_lm_head_weight = peft_model.get_output_embeddings().base_layer.weight
self.assertTrue(torch.equal(original_lm_head_weight, restored_lm_head_weight))


17 changes: 17 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ class DPOConfig(TrainingArguments):
- `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
- `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.

use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use Liger loss.
base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
Name of the attribute in the model that contains the base model. This is used to get the base model from
the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
beta (`float`, *optional*, defaults to `0.1`):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
Expand Down Expand Up @@ -301,6 +306,18 @@ class DPOConfig(TrainingArguments):
],
},
)
use_liger_loss: bool = field(
default=False,
metadata={"help": "Whether to use Liger loss."},
)
base_model_attribute_name: str = field(
default="model",
metadata={
"help": "Name of the attribute in the model that contains the base model. This is used to get the base "
"model from the model when the model does not have a `get_decoder` method in the case when "
"`use_liger_loss` is `True`."
},
)
beta: float = field(
default=0.1,
metadata={
Expand Down
Loading