You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
float8 delayed scaling: private API to fix user overriding buffers
Summary:
Context: pytorch/torchtitan#654
If the user has delayed scaling and FSDP float8 all-gather on, there is
a subtle bug that can happen if the user calls
`model.to_empty(device="cuda")`:
1. to_empty recreates the buffers for tracking weight amax and scale
2. (1) leaves the buffers pointed to by Float8Linear.weight._amax_buffer, etc orphaned, because they don't participate in `to_empty`
I couldn't think of an easy and clean way to auto-fix this since we can't expect
`torch.nn.Module` to know that our logic has multiple references to the same
buffer, so exposing a private API for now until we can think of something better.
With the current fix, the user can then call
`_maybe_fixup_delayed_scaling_buffers` manually to relink the buffers to
the correct new versions.
Test Plan: CI
Reviewers:
Subscribers:
Tasks:
Tags:
0 commit comments