-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cast local_scale_tensor
to fp32 for precompute of fp8 dynamic scaling
#713
cast local_scale_tensor
to fp32 for precompute of fp8 dynamic scaling
#713
Conversation
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/713
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1d18db7 with merge base ac8ce4c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thank you
@@ -67,7 +67,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: | |||
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate | |||
if amax_tensor.dtype is torch.float16: | |||
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) | |||
local_scale_tensor = scale_tensor.to_local() | |||
local_scale_tensor = scale_tensor.to_local().to(dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for pinpoint to a specific line. I can also take a look at why max_weights = torch._foreach_norm(weights, ord=math.inf)
does not ends up with float32
. I was expecting local_scale_tensor to be the same dtype to model.parameters()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just realized that we cast to float32 in
ao/torchao/float8/float8_utils.py
Line 55 in 227d4bf
return res.to(torch.float32) |
it's good to keep precompute in sync with float8_utils.py. thanks for the fix
happy to approve but the PR is in draft mode |
Marked ready |
Signed-off-by: Masaki Kozuki <[email protected]>
ah I should use |
test/float8/test_fsdp2/test_fsdp2.py
Outdated
@@ -106,6 +109,7 @@ def _test_transformer_parity( | |||
precompute: bool, | |||
scaling_type_weight: ScalingType, | |||
compile_transformer_block: bool, | |||
dtype: torch.dtype | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe change this to dtype: Optional[torch.dtype] = None
as well. I guess python 3.9 probably does not support torch.dtype | None
* help outputs the default model dir
When a model is in bfloat16, precomputed scales seem to be in bfloat16.
This seems to cause the dtype mismatch, especially
scale_b
being bf16, after the first call of precompute: