Skip to content

Conversation

@qgallouedec
Copy link
Member

Summary

This PR changes our gradient checkpointing default from use_reentrant=True to use_reentrant=False.

Two years ago we explicitly set use_reentrant=True in #28538 because PyTorch started warning that the default would change in the future, and recommending users choose a value explicitly:

/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: Warning: 
torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default 
value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass 
use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details 
on the differences between the two variants.

PyTorch warning shown with torch 2.3, see #28536

At the time, defaulting to True was the safest choice to preserve the behavior of earlier releases.

PyTorch now recommends the non-reentrant variant (use_reentrant=False) see, https://docs.pytorch.org/docs/stable/checkpoint.html, and is moving toward making it the default. Aligning with this upstream recommendation gives us several benefits:

Note: training and checkpointing behavior remains functionally equivalent in typical use cases, with the main difference being how activations are recomputed during backward (non-reentrant uses a safer mechanism).

@qgallouedec qgallouedec changed the title Switch gradient checkpointing default to use_reentrant=False (PyTorch… Switch gradient checkpointing default to use_reentrant=False (PyTorch recommended) Jan 9, 2026
@qgallouedec qgallouedec changed the title Switch gradient checkpointing default to use_reentrant=False (PyTorch recommended) Switch gradient checkpointing default to use_reentrant=False (PyTorch recommended) Jan 9, 2026
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines -212 to -227
@unittest.skip
def test_training_gradient_checkpointing(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

Copy link
Member Author

@qgallouedec qgallouedec Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a big number of this ignored test actually pass. I check them all

@github-actions
Copy link
Contributor

github-actions bot commented Jan 9, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: align, altclip, aria, autoformer, aya_vision, beit, big_bird, blip, blip_2, canine, chinese_clip, clap, clip, clipseg, colpali, deit

@github-actions
Copy link
Contributor

github-actions bot commented Jan 9, 2026

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43203&sha=435655

@qgallouedec qgallouedec changed the title Switch gradient checkpointing default to use_reentrant=False (PyTorch recommended) [WIP] Switch gradient checkpointing default to use_reentrant=False (PyTorch recommended) Jan 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants