Skip to content

Commit

Permalink
float8 with delayed scaling: fix autocast handling (#1306)
Browse files Browse the repository at this point in the history
Summary:

Fixes a bug with delayed scaling + autocast.

Before, the last input dtype when in autocast was queried from the input
to `torch._scaled_mm`:

```
x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm
```

This is incorrect because the dtype was saved from before the place
where autocast could change it.  This happened to work if `x_hp` was
already of the correct dtype, but did not work in cases such as the new
test case added in this PR, or real models such as the repro from
#1297.  The reason we haven't caught
this for so long is we've been using FSDP's mixed precision and not
single-GPU autocast.

The fix I'm taking here is to query the original post-autocast dtype based
on the output of `torch._scaled_mm`.  Since this dtype is based on the
dtype of the input to `torch._scaled_mm`, this will properly capture
autocasting:

```
x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here}
```

Test Plan:

```
// first, test the updated test case - it passes

// second - test a modified version of the repro in
// #1297:
// code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7
// logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10
// we now see a speedup with float8
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored Nov 19, 2024
1 parent 26648c2 commit b714026
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
13 changes: 8 additions & 5 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,33 +424,36 @@ def test_autocast_outputs(
emulate: bool,
linear_dtype: torch.dtype,
):
m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
m_ref = nn.Sequential(
nn.Linear(32, 32, device="cuda", dtype=linear_dtype),
nn.Linear(32, 32, device="cuda", dtype=linear_dtype),
)
config = Float8LinearConfig(
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
emulate=emulate,
)
m = Float8Linear.from_float(copy.deepcopy(m_ref), config)
m = convert_to_float8_training(copy.deepcopy(m_ref), config=config)

# autocast off
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
y = m(x)
if linear_requires_sync(config):
sync_float8_amax_and_scale_history(m)
y = m(x)
assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}"

# autocast on
with torch.autocast("cuda"):
y = m(x)
if linear_requires_sync(config):
sync_float8_amax_and_scale_history(m)
y = m(x)
assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}"

with torch.autocast("cuda", dtype=torch.bfloat16):
y = m(x)
if linear_requires_sync(config):
sync_float8_amax_and_scale_history(m)
y = m(x)
assert (
y.dtype == torch.bfloat16
), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
Expand Down
8 changes: 6 additions & 2 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def __init__(self, *args, **kwargs):

# This is needed to properly handle autocast in the amax/scale
# update function for torch.float16
self.last_seen_input_dtype = None
self.last_seen_output_dtype = None

# pre_forward and post_forward are currently broken with FSDP
# and torch.compile, this option can disable them
Expand Down Expand Up @@ -538,11 +538,14 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
return output

def float8_pre_forward(self, input):
# TODO(future PR): deprecate these functions and the corresponding
# config setting
if not self.enable_pre_and_post_forward:
return
self.last_seen_input_dtype = input.dtype

def float8_post_forward(self):
# TODO(future PR): deprecate these functions and the corresponding
# config setting
if not self.enable_pre_and_post_forward:
return
self.is_amax_initialized = True
Expand Down Expand Up @@ -624,6 +627,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

if self.has_any_delayed_scaling:
self.float8_post_forward()
self.last_seen_output_dtype = output.dtype
return output

def extra_repr(self):
Expand Down
2 changes: 1 addition & 1 deletion torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def inner_func():
fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight
fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output

x_dtypes.add(child.last_seen_input_dtype)
x_dtypes.add(child.last_seen_output_dtype)
scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name)

# TODO This way to get the activation dtype is not ideal
Expand Down

0 comments on commit b714026

Please sign in to comment.