Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,34 @@ training_args = SFTConfig(..., packing=True, max_length=512)
> [!WARNING]
> Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see [#1230](https://github.com/huggingface/trl/issues/1230).

## PEFT for parameter-efficient fine-tuning

Parameter-Efficient Fine-Tuning (PEFT) methods like LoRA are among the most effective techniques for reducing memory usage during training. Instead of training all model parameters, PEFT methods train only a small number of adapter parameters, significantly reducing memory requirements and enabling fine-tuning of larger models on limited hardware.

For comprehensive details on using PEFT with TRL, including various adapter methods, quantization options, and advanced configurations, see [PEFT Integration](peft_integration).

To use PEFT for reducing memory usage:

```python
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer

peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"],
)

trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
peft_config=peft_config,
args=training_args,
)
Comment on lines +103 to +114

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • in my opinion no need to specify value, defaults are good: LoraConfig()
  • the dataset is missing in SFTTrainer
  • you use training_args is used but never defined. I think you can simply drop it

```

PEFT can be combined with other memory reduction techniques such as quantization (4-bit or 8-bit) for even greater memory savings. See [PEFT Integration](peft_integration) for quantization examples.

## Liger for reducing peak memory usage

> [Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%.
Expand Down
80 changes: 78 additions & 2 deletions docs/source/speeding_up_training.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Speeding Up Training

> [!WARNING]
> Section under construction. Feel free to contribute!
This guide covers various methods to accelerate training in TRL. Each technique includes minimal examples with links to more comprehensive documentation.

## vLLM for fast generation in online methods

Expand Down Expand Up @@ -95,3 +94,80 @@ You can customize the server configuration by passing additional arguments. For

</hfoption>
</hfoptions>

## Flash Attention 2 for faster attention computation

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is out of the scope of the PR, right?


Flash Attention 2 is an optimized implementation of the attention mechanism that can significantly speed up training while reducing memory usage. It's particularly effective for long sequences.

To enable Flash Attention 2, pass `attn_implementation="flash_attention_2"` in the model initialization arguments:

```python
from trl import SFTConfig

training_args = SFTConfig(
...,
model_init_kwargs={"attn_implementation": "flash_attention_2"}
)
```

Flash Attention 2 works across all TRL trainers. For padding-free batching with Flash Attention, see [Reducing Memory Usage](reducing_memory_usage#padding-free).

## PEFT for parameter-efficient training

PEFT (Parameter-Efficient Fine-Tuning) methods like LoRA significantly reduce memory usage and training time by only training a small number of adapter parameters instead of the full model.

```python
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer

peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"],
)

trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
peft_config=peft_config,
args=training_args,
)
```

For more details, see [PEFT Integration](peft_integration).

## Liger Kernel for memory optimization

Liger Kernel is a collection of Triton kernels designed for LLM training that can increase throughput by 20% and reduce memory usage by 60%.

```python
from trl import DPOConfig

training_args = DPOConfig(..., use_liger_kernel=True)
```

Liger Kernel is supported across multiple trainers (SFT, DPO, GRPO, KTO, GKD). For more information, see [Liger Kernel Integration](liger_kernel_integration).

## Gradient checkpointing for memory savings

Gradient checkpointing trades compute for memory by not storing all intermediate activations during the forward pass, recomputing them during the backward pass instead.

```python
from trl import SFTConfig

training_args = SFTConfig(..., gradient_checkpointing=True)
```

Gradient checkpointing is available across all TRL trainers. For more memory optimization techniques, see the [Transformers Performance Guide](https://huggingface.co/docs/transformers/perf_train_gpu_one#gradient-checkpointing).

## Mixed precision training

Mixed precision training using bf16 or fp16 can speed up training and reduce memory usage with minimal impact on model quality.

```python
from trl import SFTConfig

training_args = SFTConfig(..., bf16=True) # or fp16=True for older GPUs
```

Use `bf16=True` for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True` for older GPUs. Mixed precision training is supported across all TRL trainers.
Loading