diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 2d77458a5..e62de5848 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -754,17 +754,11 @@ def unsloth_checkpoint( Returns: Output of running :attr:`function` on :attr:`*args` """ - if use_reentrant is None: - warnings.warn( - "torch.utils.checkpoint: the use_reentrant parameter should be " - "passed explicitly. In version 2.5 we will raise an exception " - "if use_reentrant is not passed. use_reentrant=False is " - "recommended, but if you need to preserve the current default " - "behavior, you can pass use_reentrant=True. Refer to docs for more " - "details on the differences between the two variants.", - stacklevel=2 - ) - use_reentrant = True + # Force use_reentrant=True so UnslothCheckpointFunction (smart CPU offloading) + # is always used. This is safe because unsloth_checkpoint is only active when + # smart GC is patched; when unpatched, the original torch checkpoint is restored. + # Fixes transformers 5.2 which defaults use_reentrant=False, bypassing Unsloth. + use_reentrant = True # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop("preserve_rng_state", True) @@ -805,6 +799,15 @@ def patch_unsloth_smart_gradient_checkpointing(dtype = None): if torch.utils.checkpoint.checkpoint.__name__ != "unsloth_checkpoint": torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint torch.utils.checkpoint.checkpoint = unsloth_checkpoint + + # Always patch transformers.modeling_utils.checkpoint so that + # gradient_checkpointing_enable() wraps unsloth_checkpoint, not the original. + # Without this, transformers 5.2's use_reentrant=False default bypasses + # UnslothCheckpointFunction entirely. + # Must be outside the conditional above since torch.utils.checkpoint.checkpoint + # may already be patched while transformers.modeling_utils.checkpoint is not. + import transformers.modeling_utils + transformers.modeling_utils.checkpoint = unsloth_checkpoint pass @@ -831,6 +834,19 @@ def unpatch_unsloth_smart_gradient_checkpointing(): hasattr(torch.utils.checkpoint, "_old_checkpoint"): torch.utils.checkpoint.checkpoint = torch.utils.checkpoint._old_checkpoint + + # Restore transformers.modeling_utils.checkpoint independently. + # Must be outside the conditional above because unpatch_unsloth_gradient_checkpointing() + # may run first (e.g. training_utils.py:201), deleting _old_checkpoint and restoring + # torch.utils.checkpoint.checkpoint, which makes the condition above False. + # Use _old_checkpoint if still available, otherwise torch.utils.checkpoint.checkpoint + # (which has already been restored to the original at that point). + import transformers.modeling_utils + if getattr(transformers.modeling_utils, "checkpoint", None) is unsloth_checkpoint: + transformers.modeling_utils.checkpoint = getattr( + torch.utils.checkpoint, "_old_checkpoint", + torch.utils.checkpoint.checkpoint + ) pass