-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable torch.compile for mxfp8_cublas recipe #1841
Open
vkuzo
wants to merge
24
commits into
main
Choose a base branch
from
gh/vkuzo/53/head
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
+102
−9
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Stack from ghstack (oldest at bottom): |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1841
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit cb824b2 with merge base 4a5ab2d ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
vkuzo
added a commit
that referenced
this pull request
Mar 5, 2025
Summary: This PR enables `MXLinear` with `mxfp8_cublas` recipe to use torch.compile. The current approach is a short term workaround until pytorch/pytorch#148461 is done. Since we can't use e8m0 in torchinductor or triton yet, we create a custom op wrapper around `torch._scaled_mm` which takes `uint8` scales and does the cast to e8m0 inside the wrapper, where torchinductor can't see it. Test Plan: ``` // this now works (although performance is not ideal due to #1788) python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas // we can also uncomment the hardware check and run the unit test pytest test/prototype/mx_formats -s -k test_linear_compile ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 033d817549f80d7d0d8cf549f748411cc1f3ac6a ghstack-comment-id: 2701679811 Pull Request resolved: #1841
vkuzo
added a commit
that referenced
this pull request
Mar 5, 2025
Summary: This PR enables `MXLinear` with `mxfp8_cublas` recipe to use torch.compile. The current approach is a short term workaround until pytorch/pytorch#147873 is done. Since we can't use e8m0 in torchinductor or triton yet, we create a custom op wrapper around `torch._scaled_mm` which takes `uint8` scales and does the cast to e8m0 inside the wrapper, where torchinductor can't see it. Test Plan: ``` // this now works (although performance is not ideal due to #1788) python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas // we can also uncomment the hardware check and run the unit test pytest test/prototype/mx_formats -s -k test_linear_compile ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f3ebd12edcb746b8abf992d00711ce2bdbb7fcf2 ghstack-comment-id: 2701679811 Pull Request resolved: #1841
vkuzo
added a commit
that referenced
this pull request
Mar 5, 2025
Summary: This PR enables `MXLinear` with `mxfp8_cublas` recipe to use torch.compile. The current approach is a short term workaround until pytorch/pytorch#147873 is done. Since we can't use e8m0 in torchinductor or triton yet, we create a custom op wrapper around `torch._scaled_mm` which takes `uint8` scales and does the cast to e8m0 inside the wrapper, where torchinductor can't see it. Test Plan: ``` // this now works (although performance is not ideal due to #1788) python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas // we can also uncomment the hardware check and run the unit test pytest test/prototype/mx_formats -s -k test_linear_compile ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e5687e308db0a54c6083c58cfec5cc49626622f1 ghstack-comment-id: 2701679811 Pull Request resolved: #1841
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
topic: performance
Use this tag if this PR improves the performance of a feature
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
This PR enables
MXLinear
withmxfp8_cublas
recipe to usetorch.compile.
The current approach is a short term workaround until
pytorch/pytorch#148461 is done. Since we can't
use e8m0 in torchinductor or triton yet, we create a custom op wrapper
around
torch._scaled_mm
which takesuint8
scales and does the cast toe8m0 inside the wrapper, where torchinductor can't see it.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: