Skip to content

Commit 9457814

Browse files
committed
float8 delayed scaling: private API to fix user overriding buffers
Summary: Context: pytorch/torchtitan#654 If the user has delayed scaling and FSDP float8 all-gather on, there is a subtle bug that can happen if the user calls `model.to_empty(device="cuda")`: 1. to_empty recreates the buffers for tracking weight amax and scale 2. (1) leaves the buffers pointed to by Float8Linear.weight._amax_buffer, etc orphaned, because they don't participate in `to_empty` I couldn't think of an easy and clean way to auto-fix this since we can't expect `torch.nn.Module` to know that our logic has multiple references to the same buffer, so exposing a private API for now until we can think of something better. With the current fix, the user can then call `_maybe_fixup_delayed_scaling_buffers` manually to relink the buffers to the correct new versions. Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent 56bf2e8 commit 9457814

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

test/float8/test_base.py

+23
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,29 @@ def test_inference_mode(self):
530530
with torch.inference_mode(mode=True):
531531
y = m(x)
532532

533+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
534+
def test_to_empty_delayed_scaling_with_float8_all_gather(self):
535+
with torch.device("meta"):
536+
m_ref = nn.Sequential(nn.Linear(32, 32))
537+
config = Float8LinearConfig(
538+
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
539+
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
540+
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
541+
enable_fsdp_float8_all_gather=True,
542+
)
543+
m_fp8 = convert_to_float8_training(m_ref, config=config)
544+
545+
assert m_fp8[0].fp8_amax_weight is m_fp8[0].weight._amax_buffer
546+
assert m_fp8[0].fp8_amax_history_weight is m_fp8[0].weight._amax_history_buffer
547+
assert m_fp8[0].fp8_scale_weight is m_fp8[0].weight._scale_buffer
548+
549+
m_fp8.to_empty(device="cuda")
550+
m_fp8[0]._maybe_fixup_delayed_scaling_buffers()
551+
552+
assert m_fp8[0].fp8_amax_weight is m_fp8[0].weight._amax_buffer
553+
assert m_fp8[0].fp8_amax_history_weight is m_fp8[0].weight._amax_history_buffer
554+
assert m_fp8[0].fp8_scale_weight is m_fp8[0].weight._scale_buffer
555+
533556

534557
class TestScaledMM:
535558
@unittest.skipIf(

torchao/float8/float8_linear.py

+13
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,19 @@ def extra_repr(self):
644644
s = f'{super().extra_repr()}, cast_configs={cast_config_str}"'
645645
return s
646646

647+
def _maybe_fixup_delayed_scaling_buffers(self):
648+
if (
649+
self.config.enable_fsdp_float8_all_gather
650+
and self.config.cast_config_weight.scaling_type is ScalingType.DELAYED
651+
):
652+
# in case the module weight-related buffers got overwritten by
653+
# the user (such as when calling `model.to_empty`), we
654+
# re-link the weight wrapper buffers to point to the correct
655+
# location
656+
self.weight._amax_buffer = self.fp8_amax_weight
657+
self.weight._amax_history_buffer = self.fp8_amax_history_weight
658+
self.weight._scale_buffer = self.fp8_scale_weight
659+
647660
@classmethod
648661
def from_float(
649662
cls,

0 commit comments

Comments
 (0)