Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7e5c7d1
Switch gradient checkpointing default to use_reentrant=False (PyTorch…
qgallouedec Jan 9, 2026
7ea2cbe
All testers
qgallouedec Jan 9, 2026
340224c
revert
qgallouedec Jan 9, 2026
68c0711
Merge branch 'main' into no-reentrant
qgallouedec Jan 9, 2026
d642e52
skip for the right reason
qgallouedec Jan 9, 2026
1210ec2
skip for right reason and this one is fixed
qgallouedec Jan 9, 2026
453d5cc
skip for the right reason
qgallouedec Jan 9, 2026
4356551
fix up to blip2 excluded
qgallouedec Jan 9, 2026
4ac2cee
reactivate all tests
qgallouedec Jan 10, 2026
71df8d6
skip non trainable and xfail know issues
qgallouedec Jan 11, 2026
d415471
audio encoder after lm initialization and add xfail tests for gradien…
qgallouedec Jan 11, 2026
09a01fc
continue fixing tests
qgallouedec Jan 11, 2026
aa7ed85
some last cases
qgallouedec Jan 11, 2026
1e2c55c
Merge branch 'main' into no-reentrant
qgallouedec Jan 11, 2026
c2688be
Merge branch 'no-reentrant' of https://github.com/huggingface/transfo…
qgallouedec Jan 11, 2026
525c1f3
fix check_training_gradient_checkpointing
qgallouedec Jan 11, 2026
d2e1b8f
ignore and x fail other tests
qgallouedec Jan 11, 2026
20ec219
this module doesn't support training
qgallouedec Jan 11, 2026
3d090e1
forgot one
qgallouedec Jan 11, 2026
0dbb70f
mra
qgallouedec Jan 11, 2026
e0535d8
style
qgallouedec Jan 11, 2026
314f63c
fix recurrent gemma when enable input require grads with gradient che…
qgallouedec Jan 11, 2026
da4c49b
mark training test as expected failure for reentrant compatibility
qgallouedec Jan 11, 2026
a4603d5
style
qgallouedec Jan 11, 2026
8444141
reverting, moshi won't work anyway
qgallouedec Jan 11, 2026
09012a4
remove skip training tests for Siglip2VisionModel
qgallouedec Jan 12, 2026
8c09f17
Merge branch 'main' into no-reentrant
qgallouedec Jan 12, 2026
183d3a8
Merge branch 'main' into no-reentrant
qgallouedec Jan 12, 2026
36cb28d
Merge branch 'main' into no-reentrant
SunMarc Jan 12, 2026
c7bf3e2
Merge branch 'main' into no-reentrant
SunMarc Jan 12, 2026
a9c9a66
Merge branch 'main' into no-reentrant
SunMarc Jan 13, 2026
1926ac8
Merge branch 'main' into no-reentrant
SunMarc Jan 14, 2026
75655a8
Merge branch 'main' into no-reentrant
SunMarc Jan 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3063,7 +3063,7 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_kwargs = {"use_reentrant": False}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main change


gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,8 @@ def forward(
x_branch = x_branch.unsqueeze(-1)
self.conv1d_state = conv_state[:, :, 1:]
else:
self.conv1d_state = None
self.rg_lru.recurrent_states = None
Comment on lines +490 to +491
Copy link
Member Author

@qgallouedec qgallouedec Jan 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clearing recurrent cache state when use_cache=False so the recurrent block doesn’t reuse graph-attached state between forwards, which fixes the double-backward error (see below) in the gradient checkpointing tests.

Result before the change:

$ pytest tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py::RecurrentGemmaModelTest::test_training_gradient_checkpointing
================================================ test session starts ================================================
platform linux -- Python 3.12.12, pytest-8.4.2, pluggy-1.6.0
rootdir: /fsx/qgallouedec/transformers
configfile: pyproject.toml
plugins: timeout-2.4.0, asyncio-1.3.0, rich-0.2.0, anyio-4.12.0, cov-7.0.0, xdist-3.8.0, hypothesis-6.150.0, order-1.3.0, rerunfailures-15.1
asyncio: mode=Mode.STRICT, debug=False, asyncio_default_fixture_loop_scope=function, asyncio_default_test_loop_scope=function
collected 1 item                                                                                                    

tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py::RecurrentGemmaModelTest::test_training_gradient_checkpointing FAILED [100%]

===================================================== FAILURES ======================================================
___________________________ RecurrentGemmaModelTest.test_training_gradient_checkpointing ____________________________

self = <tests.models.recurrent_gemma.test_modeling_recurrent_gemma.RecurrentGemmaModelTest testMethod=test_training_gradient_checkpointing>

    def test_training_gradient_checkpointing(self):
        # Scenario - 1 default behaviour
>       self.check_training_gradient_checkpointing()

tests/test_modeling_common.py:1620: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/test_modeling_common.py:1572: in check_training_gradient_checkpointing
    loss.backward()
../miniconda3/envs/trl/lib/python3.12/site-packages/torch/_tensor.py:625: in backward
    torch.autograd.backward(
../miniconda3/envs/trl/lib/python3.12/site-packages/torch/autograd/__init__.py:354: in backward
    _engine_run_backward(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

t_outputs = (tensor(4.6717, device='cuda:0', grad_fn=<NllLossBackward0>),)
args = ((tensor(1., device='cuda:0'),), False, False, ())
kwargs = {'accumulate_grad': True, 'allow_unreachable': True}, attach_logging_hooks = False

    def _engine_run_backward(
        t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
        *args: Any,
        **kwargs: Any,
    ) -> tuple[torch.Tensor, ...]:
        attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
        if attach_logging_hooks:
            unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
        try:
>           return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
                t_outputs, *args, **kwargs
            )  # Calls into the C++ engine to run the backward pass
E           RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

../miniconda3/envs/trl/lib/python3.12/site-packages/torch/autograd/graph.py:841: RuntimeError
============================================== short test summary info ==============================================
FAILED tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py::RecurrentGemmaModelTest::test_training_gradient_checkpointing - RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they ha...
================================================ 1 failed in 10.78s =================================================

x_branch = self.conv_1d(x_branch)[..., :seq_len]

x_branch = self.rg_lru(x_branch.transpose(1, 2), position_ids)
Expand Down
16 changes: 6 additions & 10 deletions tests/models/align/test_modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,24 +344,20 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

@unittest.skip
@unittest.skip(reason="This module does not support standalone training")
def test_training(self):
pass

@unittest.skip
@unittest.skip(reason="This module does not support standalone training")
def test_training_gradient_checkpointing(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
@unittest.skip(reason="This module does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
@unittest.skip(reason="This module does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_true(self):
pass

@unittest.skip(reason="ALIGN does not use inputs_embeds")
Expand Down
32 changes: 12 additions & 20 deletions tests/models/altclip/test_modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,24 +172,20 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

@unittest.skip
@unittest.skip(reason="This module does not support standalone training")
def test_training(self):
pass

@unittest.skip
@unittest.skip(reason="This module does not support standalone training")
def test_training_gradient_checkpointing(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
@unittest.skip(reason="This module does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
@unittest.skip(reason="This module does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_true(self):
pass

@unittest.skip(reason="AltCLIPVisionModel use the same cv backbone with CLIP model.")
Expand Down Expand Up @@ -309,24 +305,20 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

@unittest.skip
@unittest.skip(reason="This module does not support standalone training")
def test_training(self):
pass

@unittest.skip
@unittest.skip(reason="This module does not support standalone training")
def test_training_gradient_checkpointing(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
@unittest.skip(reason="This module does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
@unittest.skip(reason="This module does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_true(self):
pass

def test_model_outputs_equivalence(self):
Expand Down
17 changes: 9 additions & 8 deletions tests/models/aria/test_modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import unittest

import pytest
import requests

from transformers import (
Expand Down Expand Up @@ -197,23 +198,23 @@ def setUp(self):
self.model_tester = AriaVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=AriaConfig, has_text_modality=False)

@unittest.skip(
@pytest.mark.xfail(
reason="This architecture seems to not compute gradients for the last vision-layernorm because the model uses hidden states pre-norm"
)
def test_training_gradient_checkpointing(self):
pass
super().test_training_gradient_checkpointing()

@unittest.skip(
@pytest.mark.xfail(
reason="This architecture seems to not compute gradients for the last vision-layernorm because the model uses hidden states pre-norm"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
def test_training_gradient_checkpointing_use_reentrant_false(self):
super().test_training_gradient_checkpointing_use_reentrant_false()

@unittest.skip(
@pytest.mark.xfail(
reason="This architecture seems to not compute gradients for the last vision-layernorm because the model uses hidden states pre-norm"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
def test_training_gradient_checkpointing_use_reentrant_true(self):
super().test_training_gradient_checkpointing_use_reentrant_true()


SKIP = False
Expand Down
21 changes: 8 additions & 13 deletions tests/models/autoformer/test_modeling_autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tempfile
import unittest

import pytest
from huggingface_hub import hf_hub_download

from transformers import is_torch_available
Expand Down Expand Up @@ -242,23 +243,17 @@ def test_encoder_decoder_model_standalone(self):
def test_resize_tokens_embeddings(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
@pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
def test_training_gradient_checkpointing(self):
pass
super().test_training_gradient_checkpointing()

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
@pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
super().test_training_gradient_checkpointing_use_reentrant_false()

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
def test_training_gradient_checkpointing_use_reentrant_true(self):
super().test_training_gradient_checkpointing_use_reentrant_true()

# # Input is 'static_categorical_features' not 'input_ids'
def test_model_main_input_name(self):
Expand Down
16 changes: 8 additions & 8 deletions tests/models/aya_vision/test_modeling_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,17 +184,17 @@ def test_config(self):
def test_training(self):
pass

@unittest.skip(reason="SiglipVisionModel does not support standalone training")
@pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
def test_training_gradient_checkpointing(self):
pass

@unittest.skip(reason="SiglipVisionModel does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
super().test_training_gradient_checkpointing()

@unittest.skip(reason="SiglipVisionModel does not support standalone training")
@pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
super().test_training_gradient_checkpointing_use_reentrant_false()

@pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
def test_training_gradient_checkpointing_use_reentrant_true(self):
super().test_training_gradient_checkpointing_use_reentrant_true()

@unittest.skip(reason="Compile not yet supported because in LLava models")
@pytest.mark.torch_compile_test
Expand Down
16 changes: 2 additions & 14 deletions tests/models/beit/test_modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def test_training(self):
loss = model(**inputs).loss
loss.backward()

def test_training_gradient_checkpointing(self):
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training:
self.skipTest(reason="model_tester.is_training is set to False")
Expand All @@ -362,25 +362,13 @@ def test_training_gradient_checkpointing(self):
continue

model = model_class(config)
model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@slow
def test_model_from_pretrained(self):
model_name = "microsoft/beit-base-patch16-224"
Expand Down
24 changes: 10 additions & 14 deletions tests/models/big_bird/test_modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import unittest

import pytest

from transformers import BigBirdConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer
Expand Down Expand Up @@ -579,23 +581,17 @@ def test_for_change_to_full_attn(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs)

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
@pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
def test_training_gradient_checkpointing(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
super().test_training_gradient_checkpointing()

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
@pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
super().test_training_gradient_checkpointing_use_reentrant_false()

@pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
def test_training_gradient_checkpointing_use_reentrant_true(self):
super().test_training_gradient_checkpointing_use_reentrant_true()


@require_torch
Expand Down
Loading