From 6323d2504df2390e4012544a690dab411e0143dd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 16 Mar 2026 10:17:05 +0000 Subject: [PATCH 1/2] Fix VRAM regression with transformers 5.2+ gradient checkpointing Transformers 5.2 changed the default use_reentrant from True to False in gradient_checkpointing_enable(). This causes the non-reentrant checkpoint path to be taken, completely bypassing UnslothCheckpointFunction which provides smart CPU offloading of gradients. Three fixes: 1. Force use_reentrant=True in unsloth_checkpoint() unconditionally. Since unsloth_checkpoint is only active when smart gradient checkpointing is patched (and the original torch checkpoint is restored on unpatch), this is safe for all code paths including distributed and vision models. 2. Patch transformers.modeling_utils.checkpoint in patch_unsloth_smart_gradient_checkpointing(). Previously only torch.utils.checkpoint.checkpoint was patched, but gradient_checkpointing_enable() creates its partial from the transformers.modeling_utils namespace. This must be done unconditionally (outside the existing name check) because the torch.utils.checkpoint.checkpoint may already be patched while transformers.modeling_utils.checkpoint is not. 3. Restore transformers.modeling_utils.checkpoint in the unpatch function to match the new patching. Backwards compatible with transformers 4.57.x (where use_reentrant defaults to True) since forcing True is a no-op in that case. --- unsloth_zoo/gradient_checkpointing.py | 28 ++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 2d77458a5..fb126c446 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -754,17 +754,11 @@ def unsloth_checkpoint( Returns: Output of running :attr:`function` on :attr:`*args` """ - if use_reentrant is None: - warnings.warn( - "torch.utils.checkpoint: the use_reentrant parameter should be " - "passed explicitly. In version 2.5 we will raise an exception " - "if use_reentrant is not passed. use_reentrant=False is " - "recommended, but if you need to preserve the current default " - "behavior, you can pass use_reentrant=True. Refer to docs for more " - "details on the differences between the two variants.", - stacklevel=2 - ) - use_reentrant = True + # Force use_reentrant=True so UnslothCheckpointFunction (smart CPU offloading) + # is always used. This is safe because unsloth_checkpoint is only active when + # smart GC is patched; when unpatched, the original torch checkpoint is restored. + # Fixes transformers 5.2 which defaults use_reentrant=False, bypassing Unsloth. + use_reentrant = True # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop("preserve_rng_state", True) @@ -805,6 +799,15 @@ def patch_unsloth_smart_gradient_checkpointing(dtype = None): if torch.utils.checkpoint.checkpoint.__name__ != "unsloth_checkpoint": torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint torch.utils.checkpoint.checkpoint = unsloth_checkpoint + + # Always patch transformers.modeling_utils.checkpoint so that + # gradient_checkpointing_enable() wraps unsloth_checkpoint, not the original. + # Without this, transformers 5.2's use_reentrant=False default bypasses + # UnslothCheckpointFunction entirely. + # Must be outside the conditional above since torch.utils.checkpoint.checkpoint + # may already be patched while transformers.modeling_utils.checkpoint is not. + import transformers.modeling_utils + transformers.modeling_utils.checkpoint = unsloth_checkpoint pass @@ -831,6 +834,9 @@ def unpatch_unsloth_smart_gradient_checkpointing(): hasattr(torch.utils.checkpoint, "_old_checkpoint"): torch.utils.checkpoint.checkpoint = torch.utils.checkpoint._old_checkpoint + # Also restore transformers.modeling_utils.checkpoint + import transformers.modeling_utils + transformers.modeling_utils.checkpoint = torch.utils.checkpoint._old_checkpoint pass From 4f853be1de993011e5ae44fe994f59938dd2e4d7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 16 Mar 2026 10:38:05 +0000 Subject: [PATCH 2/2] Fix unpatch ordering: restore transformers.modeling_utils independently When unpatch_unsloth_gradient_checkpointing() runs before unpatch_unsloth_smart_gradient_checkpointing() (as in training_utils.py:200-202), the non-smart unpatch deletes _old_checkpoint and restores torch.utils.checkpoint.checkpoint, causing the smart unpatch's condition to be False. This left transformers.modeling_utils.checkpoint stuck on unsloth_checkpoint. Fix by making the transformers.modeling_utils restore independent: - Check identity directly (is unsloth_checkpoint) rather than relying on the torch.utils.checkpoint.checkpoint name - Fall back to torch.utils.checkpoint.checkpoint (already restored) when _old_checkpoint was deleted by the non-smart unpatch --- unsloth_zoo/gradient_checkpointing.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index fb126c446..e62de5848 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -834,9 +834,19 @@ def unpatch_unsloth_smart_gradient_checkpointing(): hasattr(torch.utils.checkpoint, "_old_checkpoint"): torch.utils.checkpoint.checkpoint = torch.utils.checkpoint._old_checkpoint - # Also restore transformers.modeling_utils.checkpoint - import transformers.modeling_utils - transformers.modeling_utils.checkpoint = torch.utils.checkpoint._old_checkpoint + + # Restore transformers.modeling_utils.checkpoint independently. + # Must be outside the conditional above because unpatch_unsloth_gradient_checkpointing() + # may run first (e.g. training_utils.py:201), deleting _old_checkpoint and restoring + # torch.utils.checkpoint.checkpoint, which makes the condition above False. + # Use _old_checkpoint if still available, otherwise torch.utils.checkpoint.checkpoint + # (which has already been restored to the original at that point). + import transformers.modeling_utils + if getattr(transformers.modeling_utils, "checkpoint", None) is unsloth_checkpoint: + transformers.modeling_utils.checkpoint = getattr( + torch.utils.checkpoint, "_old_checkpoint", + torch.utils.checkpoint.checkpoint + ) pass