Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
258 commits
Select commit Hold shift + click to select a range
ea21e98
wip
qgallouedec Aug 17, 2025
a6941ea
Merge branch 'main' into refactor-dpo
qgallouedec Oct 7, 2025
a7aab5a
progress
qgallouedec Oct 8, 2025
6da159e
ref log p
qgallouedec Oct 8, 2025
d2f5227
precompute
qgallouedec Oct 9, 2025
a68ea0f
Merge branch 'main' into refactor-dpo
qgallouedec Oct 10, 2025
fcf62d1
hinge loss
qgallouedec Oct 10, 2025
36513d0
fix default
qgallouedec Oct 11, 2025
f023e65
fix precompute
qgallouedec Oct 11, 2025
aefc01b
disable dropout
qgallouedec Oct 11, 2025
d2aa8f3
move to exp
qgallouedec Oct 11, 2025
4b92f64
progress
qgallouedec Oct 11, 2025
ea664ba
fix hash, disable dropout and other changes
qgallouedec Oct 15, 2025
1252bb1
Merge branch 'main' into refactor-dpo
qgallouedec Nov 27, 2025
ebc7631
Merge branch 'main' into refactor-dpo
qgallouedec Dec 5, 2025
fda430d
Merge branch 'main' into refactor-dpo
qgallouedec Dec 12, 2025
c321700
progress
qgallouedec Dec 12, 2025
a8fc02c
First step
qgallouedec Dec 12, 2025
8dcb6d3
other alignements
qgallouedec Dec 12, 2025
45e67ae
same for RLOO
qgallouedec Dec 12, 2025
1d8b788
minor alignements
qgallouedec Dec 12, 2025
14b7525
style
qgallouedec Dec 13, 2025
e896cb6
test collator
qgallouedec Dec 13, 2025
aed0a88
robust
qgallouedec Dec 15, 2025
e1c46fb
fix
qgallouedec Dec 15, 2025
473cab5
fixes
qgallouedec Dec 15, 2025
c8a8b40
lr
qgallouedec Dec 15, 2025
66dbb27
ipo and exo_pair
qgallouedec Dec 15, 2025
aa8d6d7
Merge branch 'main' into refactor-dpo
qgallouedec Dec 15, 2025
7b8fd87
harmony dataset
qgallouedec Dec 15, 2025
e684742
thinking in rejected
qgallouedec Dec 15, 2025
9cc3f4e
nca_pair
qgallouedec Dec 15, 2025
7130bbf
bco pair
qgallouedec Dec 15, 2025
9cf7cc8
aot and aot_paired
qgallouedec Dec 16, 2025
dbfcdc5
apo_zero and apo down
qgallouedec Dec 16, 2025
22a6af6
discopop
qgallouedec Dec 16, 2025
2d249c1
contrib in doc
qgallouedec Dec 16, 2025
3fabfa4
sft loss + deprecated max prompt/completion length
qgallouedec Dec 16, 2025
e0cd28a
Merge branch 'main' into refactor-dpo
qgallouedec Dec 16, 2025
81ce5fe
keep start
qgallouedec Dec 16, 2025
412d8b5
some modification
qgallouedec Dec 16, 2025
29deff8
add "Parameters that need to be implemented"
qgallouedec Dec 16, 2025
cfa6f30
Merge branch 'main' into refactor-dpo
qgallouedec Dec 16, 2025
dfac706
Deprecate ref_model_init_kwargs and update DPOTrainer to use it
qgallouedec Dec 16, 2025
d757440
Deprecate generate_during_eval parameter in DPOConfig
qgallouedec Dec 17, 2025
26d6c7d
Remove ref_model_init_kwargs from DPOConfig
qgallouedec Dec 17, 2025
1a6fce0
Disallow PeftModel + peft_config in trainers
qgallouedec Dec 17, 2025
07f3a3a
remove tests
qgallouedec Dec 17, 2025
f97f2cc
remove old comments
qgallouedec Dec 17, 2025
4aeb899
Disallow passing PeftModel with peft_config in DPOTrainer
qgallouedec Dec 17, 2025
90c06e5
Merge branch 'main' into refactor-dpo
qgallouedec Dec 18, 2025
8c2c7ff
explicit ref model
qgallouedec Dec 18, 2025
1f579dc
deprecate force_use_ref_model
qgallouedec Dec 18, 2025
ae2646c
style
qgallouedec Dec 18, 2025
f83cf3c
deprecate force_use_ref_model and use_logits_to_keep parameters in DP…
qgallouedec Dec 19, 2025
66f96f8
proper peft integration
qgallouedec Dec 19, 2025
ef86117
proper peft integration
qgallouedec Dec 19, 2025
3c49e91
Merge branch 'main' into refactor-dpo
qgallouedec Dec 22, 2025
9345f5c
Deprecate model_adapter_name and ref_adapter_name parameters in DPOCo…
qgallouedec Dec 22, 2025
fd99c1e
fix type hint
qgallouedec Dec 23, 2025
8d3a61c
WPO
qgallouedec Dec 23, 2025
5a8f8fc
Merge branch 'main' into refactor-dpo
qgallouedec Dec 23, 2025
41bc3d3
Merge branch 'main' into refactor-dpo
qgallouedec Dec 23, 2025
19f21cd
style
qgallouedec Dec 23, 2025
5b7c446
Deprecate `label_pad_token_id` in `DPOConfig`
qgallouedec Dec 23, 2025
b615fca
loss weights
qgallouedec Dec 23, 2025
b6ee841
f divergence
qgallouedec Dec 23, 2025
e74ec7e
Add LD-DPO support with ld_alpha parameter in DPOConfig and DPOTrainer
qgallouedec Dec 23, 2025
408db53
some greek letter fixes
qgallouedec Dec 23, 2025
2dc1dd8
fix param name
qgallouedec Dec 23, 2025
fe48a17
deprecated tools
qgallouedec Dec 23, 2025
b341619
start liger integration
qgallouedec Dec 24, 2025
ae841e3
Deprecate reference_free parameter in DPOConfig and update related wa…
qgallouedec Dec 24, 2025
b15afbd
start liger
qgallouedec Jan 5, 2026
24e451b
Merge branch 'main' into refactor-dpo
qgallouedec Jan 6, 2026
6e9e1bf
Merge branch 'main' into refactor-dpo
qgallouedec Jan 6, 2026
ae8a03f
2026
qgallouedec Jan 6, 2026
5fbcfbc
Merge branch 'main' into refactor-dpo
qgallouedec Jan 9, 2026
40e8d51
Merge branch 'main' into refactor-dpo
qgallouedec Jan 12, 2026
0a6a209
Merge branch 'main' into refactor-dpo
qgallouedec Jan 12, 2026
90631a3
Merge branch 'main' into refactor-dpo
qgallouedec Jan 12, 2026
59763b8
Merge branch 'main' into refactor-dpo
qgallouedec Jan 13, 2026
2339198
Merge branch 'main' into refactor-dpo
qgallouedec Jan 13, 2026
7adb54e
Deprecate use_liger_loss parameter in DPOConfig and update related wa…
qgallouedec Jan 13, 2026
d4b35e5
Merge branch 'main' into refactor-dpo
qgallouedec Jan 14, 2026
a1d9d1a
Merge branch 'main' into refactor-dpo
qgallouedec Jan 14, 2026
8792365
Merge branch 'main' into refactor-dpo
qgallouedec Jan 16, 2026
de5f182
Add Iterative Reasoning Preference Optimization paper and update DPOC…
qgallouedec Jan 16, 2026
8b975ef
Merge branch 'main' into refactor-dpo
qgallouedec Jan 19, 2026
facccbc
Refactor learning rate comments and add validation for sync_ref_model…
qgallouedec Jan 20, 2026
dce26cd
Add ref_model_mixup_alpha and ref_model_sync_steps parameters to DPOC…
qgallouedec Jan 20, 2026
337ddac
update comment
qgallouedec Jan 20, 2026
d326bc2
precompute (wip)
qgallouedec Jan 20, 2026
5b11c71
qol
qgallouedec Jan 20, 2026
4d49b8c
precompute
qgallouedec Jan 23, 2026
9466af1
revert test
qgallouedec Jan 23, 2026
a334b76
is chat for processed
qgallouedec Jan 23, 2026
7a5a866
Merge branch 'main' into refactor-dpo
qgallouedec Jan 25, 2026
09d3f42
fix dtype in test
qgallouedec Jan 25, 2026
8af36f4
prediction_step for eval
qgallouedec Jan 25, 2026
164aab3
preference tool call dataset
qgallouedec Jan 25, 2026
1072de3
vlm support
qgallouedec Jan 25, 2026
cab3245
fix vlm collator
qgallouedec Jan 26, 2026
261627c
style
qgallouedec Jan 26, 2026
07b3632
fill example
qgallouedec Jan 26, 2026
2bd02e6
Merge branch 'main' into refactor-dpo
qgallouedec Jan 26, 2026
1dd9af1
doc
qgallouedec Jan 27, 2026
3d5f3ab
fix doc
qgallouedec Jan 27, 2026
f75b31d
dpo doc ready!
qgallouedec Jan 27, 2026
4380fd8
hide sidebar
qgallouedec Jan 27, 2026
01b788c
`,`
qgallouedec Jan 27, 2026
6eac594
RLAIF-V-Dataset
qgallouedec Jan 27, 2026
5b83fa8
Update documentation, scripts, and test with experimental
qgallouedec Jan 27, 2026
24c8582
Enhance DPOConfig documentation with detailed parameter descriptions …
qgallouedec Jan 27, 2026
045b853
Add precomputation options for reference model log probabilities in D…
qgallouedec Jan 27, 2026
913d757
removing unnecessary keyword argument
qgallouedec Jan 27, 2026
5a78275
Merge branch 'main' into refactor-dpo
qgallouedec Jan 27, 2026
ab79ec4
Merge branch 'main' into refactor-dpo
qgallouedec Jan 28, 2026
3e3b5c6
comment style
qgallouedec Jan 28, 2026
87e5c56
revert space
qgallouedec Jan 28, 2026
847dbbe
revert another space
qgallouedec Jan 28, 2026
20819a2
align comments
qgallouedec Jan 28, 2026
7736ec3
gradient_checkpointing=True is default
qgallouedec Jan 28, 2026
41716cd
align comments
qgallouedec Jan 28, 2026
c9bdbe2
Merge branch 'main' into refactor-dpo
qgallouedec Jan 28, 2026
3058500
Merge branch 'main' into refactor-dpo
qgallouedec Jan 28, 2026
5161288
Add tests for hash_module
qgallouedec Jan 28, 2026
86c4b9c
revert import change
qgallouedec Jan 28, 2026
375faa9
legacy tests
qgallouedec Jan 28, 2026
129b198
move experimental implementation to stable
qgallouedec Jan 28, 2026
f2b7242
fix imports
qgallouedec Jan 28, 2026
db871c9
ignore TestTokenizeRow
qgallouedec Jan 28, 2026
5c31d52
ignore ruff for legacy test + apply style
qgallouedec Jan 28, 2026
d03f191
remove `prompt_input_ids` access
qgallouedec Jan 28, 2026
7c0239d
gc kwargs
qgallouedec Jan 28, 2026
c785763
just ignore the relevant part of the test
qgallouedec Jan 28, 2026
0cd486e
ignore same lenght catching error
qgallouedec Jan 28, 2026
79caf4d
comment `test_dpo_loss_alpha_div_f`
qgallouedec Jan 28, 2026
cb513ad
comment test_dpo_loss_js_div_f
qgallouedec Jan 28, 2026
8a7b044
comment test_dpo_trainer_use_logits_to_keep
qgallouedec Jan 28, 2026
52377d6
logits to keep none by default
qgallouedec Jan 28, 2026
ee15d87
ValueError when ref_model is model
qgallouedec Jan 28, 2026
cde37e7
generate during eval deprecated
qgallouedec Jan 28, 2026
da9794e
Deprecate string usage for ref_model and update initialization logic
qgallouedec Jan 28, 2026
5b76884
comment access to prompt_input_ids
qgallouedec Jan 28, 2026
dd65760
don't check training_args.f_divergence_type type
qgallouedec Jan 29, 2026
046a555
support compute_metrics
qgallouedec Jan 29, 2026
a5d46dc
Implement Liger kernel compatibility checks and restrict unsupported …
qgallouedec Jan 29, 2026
4fd67bf
fp16 -> bf16 in qlora test
qgallouedec Jan 29, 2026
1d9b18a
fix link in doc
qgallouedec Jan 29, 2026
d866569
remove test_collators.py file
qgallouedec Jan 29, 2026
2f901ee
support tool in is_conversational
qgallouedec Jan 29, 2026
06473ea
Add conversational examples with tool calls to TestIsConversational
qgallouedec Jan 29, 2026
d7c1403
Add tests for training with compute metrics
qgallouedec Jan 29, 2026
db5812b
Update DPOTrainer config with higher learning rate and enable test in…
qgallouedec Jan 29, 2026
222e7d2
remove redundant cases
qgallouedec Jan 29, 2026
c79af1b
Update test parameters for DPOTrainer to address memory issues and ad…
qgallouedec Jan 29, 2026
94b09c8
align/fix test
qgallouedec Jan 29, 2026
b763453
fix dataset generation
qgallouedec Jan 29, 2026
f9ff05a
Merge branch 'main' into refactor-dpo
qgallouedec Jan 29, 2026
7a530ba
align legacy tests
qgallouedec Jan 29, 2026
7b49f02
Merge branch 'main' into refactor-dpo
qgallouedec Jan 29, 2026
256e121
Refactor test configurations in DPO and SFT trainers to improve clari…
qgallouedec Jan 29, 2026
be36e7d
fix robust
qgallouedec Jan 30, 2026
59e70b3
better names
qgallouedec Jan 30, 2026
c97e844
ipo: logits instead of delta-score
qgallouedec Jan 30, 2026
2634ec8
Revert "ipo: logits instead of delta-score"
qgallouedec Jan 30, 2026
1a97e58
Normalize IPO loss by completion length
qgallouedec Jan 30, 2026
dd1e074
Fix ipo normalization
qgallouedec Jan 30, 2026
86c5a27
Merge branch 'main' into refactor-dpo
qgallouedec Feb 3, 2026
24d0e78
Remove commented-out batch size adjustments in SFTTrainer tests
qgallouedec Feb 2, 2026
267ca1e
revert
qgallouedec Feb 3, 2026
0a27e3c
align compute_metrics tests
qgallouedec Feb 3, 2026
0f86f3f
align tests
qgallouedec Feb 3, 2026
ec87e99
remove estimat token
qgallouedec Feb 3, 2026
9788fa7
see https://github.com/huggingface/trl/pull/3950
qgallouedec Feb 3, 2026
1769d5a
Merge branch 'main' into refactor-dpo
qgallouedec Feb 3, 2026
48034d0
Update trl/trainer/dpo_trainer.py
qgallouedec Feb 3, 2026
13b8afe
Deprecate `ref_adapter_name` parameter in DPOConfig class
qgallouedec Feb 3, 2026
b61e4e3
memory efficient use_weighting
qgallouedec Feb 4, 2026
49d7faf
Update f_alpha_divergence_coef default value to 0.5 in DPOConfig and …
qgallouedec Feb 4, 2026
963d046
stable alpha_divergence
qgallouedec Feb 4, 2026
5a12cd9
Merge branch 'main' into refactor-dpo
qgallouedec Feb 4, 2026
1589da6
fix text-only vlm training
qgallouedec Feb 4, 2026
be4a3ee
Update test_dpo_trainer.py
qgallouedec Feb 4, 2026
bf53e4a
Merge branch 'main' into refactor-dpo
qgallouedec Feb 6, 2026
3b093f9
Update dataset configuration name for toolcall dataset loading
qgallouedec Feb 6, 2026
d33fb7d
Fix add_column in test_train_with_chat_template_kwargs
qgallouedec Feb 6, 2026
d659891
Remove max_prompt_length and max_completion_length parameters from DP…
qgallouedec Feb 6, 2026
c319d5a
Replace torch.allclose with torch.testing.assert_close in DPOTrainer …
qgallouedec Feb 6, 2026
7250cc3
Disallow installation of transformers 5.1.0 due to compatibility issu…
qgallouedec Feb 6, 2026
7cb4d72
try higher learning rate and smaller batch size
qgallouedec Feb 6, 2026
ecf7241
Update parameter comparison logic in DPOTrainer tests to exclude 'ref…
qgallouedec Feb 6, 2026
9ae192c
move some utils to experimental
qgallouedec Feb 6, 2026
77f18f2
fix deprecation version
qgallouedec Feb 6, 2026
c41bc67
fix max_length type in docstring
qgallouedec Feb 6, 2026
6d14396
Assert chat_template is applied in test_train_with_chat_template_kwar…
qgallouedec Feb 6, 2026
69d2f88
Merge branch 'main' into refactor-dpo
qgallouedec Feb 6, 2026
2fe89c3
Merge branch 'main' into refactor-dpo
qgallouedec Feb 6, 2026
b958bf2
Fix post_init warning stacklevel to 3#4993
qgallouedec Feb 6, 2026
be82376
Pin transformers!=5.1.0 in deepspeed extra due to incompatibility (#4…
qgallouedec Feb 6, 2026
69c2c26
Merge branch 'main' into refactor-dpo
qgallouedec Feb 6, 2026
8c64814
style and merge main
qgallouedec Feb 6, 2026
76705c7
Merge branch 'main' into refactor-dpo
qgallouedec Feb 9, 2026
0c57ff3
update old tests
qgallouedec Feb 9, 2026
4691da2
remove deprecated
qgallouedec Feb 9, 2026
f52a407
remove deprecated aot_pair and use_liger_loss
qgallouedec Feb 9, 2026
39cee58
remove deprecated FDivergenceType and update related code to use stri…
qgallouedec Feb 10, 2026
3351691
style
qgallouedec Feb 10, 2026
254ed01
remove deprecated handling of ref_model as a string in DPOTrainer
qgallouedec Feb 10, 2026
d9e95b7
comment out deprecated test for DPOTrainer with tools
qgallouedec Feb 10, 2026
4748d49
revert section removal
qgallouedec Feb 10, 2026
df97352
Merge branch 'main' into refactor-dpo
qgallouedec Feb 10, 2026
ecaa25f
finish merge main
qgallouedec Feb 10, 2026
e5b9294
Remove deprecated CPO imports from trainer module
qgallouedec Feb 10, 2026
7bdbd56
Merge branch 'main' into refactor-dpo
qgallouedec Feb 10, 2026
7fbe52f
Merge branch 'main' into refactor-dpo
qgallouedec Feb 11, 2026
5eb0548
move peft_module_casting_to_bf16 to epxerimental
qgallouedec Feb 12, 2026
68f7975
removed from model utils
qgallouedec Feb 12, 2026
574a50d
Merge branch 'main' into refactor-dpo
qgallouedec Feb 16, 2026
b1964e4
move create_reference_model
qgallouedec Feb 12, 2026
46e05e3
sort papers
qgallouedec Feb 16, 2026
debb233
docs: update paper index to remove construction warning and clarify s…
qgallouedec Feb 16, 2026
801481f
Merge branch 'main' into refactor-dpo
qgallouedec Feb 16, 2026
d73f4eb
Replace logging with warnings for FSDP version warning in prepare_fsdp()
qgallouedec Feb 16, 2026
7596cf0
Merge branch 'main' into refactor-dpo
qgallouedec Feb 17, 2026
360a201
Merge branch 'main' into refactor-dpo
qgallouedec Feb 18, 2026
95a81a0
Merge branch 'main' into refactor-dpo
qgallouedec Feb 18, 2026
b18523b
Merge branch 'main' into refactor-dpo
qgallouedec Feb 19, 2026
51ffe84
Apply suggestions from code review
qgallouedec Feb 19, 2026
0dc5b95
Fix SFTTrainer support for single-image data
qgallouedec Feb 19, 2026
ea424ff
same for dpo trainer test
qgallouedec Feb 19, 2026
a3835fc
style
qgallouedec Feb 19, 2026
0d5202f
style
qgallouedec Feb 19, 2026
94726b2
Merge branch 'main' into refactor-dpo
qgallouedec Feb 19, 2026
226125b
Apply suggestions from code review
qgallouedec Feb 19, 2026
9d757ed
Update trl/trainer/dpo_config.py
qgallouedec Feb 19, 2026
cb10b3d
fix sync_ref_model doc
qgallouedec Feb 19, 2026
0ba8960
align
qgallouedec Feb 19, 2026
3d1cb40
Update `DataCollatorForVisionPreference` to support single image input
qgallouedec Feb 19, 2026
9bb0af5
Deprecation warning for `create_reference_model` moved to `trl.experi…
qgallouedec Feb 19, 2026
8286c92
Refactor hash_module to use hashlib for improved hashing
qgallouedec Feb 19, 2026
8287efe
Refactor input truncation logic into a separate method for improved r…
qgallouedec Feb 19, 2026
896cff3
synchronization after save
qgallouedec Feb 19, 2026
d9b6ac2
Fix error message in DataCollatorForVisionPreference
qgallouedec Feb 19, 2026
2e71cfd
fix tests
qgallouedec Feb 20, 2026
1462a8d
revert change in zen image dataset generation (other pr)
qgallouedec Feb 20, 2026
6ebfa4c
update doc
qgallouedec Feb 20, 2026
83c4d4c
Disable padding_free feature temporarily and log a warning for users
qgallouedec Feb 20, 2026
9f50918
revert changes in sft
qgallouedec Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,13 @@ trainer.train()

```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
from trl import DPOTrainer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")

trainer = DPOTrainer(
model=model,
args=training_args,
model="Qwen3/Qwen-0.6B",
train_dataset=dataset,
processing_class=tokenizer
)
trainer.train()
```
Expand Down
16 changes: 3 additions & 13 deletions docs/source/bema_for_reference_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,16 @@ This feature implements the BEMA algorithm to update the reference model during
```python
from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer


pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")

bema_callback = BEMACallback(update_ref_model=True)

model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer.pad_token = tokenizer.eos_token

trainer = DPOTrainer(
model=model,
ref_model=ref_model,
train_dataset=pref_dataset,
processing_class=tokenizer,
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
train_dataset=dataset,
callbacks=[bema_callback],
)

trainer.train()
```

Expand Down
16 changes: 5 additions & 11 deletions docs/source/customization.md
Original file line number Diff line number Diff line change
@@ -1,32 +1,26 @@
# Training customization

TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are examples on how you can apply and test different techniques.
TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are examples on how you can apply and test different techniques.

> [!NOTE]
> Although these examples use the [`DPOTrainer`], these customization methods apply to most (if not all) trainers in TRL.
## Use different optimizers and schedulers

By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows:
By default, the [`DPOTrainer`] creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to [`DPOTrainer`] as follows:

```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch import optim
from trl import DPOConfig, DPOTrainer
from trl import DPOTrainer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")

optimizer = optim.SGD(model.parameters(), lr=training_args.learning_rate)

trainer = DPOTrainer(
model=model,
args=training_args,
model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=dataset,
tokenizer=tokenizer,
optimizers=(optimizer, None),
)
trainer.train()
Expand All @@ -50,7 +44,7 @@ trainer = DPOTrainer(..., optimizers=(optimizer, lr_scheduler))
Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train.

```python
from trl import create_reference_model
from trl.experimental.utils import create_reference_model

ref_model = create_reference_model(model, num_shared_layers=6)

Expand Down
375 changes: 187 additions & 188 deletions docs/source/dpo_trainer.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/example_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ These notebooks are easier to run and are designed for quick experimentation wit

## Scripts

Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) and [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directories. They show how to use different trainers such as `SFTTrainer`, `PPOTrainer`, `DPOTrainer`, `GRPOTrainer`, and more.
Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) and [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directories. They show how to use different trainers such as [`SFTTrainer`], [`PPOTrainer`], [`DPOTrainer`], [`GRPOTrainer`], and more.

| File | Description |
| --- | --- |
Expand Down
1 change: 0 additions & 1 deletion docs/source/lora_without_regret.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ Here are the parameters we used to train the above models
| `--model_name_or_path` | HuggingFaceTB/SmolLM3-3B | HuggingFaceTB/SmolLM3-3B |
| `--dataset_name` | HuggingFaceH4/OpenR1-Math-220k-default-verified | HuggingFaceH4/OpenR1-Math-220k-default-verified |
| `--learning_rate` | 1.0e-5 | 1.0e-6 |
| `--max_prompt_length` | 1024 | 1024 |
| `--max_completion_length` | 4096 | 4096 |
| `--lora_r` | 1 | - |
| `--lora_alpha` | 32 | - |
Expand Down
4 changes: 0 additions & 4 deletions docs/source/model_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,3 @@
## disable_gradient_checkpointing

[[autodoc]] models.utils.disable_gradient_checkpointing

## create_reference_model

[[autodoc]] create_reference_model
Loading
Loading