Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions unsloth_zoo/gradient_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,17 +754,11 @@ def unsloth_checkpoint(
Returns:
Output of running :attr:`function` on :attr:`*args`
"""
if use_reentrant is None:
warnings.warn(
"torch.utils.checkpoint: the use_reentrant parameter should be "
"passed explicitly. In version 2.5 we will raise an exception "
"if use_reentrant is not passed. use_reentrant=False is "
"recommended, but if you need to preserve the current default "
"behavior, you can pass use_reentrant=True. Refer to docs for more "
"details on the differences between the two variants.",
stacklevel=2
)
use_reentrant = True
# Force use_reentrant=True so UnslothCheckpointFunction (smart CPU offloading)
# is always used. This is safe because unsloth_checkpoint is only active when
# smart GC is patched; when unpatched, the original torch checkpoint is restored.
# Fixes transformers 5.2 which defaults use_reentrant=False, bypassing Unsloth.
use_reentrant = True
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Respect explicit non-reentrant checkpoint requests

Hard-coding use_reentrant = True here regresses gradient_checkpointing_kwargs support from transformers by ignoring callers that explicitly request use_reentrant=False and non-reentrant-only options. In that case, settings like context_fn/debug that were previously valid now hit the later ValueError path (context_fn/debug are rejected when reentrant), so configurations that worked before this commit can fail once smart checkpointing is patched.

Useful? React with 👍 / 👎.


# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop("preserve_rng_state", True)
Expand Down Expand Up @@ -805,6 +799,15 @@ def patch_unsloth_smart_gradient_checkpointing(dtype = None):
if torch.utils.checkpoint.checkpoint.__name__ != "unsloth_checkpoint":
torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint
torch.utils.checkpoint.checkpoint = unsloth_checkpoint

# Always patch transformers.modeling_utils.checkpoint so that
# gradient_checkpointing_enable() wraps unsloth_checkpoint, not the original.
# Without this, transformers 5.2's use_reentrant=False default bypasses
# UnslothCheckpointFunction entirely.
# Must be outside the conditional above since torch.utils.checkpoint.checkpoint
# may already be patched while transformers.modeling_utils.checkpoint is not.
import transformers.modeling_utils
transformers.modeling_utils.checkpoint = unsloth_checkpoint
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep a dedicated backup before patching modeling_utils checkpoint

patch_unsloth_smart_gradient_checkpointing() now overwrites transformers.modeling_utils.checkpoint, but its restore path relies on torch.utils.checkpoint._old_checkpoint still existing; that assumption is broken by the current teardown flow in training_utils.py (unpatch_unsloth_gradient_checkpointing() runs first and deletes _old_checkpoint at lines 201-202), so unpatch_unsloth_smart_gradient_checkpointing() never reaches the restore branch and leaves transformers.modeling_utils.checkpoint stuck on unsloth_checkpoint. This causes later “vanilla” gradient-checkpointing runs to keep using the Unsloth wrapper unexpectedly.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. training_utils.py:200-202 calls unpatch_unsloth_gradient_checkpointing() before unpatch_unsloth_smart_gradient_checkpointing(). The non-smart unpatch deletes _old_checkpoint and restores torch.utils.checkpoint.checkpoint, so by the time the smart unpatch runs, its checkpoint.name == "unsloth_checkpoint" condition is already False and the restore block is skipped entirely.

Fixed in 4f853be. The transformers.modeling_utils.checkpoint restore is now independent -- uses an identity check (is unsloth_checkpoint) and falls back to torch.utils.checkpoint.checkpoint (already restored to original) when _old_checkpoint was deleted.

pass


Expand All @@ -831,6 +834,19 @@ def unpatch_unsloth_smart_gradient_checkpointing():
hasattr(torch.utils.checkpoint, "_old_checkpoint"):

torch.utils.checkpoint.checkpoint = torch.utils.checkpoint._old_checkpoint

# Restore transformers.modeling_utils.checkpoint independently.
# Must be outside the conditional above because unpatch_unsloth_gradient_checkpointing()
# may run first (e.g. training_utils.py:201), deleting _old_checkpoint and restoring
# torch.utils.checkpoint.checkpoint, which makes the condition above False.
# Use _old_checkpoint if still available, otherwise torch.utils.checkpoint.checkpoint
# (which has already been restored to the original at that point).
import transformers.modeling_utils
if getattr(transformers.modeling_utils, "checkpoint", None) is unsloth_checkpoint:
transformers.modeling_utils.checkpoint = getattr(
torch.utils.checkpoint, "_old_checkpoint",
torch.utils.checkpoint.checkpoint
)
pass


Expand Down
Loading