float8 delayed scaling: private API to fix user overriding buffers #1292
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
Context: pytorch/torchtitan#654
If the user has delayed scaling and FSDP float8 all-gather on, there is a subtle bug that can happen if the user calls
model.to_empty(device="cuda")
:to_empty
I couldn't think of an easy and clean way to auto-fix this since we can't expect
torch.nn.Module
to know that our logic has multiple references to the same buffer, so exposing a private API for now until we can think of something better.With the current fix, the user can then call
_maybe_fixup_delayed_scaling_buffers
manually to relink the buffers to the correct new versions.Test Plan: CI
Reviewers:
Subscribers:
Tasks:
Tags: