-
Notifications
You must be signed in to change notification settings - Fork 386
Make scaling type configurable for MoE training #2642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2642
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
507b6cc to
4fbf578
Compare
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
4fbf578 to
1434e9b
Compare
1434e9b to
a5403ac
Compare
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
a5403ac to
a828d09
Compare
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
- For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
- For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
vkuzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lg for prototype, we might need to change this later
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
a828d09 to
82e707e
Compare
e9ba18b to
2aabb15
Compare
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
82e707e to
1b362ee
Compare
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
1b362ee to
bb05933
Compare
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
bb05933 to
df8adf3
Compare
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
df8adf3 to
a221a9e
Compare
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
Stacked PRs:
Make scaling type configurable for MoE training
Summary
Test plan