-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[float8] improve eager numerics for dynamic scales and gets on par with torch.compile #904
Changes from 10 commits
6bf0f5c
553687f
19a592d
218290e
24ec914
c099486
b93ffc8
ebff416
8978ab2
f17dc12
511c751
9becda1
e4fdca9
0cd4d37
014558d
3267402
1e07eff
ebdeed0
09ffa22
0b8dd85
87faf04
3a9fdb0
fc6c393
0043ace
ab3435c
a05a40f
334891b
c706139
93554c0
efd9bb9
85126cc
a5a426e
e7270f1
352685c
168cfe9
5900c3e
37e1479
2efde49
8c04f4f
04b229b
8b7c2ef
9346afd
3d0da20
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,9 @@ | |
|
||
import torch | ||
import torch.nn as nn | ||
from torchao.float8.float8_scaling_utils import ( | ||
hp_tensor_to_float8_dynamic, | ||
) | ||
|
||
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 | ||
|
||
|
@@ -604,6 +607,40 @@ def test_small_amax_float16(self, float8_dtype): | |
x = torch.tensor([target_amax], dtype=torch.float16, device="cuda") | ||
scale = tensor_to_scale(x, float8_dtype) | ||
assert not torch.any(torch.isinf(scale)) | ||
|
||
@unittest.skipIf( | ||
not is_cuda_8_9, | ||
"CUDA not available", | ||
) | ||
@pytest.mark.parametrize( | ||
"dtype", | ||
[ | ||
torch.float32, | ||
torch.bfloat16, | ||
torch.float16, | ||
], | ||
) | ||
def test_dynamic_scale_parity(self, dtype: torch.dtype): | ||
scaling_type_weight = ScalingType.DYNAMIC | ||
torch.manual_seed(0) | ||
hp_tensor = torch.randn(768, 32, device="cuda", dtype=dtype) | ||
float8_config = Float8LinearConfig( | ||
cast_config_weight=CastConfig(scaling_type=scaling_type_weight), | ||
) | ||
float8_eager = hp_tensor_to_float8_dynamic( | ||
hp_tensor, | ||
torch.float8_e4m3fn, | ||
float8_config, | ||
gemm_input_role=GemmInputRole.WEIGHT, | ||
) | ||
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( | ||
hp_tensor, | ||
torch.float8_e4m3fn, | ||
float8_config, | ||
gemm_input_role=GemmInputRole.WEIGHT, | ||
) | ||
assert torch.equal(float8_eager._scale, float8_compile._scale) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. without the PR, the numerics looks like following after, eager is also 106.1925... |
||
assert torch.testing.assert_close(float8_eager._data, float8_compile._data) | ||
|
||
|
||
class TestFloat8LinearUtils(unittest.TestCase): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -163,7 +163,8 @@ def forward( | |
|
||
DTensor Invariant: DTensor must always be the outer most tensor subclass | ||
""" | ||
tensor_scaled = tensor * scale | ||
# scale is float32 thus upcasting tensor to match | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we make this comment contain the context? something like
|
||
tensor_scaled = tensor.to(torch.float32) * scale | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. without upcasting, the eager numeric is like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torch.compile upcast tensor ahead, see
|
||
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) | ||
|
||
if isinstance(bits_fp8, DTensor): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,8 @@ def amax_to_scale( | |
float8_dtype: The float8 dtype. | ||
orig_dtype: The original dtype of the tensor. | ||
""" | ||
# _scaled_mm requires float32 scale | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can we describe in more detail why we are upcasting here |
||
amax = amax.to(torch.float64) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. upcast
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you share why the upcasting happens? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can look into inductor more on how it achieved fp64 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torch.compile actually upcasts to float32 with The float32 numeric difference can be verified with
|
||
if float8_dtype in FP8_TYPES: | ||
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) | ||
else: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,17 +59,17 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: | |
return | ||
|
||
# inf-norm is equivalent to max(abs(w)) | ||
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial | ||
max_weights = torch._foreach_norm(weights, ord=math.inf, dtype=torch.float64) # Partial | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add comment to describe upcasting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. improved comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good question! Actually I just updated the code to do back to your question, I checked |
||
amax_tensor = torch.stack(max_weights) # Partial | ||
# clamp is dispatched through DTensor | ||
# it will issue a single all-reduce | ||
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate | ||
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate | ||
if amax_tensor.dtype is torch.float16: | ||
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) | ||
local_scale_tensor = scale_tensor.to_local() | ||
local_scale_tensor = scale_tensor.to_local().to(torch.float32) | ||
for i, float8_linear in enumerate(float8_linears): | ||
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i].to(torch.float32) | ||
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i] | ||
|
||
|
||
# FSDP pads its local tensor on dim-0. The subclass should be preserved such | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move to
test_compile.py
since this is testing compile vs eager?