Skip to content

Switch gradient checkpointing default to use_reentrant=False (PyTorch recommended)#4811

Merged
qgallouedec merged 11 commits intomainfrom
non-reentrant
Jan 14, 2026
Merged

Switch gradient checkpointing default to use_reentrant=False (PyTorch recommended)#4811
qgallouedec merged 11 commits intomainfrom
non-reentrant

Conversation

@qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Jan 12, 2026

Summary

Set use_reentrant=False by default

This PR defaults gradient checkpointing to use_reentrant=False in TRL when no explicit value is provided.

PyTorch now recommends the non-reentrant checkpointing variant, see https://docs.pytorch.org/docs/stable/checkpoint.html. However, Transformers still defaults to use_reentrant=True because it was explicitly pinned in the past to silence a PyTorch warning during a transition period, and the default was never updated afterward.

Until this is fixed upstream and released (see huggingface/transformers#43203), TRL aligns with the current PyTorch recommendation by setting use_reentrant=False by default, while fully preserving any user-provided value.

FixesExpected to mark a variable ready only once

This PR fixes an issue that seems unrelated, but is: #4782

Remove ScriptArguments.gradient_checkpointing_use_reentrant

ScriptArguments.gradient_checkpointing_use_reentrant exists but is never used. This is misleading, so this PR removes this argument.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! I have a couple of questions below.

Additionally, there are 2 new CI errors that I think we should check if they are caused by this PR and then if they should be fixed as well:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

  FAILED tests/experimental/test_nash_md_trainer.py::TestNashMDTrainer::test_training_pre_pefted_model_implicit_ref_with_reward_model - RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
  FAILED tests/experimental/test_xpo_trainer.py::TestXPOTrainer::test_training_pre_pefted_model_implicit_ref - RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

@qgallouedec
Copy link
Member Author

Additionally, there are 2 new CI errors that I think we should check if they are caused by this PR and then if they should be fixed as well:

Fixed, enable_input_require_grads() is needed not only for reentrant.

@qgallouedec qgallouedec merged commit 40d8759 into main Jan 14, 2026
10 of 14 checks passed
@qgallouedec qgallouedec deleted the non-reentrant branch January 14, 2026 14:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RuntimeError: Expected to mark a variable ready only once when training with PEFT

5 participants