-
Notifications
You must be signed in to change notification settings - Fork 227
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
enable torch.compile for mxfp8_cublas recipe
Summary: This PR enables `MXLinear` with `mxfp8_cublas` recipe to use torch.compile. The current approach is a short term workaround until pytorch/pytorch#148461 is done. Since we can't use e8m0 in torchinductor or triton yet, we create a custom op wrapper around `torch._scaled_mm` which takes `uint8` scales and does the cast to e8m0 inside the wrapper, where torchinductor can't see it. Test Plan: ``` // this now works (although performance is not ideal due to #1788) python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas // we can also uncomment the hardware check and run the unit test pytest test/prototype/mx_formats -s -k test_linear_compile ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 033d817549f80d7d0d8cf549f748411cc1f3ac6a ghstack-comment-id: 2701679811 Pull Request resolved: #1841
- Loading branch information
Showing
2 changed files
with
91 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters