From c59ec6d1d091219f8832b75f89109b60bd5fe0f9 Mon Sep 17 00:00:00 2001 From: Behrooz Date: Sun, 2 Nov 2025 12:28:48 -0800 Subject: [PATCH 1/4] docs: Expand training customization examples Resolves #4379 - Add custom callbacks example for logging and monitoring - Add custom evaluation metrics example - Add mixed precision training example (bf16/fp16) - Add gradient accumulation example - Add custom data collator example - Update introduction for better clarity --- docs/source/customization.md | 151 ++++++++++++++++++++++++++++++++++- 1 file changed, 150 insertions(+), 1 deletion(-) diff --git a/docs/source/customization.md b/docs/source/customization.md index 5989858122e..da4f9062a67 100644 --- a/docs/source/customization.md +++ b/docs/source/customization.md @@ -1,6 +1,6 @@ # Training customization -TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers. +TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are examples on how you can apply and test different techniques. Note: Although these examples use the [`DPOTrainer`], these customization methods apply to most (if not all) trainers in TRL. ## Use different optimizers and schedulers @@ -117,3 +117,152 @@ When training large models, you should better handle the accelerator cache by it ```python training_args = DPOConfig(..., optimize_device_cache=True) ``` + +## Add custom callbacks + +You can customize the training loop by adding callbacks for logging, monitoring, or early stopping. Callbacks allow you to execute custom code at specific points during training. + +```python +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback +from trl import DPOConfig, DPOTrainer + + +class CustomLoggingCallback(TrainerCallback): + def on_log(self, args, state, control, logs=None, **kwargs): + if logs is not None: + print(f"Step {state.global_step}: {logs}") + + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") + +trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, + callbacks=[CustomLoggingCallback()], +) +trainer.train() +``` + +## Add custom evaluation metrics + +You can define custom evaluation metrics to track during training. This is useful for monitoring model performance on specific tasks. + +```python +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import DPOConfig, DPOTrainer + + +def compute_metrics(eval_preds): + # Custom metric computation + logits, labels = eval_preds + # Add your metric computation here + return {"custom_metric": 0.0} + + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") +eval_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="test[:10%]") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO", eval_strategy="steps", eval_steps=100) + +trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, +) +trainer.train() +``` + +## Use mixed precision training + +Mixed precision training can significantly speed up training and reduce memory usage. You can enable it by setting `bf16=True` or `fp16=True` in the training config. + +```python +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import DPOConfig, DPOTrainer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +# Use bfloat16 precision (recommended for modern GPUs) +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO", bf16=True) + +trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, +) +trainer.train() +``` + +Note: Use `bf16=True` for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True` for older GPUs. + +## Use gradient accumulation + +When training with limited GPU memory, gradient accumulation allows you to simulate larger batch sizes by accumulating gradients over multiple steps before updating weights. + +```python +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import DPOConfig, DPOTrainer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +# Simulate a batch size of 32 with per_device_train_batch_size=4 and gradient_accumulation_steps=8 +training_args = DPOConfig( + output_dir="Qwen2.5-0.5B-DPO", + per_device_train_batch_size=4, + gradient_accumulation_steps=8, +) + +trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, +) +trainer.train() +``` + +## Use a custom data collator + +You can provide a custom data collator to handle special data preprocessing or padding strategies. + +```python +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import DPOConfig, DPOTrainer +from trl.trainer.dpo_trainer import DataCollatorForPreference + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") + +# Create a custom data collator with specific padding token +data_collator = DataCollatorForPreference(pad_token_id=tokenizer.pad_token_id) + +trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, + data_collator=data_collator, +) +trainer.train() +``` From 0fc059bb13923fcb5dd16d884a43d85723a971a7 Mon Sep 17 00:00:00 2001 From: Behrooz Date: Fri, 28 Nov 2025 14:39:34 -0800 Subject: [PATCH 2/4] docs: streamline code examples and remove obsolete section Address review feedback: - Remove obsolete optimize_device_cache section (no longer in codebase) - Reduce code snippets to show only relevant customization parts - Keep first example complete as reference, subsequent examples focused - Remove ~120 lines of repetitive boilerplate Improves clarity by highlighting the actual customization being demonstrated. --- docs/source/customization.md | 157 ++++------------------------------- 1 file changed, 14 insertions(+), 143 deletions(-) diff --git a/docs/source/customization.md b/docs/source/customization.md index da4f9062a67..b67213e4d9a 100644 --- a/docs/source/customization.md +++ b/docs/source/customization.md @@ -31,30 +31,15 @@ trainer.train() ### Add a learning rate scheduler -You can also play with your training by adding learning rate schedulers. +You can also add learning rate schedulers by passing both optimizer and scheduler: ```python -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer from torch import optim -from trl import DPOConfig, DPOTrainer - -model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") -training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") optimizer = optim.AdamW(model.parameters(), lr=training_args.learning_rate) lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) -trainer = DPOTrainer( - model=model, - args=training_args, - train_dataset=dataset, - tokenizer=tokenizer, - optimizers=(optimizer, lr_scheduler), -) -trainer.train() +trainer = DPOTrainer(..., optimizers=(optimizer, lr_scheduler)) ``` ## Memory efficient fine-tuning by sharing layers @@ -62,24 +47,11 @@ trainer.train() Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train. ```python -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer -from trl import create_reference_model, DPOConfig, DPOTrainer +from trl import create_reference_model -model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") ref_model = create_reference_model(model, num_shared_layers=6) -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:1%]") -training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") -trainer = DPOTrainer( - model=model, - ref_model=ref_model, - args=training_args, - train_dataset=dataset, - tokenizer=tokenizer, -) -trainer.train() +trainer = DPOTrainer(..., ref_model=ref_model) ``` ## Pass 8-bit reference models @@ -89,33 +61,12 @@ Since `trl` supports all keyword arguments when loading a model from `transforme Read more about 8-bit model loading in `transformers` [Load in 8bit or 4bit](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit). ```python -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig -from trl import DPOConfig, DPOTrainer +from transformers import AutoModelForCausalLM, BitsAndBytesConfig -model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") quantization_config = BitsAndBytesConfig(load_in_8bit=True) -ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config= quantization_config) -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") -training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") +ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config=quantization_config) -trainer = DPOTrainer( - model=model, - ref_model=ref_model, - args=training_args, - train_dataset=dataset, - tokenizer=tokenizer, -) -trainer.train() -``` - -## Use the accelerator cache optimizer - -When training large models, you should better handle the accelerator cache by iteratively clearing it. To do so, simply pass `optimize_device_cache=True` to [`DPOConfig`]: - -```python -training_args = DPOConfig(..., optimize_device_cache=True) +trainer = DPOTrainer(..., ref_model=ref_model) ``` ## Add custom callbacks @@ -123,9 +74,7 @@ training_args = DPOConfig(..., optimize_device_cache=True) You can customize the training loop by adding callbacks for logging, monitoring, or early stopping. Callbacks allow you to execute custom code at specific points during training. ```python -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback -from trl import DPOConfig, DPOTrainer +from transformers import TrainerCallback class CustomLoggingCallback(TrainerCallback): @@ -134,19 +83,7 @@ class CustomLoggingCallback(TrainerCallback): print(f"Step {state.global_step}: {logs}") -model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") -training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") - -trainer = DPOTrainer( - model=model, - args=training_args, - train_dataset=dataset, - tokenizer=tokenizer, - callbacks=[CustomLoggingCallback()], -) -trainer.train() +trainer = DPOTrainer(..., callbacks=[CustomLoggingCallback()]) ``` ## Add custom evaluation metrics @@ -154,33 +91,15 @@ trainer.train() You can define custom evaluation metrics to track during training. This is useful for monitoring model performance on specific tasks. ```python -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer -from trl import DPOConfig, DPOTrainer - - def compute_metrics(eval_preds): - # Custom metric computation logits, labels = eval_preds # Add your metric computation here return {"custom_metric": 0.0} -model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") -eval_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="test[:10%]") -training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO", eval_strategy="steps", eval_steps=100) +training_args = DPOConfig(..., eval_strategy="steps", eval_steps=100) -trainer = DPOTrainer( - model=model, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - compute_metrics=compute_metrics, -) -trainer.train() +trainer = DPOTrainer(..., eval_dataset=eval_dataset, compute_metrics=compute_metrics) ``` ## Use mixed precision training @@ -188,24 +107,8 @@ trainer.train() Mixed precision training can significantly speed up training and reduce memory usage. You can enable it by setting `bf16=True` or `fp16=True` in the training config. ```python -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer -from trl import DPOConfig, DPOTrainer - -model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") - # Use bfloat16 precision (recommended for modern GPUs) -training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO", bf16=True) - -trainer = DPOTrainer( - model=model, - args=training_args, - train_dataset=dataset, - tokenizer=tokenizer, -) -trainer.train() +training_args = DPOConfig(..., bf16=True) ``` Note: Use `bf16=True` for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True` for older GPUs. @@ -215,28 +118,12 @@ Note: Use `bf16=True` for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True` When training with limited GPU memory, gradient accumulation allows you to simulate larger batch sizes by accumulating gradients over multiple steps before updating weights. ```python -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer -from trl import DPOConfig, DPOTrainer - -model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") - # Simulate a batch size of 32 with per_device_train_batch_size=4 and gradient_accumulation_steps=8 training_args = DPOConfig( - output_dir="Qwen2.5-0.5B-DPO", + ..., per_device_train_batch_size=4, gradient_accumulation_steps=8, ) - -trainer = DPOTrainer( - model=model, - args=training_args, - train_dataset=dataset, - tokenizer=tokenizer, -) -trainer.train() ``` ## Use a custom data collator @@ -244,25 +131,9 @@ trainer.train() You can provide a custom data collator to handle special data preprocessing or padding strategies. ```python -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer -from trl import DPOConfig, DPOTrainer from trl.trainer.dpo_trainer import DataCollatorForPreference -model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") -training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") - -# Create a custom data collator with specific padding token data_collator = DataCollatorForPreference(pad_token_id=tokenizer.pad_token_id) -trainer = DPOTrainer( - model=model, - args=training_args, - train_dataset=dataset, - tokenizer=tokenizer, - data_collator=data_collator, -) -trainer.train() +trainer = DPOTrainer(..., data_collator=data_collator) ``` From 6edc64c2b278fe4c25e3dc4b5723c3846a1980eb Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Mon, 1 Dec 2025 10:33:56 +0100 Subject: [PATCH 3/4] Update docs/source/customization.md --- docs/source/customization.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/customization.md b/docs/source/customization.md index b67213e4d9a..3c6c6b0bf96 100644 --- a/docs/source/customization.md +++ b/docs/source/customization.md @@ -1,6 +1,9 @@ # Training customization -TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are examples on how you can apply and test different techniques. Note: Although these examples use the [`DPOTrainer`], these customization methods apply to most (if not all) trainers in TRL. +TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are examples on how you can apply and test different techniques. + +> [!NOTE] +> Although these examples use the [`DPOTrainer`], these customization methods apply to most (if not all) trainers in TRL. ## Use different optimizers and schedulers From 93e2c71d132fbfe850f9dea297126e7fead5f6a5 Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Mon, 1 Dec 2025 10:34:03 +0100 Subject: [PATCH 4/4] Update docs/source/customization.md --- docs/source/customization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/customization.md b/docs/source/customization.md index 3c6c6b0bf96..19ba1088fd1 100644 --- a/docs/source/customization.md +++ b/docs/source/customization.md @@ -61,7 +61,7 @@ trainer = DPOTrainer(..., ref_model=ref_model) Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning. -Read more about 8-bit model loading in `transformers` [Load in 8bit or 4bit](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit). +Read more about 8-bit model loading in `transformers` [Load in 8bit or 4bit](https://huggingface.co/docs/transformers/en/peft). ```python from transformers import AutoModelForCausalLM, BitsAndBytesConfig