-
Notifications
You must be signed in to change notification settings - Fork 265
Fix VRAM regression with transformers 5.2+ gradient checkpointing #549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -754,17 +754,11 @@ def unsloth_checkpoint( | |
| Returns: | ||
| Output of running :attr:`function` on :attr:`*args` | ||
| """ | ||
| if use_reentrant is None: | ||
| warnings.warn( | ||
| "torch.utils.checkpoint: the use_reentrant parameter should be " | ||
| "passed explicitly. In version 2.5 we will raise an exception " | ||
| "if use_reentrant is not passed. use_reentrant=False is " | ||
| "recommended, but if you need to preserve the current default " | ||
| "behavior, you can pass use_reentrant=True. Refer to docs for more " | ||
| "details on the differences between the two variants.", | ||
| stacklevel=2 | ||
| ) | ||
| use_reentrant = True | ||
| # Force use_reentrant=True so UnslothCheckpointFunction (smart CPU offloading) | ||
| # is always used. This is safe because unsloth_checkpoint is only active when | ||
| # smart GC is patched; when unpatched, the original torch checkpoint is restored. | ||
| # Fixes transformers 5.2 which defaults use_reentrant=False, bypassing Unsloth. | ||
| use_reentrant = True | ||
|
|
||
| # Hack to mix *args with **kwargs in a python 2.7-compliant way | ||
| preserve = kwargs.pop("preserve_rng_state", True) | ||
|
|
@@ -805,6 +799,15 @@ def patch_unsloth_smart_gradient_checkpointing(dtype = None): | |
| if torch.utils.checkpoint.checkpoint.__name__ != "unsloth_checkpoint": | ||
| torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint | ||
| torch.utils.checkpoint.checkpoint = unsloth_checkpoint | ||
|
|
||
| # Always patch transformers.modeling_utils.checkpoint so that | ||
| # gradient_checkpointing_enable() wraps unsloth_checkpoint, not the original. | ||
| # Without this, transformers 5.2's use_reentrant=False default bypasses | ||
| # UnslothCheckpointFunction entirely. | ||
| # Must be outside the conditional above since torch.utils.checkpoint.checkpoint | ||
| # may already be patched while transformers.modeling_utils.checkpoint is not. | ||
| import transformers.modeling_utils | ||
| transformers.modeling_utils.checkpoint = unsloth_checkpoint | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. training_utils.py:200-202 calls unpatch_unsloth_gradient_checkpointing() before unpatch_unsloth_smart_gradient_checkpointing(). The non-smart unpatch deletes _old_checkpoint and restores torch.utils.checkpoint.checkpoint, so by the time the smart unpatch runs, its checkpoint.name == "unsloth_checkpoint" condition is already False and the restore block is skipped entirely. Fixed in 4f853be. The transformers.modeling_utils.checkpoint restore is now independent -- uses an identity check (is unsloth_checkpoint) and falls back to torch.utils.checkpoint.checkpoint (already restored to original) when _old_checkpoint was deleted. |
||
| pass | ||
|
|
||
|
|
||
|
|
@@ -831,6 +834,19 @@ def unpatch_unsloth_smart_gradient_checkpointing(): | |
| hasattr(torch.utils.checkpoint, "_old_checkpoint"): | ||
|
|
||
| torch.utils.checkpoint.checkpoint = torch.utils.checkpoint._old_checkpoint | ||
|
|
||
| # Restore transformers.modeling_utils.checkpoint independently. | ||
| # Must be outside the conditional above because unpatch_unsloth_gradient_checkpointing() | ||
| # may run first (e.g. training_utils.py:201), deleting _old_checkpoint and restoring | ||
| # torch.utils.checkpoint.checkpoint, which makes the condition above False. | ||
| # Use _old_checkpoint if still available, otherwise torch.utils.checkpoint.checkpoint | ||
| # (which has already been restored to the original at that point). | ||
| import transformers.modeling_utils | ||
| if getattr(transformers.modeling_utils, "checkpoint", None) is unsloth_checkpoint: | ||
| transformers.modeling_utils.checkpoint = getattr( | ||
| torch.utils.checkpoint, "_old_checkpoint", | ||
| torch.utils.checkpoint.checkpoint | ||
| ) | ||
| pass | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hard-coding
use_reentrant = Truehere regressesgradient_checkpointing_kwargssupport fromtransformersby ignoring callers that explicitly requestuse_reentrant=Falseand non-reentrant-only options. In that case, settings likecontext_fn/debugthat were previously valid now hit the laterValueErrorpath (context_fn/debugare rejected when reentrant), so configurations that worked before this commit can fail once smart checkpointing is patched.Useful? React with 👍 / 👎.