Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cast local_scale_tensor to fp32 for precompute of fp8 dynamic scaling #713

Merged
merged 4 commits into from
Aug 22, 2024

Conversation

crcrpar
Copy link
Contributor

@crcrpar crcrpar commented Aug 20, 2024

When a model is in bfloat16, precomputed scales seem to be in bfloat16.
This seems to cause the dtype mismatch, especially scale_b being bf16, after the first call of precompute:

# [rank0]:[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 363, in forward
# [rank0]:[rank0]:     output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())
# [rank0]:[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 575, in apply
# [rank0]:[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
# [rank0]:[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 60, in forward
# [rank0]:[rank0]:     res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t)
# [rank0]:[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_tensor.py", line 359, in __torch_dispatch__
# [rank0]:[rank0]:     return FLOAT8_OPS_TABLE[func](func, args, kwargs)
# [rank0]:[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_ops.py", line 181, in float8_mm
# [rank0]:[rank0]:     tensor_out = addmm_float8_unwrapped(
# [rank0]:[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_python_api.py", line 54, in addmm_float8_unwrapped
# [rank0]:[rank0]:     output = torch._scaled_mm(
# [rank0]:[rank0]: RuntimeError: Both scale_a and scale_b must be float (fp32) tensors.

Copy link

pytorch-bot bot commented Aug 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/713

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1d18db7 with merge base ac8ce4c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 20, 2024
@crcrpar crcrpar marked this pull request as draft August 20, 2024 08:02
@msaroufim msaroufim requested a review from vkuzo August 20, 2024 16:37
@vkuzo vkuzo requested a review from weifengpy August 20, 2024 16:46
Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thank you

@@ -67,7 +67,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if amax_tensor.dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
local_scale_tensor = scale_tensor.to_local()
local_scale_tensor = scale_tensor.to_local().to(dtype=torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pinpoint to a specific line. I can also take a look at why max_weights = torch._foreach_norm(weights, ord=math.inf) does not ends up with float32. I was expecting local_scale_tensor to be the same dtype to model.parameters()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just realized that we cast to float32 in

return res.to(torch.float32)

it's good to keep precompute in sync with float8_utils.py. thanks for the fix

@weifengpy
Copy link
Contributor

happy to approve but the PR is in draft mode

@crcrpar crcrpar marked this pull request as ready for review August 21, 2024 00:46
@crcrpar
Copy link
Contributor Author

crcrpar commented Aug 22, 2024

Marked ready

@crcrpar
Copy link
Contributor Author

crcrpar commented Aug 22, 2024

ah I should use dtype kwarg of _foreach_norm instead?
ref: pytorch/pytorch#125665

@@ -106,6 +109,7 @@ def _test_transformer_parity(
precompute: bool,
scaling_type_weight: ScalingType,
compile_transformer_block: bool,
dtype: torch.dtype | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe change this to dtype: Optional[torch.dtype] = None as well. I guess python 3.9 probably does not support torch.dtype | None

@msaroufim msaroufim merged commit cdb25a4 into pytorch:main Aug 22, 2024
16 checks passed
@crcrpar crcrpar deleted the cast_precompute_scale_to_fp32 branch October 18, 2024 07:11
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* help outputs the default model dir
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants