Skip to content

Commit 56bf2e8

Browse files
authored
float8 delayed scaling: remove need to use workaround for AC (#1291)
Summary: The logic to check if the user has called `sync_float8_amax_and_scale_history`, while nice, hasn't been that useful in practice. Removing this logic in order to simplify the integration with manual activation checkpointing - now it "just works" without the need to work around with a non-standard config. Test Plan: ``` pytest test/float8/test_base.py -s -x ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 6735461 commit 56bf2e8

File tree

2 files changed

+12
-17
lines changed

2 files changed

+12
-17
lines changed

test/float8/test_base.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def _test_linear_impl(
261261
x,
262262
m_ref,
263263
config: Float8LinearConfig,
264+
use_ac: bool = False,
264265
):
265266
m_fp8 = Float8Linear.from_float(
266267
copy.deepcopy(m_ref),
@@ -269,9 +270,15 @@ def _test_linear_impl(
269270
for _ in range(2):
270271
if linear_requires_sync(config):
271272
sync_float8_amax_and_scale_history(m_fp8)
272-
y_fp8 = m_fp8(x)
273+
if use_ac:
274+
y_fp8 = torch.utils.checkpoint.checkpoint(m_fp8, x, use_reentrant=False)
275+
else:
276+
y_fp8 = m_fp8(x)
273277
y_fp8.sum().backward()
274-
y_ref = m_ref(x)
278+
if use_ac:
279+
y_ref = torch.utils.checkpoint.checkpoint(m_ref, x, use_reentrant=False)
280+
else:
281+
y_ref = m_ref(x)
275282
y_ref.sum().backward()
276283

277284
assert y_ref.shape == y_fp8.shape
@@ -344,6 +351,7 @@ def _test_linear_impl(
344351
)
345352
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
346353
@pytest.mark.parametrize("linear_bias", [False, True])
354+
@pytest.mark.parametrize("use_ac", [False, True])
347355
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
348356
def test_linear_from_config_params(
349357
self,
@@ -354,6 +362,7 @@ def test_linear_from_config_params(
354362
scaling_type_grad_output: ScalingType,
355363
linear_dtype: torch.dtype,
356364
linear_bias: bool,
365+
use_ac: bool,
357366
):
358367
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
359368
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
@@ -369,6 +378,7 @@ def test_linear_from_config_params(
369378
x,
370379
m_ref,
371380
config,
381+
use_ac,
372382
)
373383

374384
# Note: there are now too many config combinations to test all of

torchao/float8/float8_linear.py

-15
Original file line numberDiff line numberDiff line change
@@ -334,10 +334,6 @@ def __init__(self, *args, **kwargs):
334334
# TODO(future PR): add serialization for this flag
335335
self.is_amax_initialized = not self.config.enable_amax_init
336336

337-
# Syncing of amaxes and scales happens outside of this function. This
338-
# flag is here to enforce that the user does not forget to do this.
339-
self.amax_and_scale_synced = not self.config.enable_amax_init
340-
341337
# This is needed to properly handle autocast in the amax/scale
342338
# update function for torch.float16
343339
self.last_seen_input_dtype = None
@@ -544,23 +540,12 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
544540
def float8_pre_forward(self, input):
545541
if not self.enable_pre_and_post_forward:
546542
return
547-
if (
548-
self.is_amax_initialized
549-
and (not self.amax_and_scale_synced)
550-
and torch.is_grad_enabled()
551-
):
552-
raise AssertionError(
553-
"amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward"
554-
)
555543
self.last_seen_input_dtype = input.dtype
556544

557545
def float8_post_forward(self):
558546
if not self.enable_pre_and_post_forward:
559547
return
560-
# Ensure that calling forward again will fail until the user syncs
561-
# amaxes and scales
562548
self.is_amax_initialized = True
563-
self.amax_and_scale_synced = False
564549

565550
def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
566551
has_any_axiswise_scaling = (

0 commit comments

Comments
 (0)