-
Notifications
You must be signed in to change notification settings - Fork 1.9k
cutlass: enable SM121-gated MXFP4 MoE kernel path #3038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
24cc8c8
aaaff2a
cac6eff
476b496
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| if (CUTLASS_NVCC_ARCHS MATCHES 121a) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add copyright to this file.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for catching this. I’ve added the standard NVIDIA copyright header. |
||
|
|
||
| add_custom_target( | ||
| cutlass_test_unit_gemm_device_sm121_bs | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Porting some internal review comments: Let's just spell out There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @christopherowen check |
||
| DEPENDS | ||
| cutlass_test_unit_bs_grouped_gemm_device_tensorop_sm121 | ||
| ) | ||
|
|
||
| cutlass_test_unit_gemm_device_add_executable( | ||
| cutlass_test_unit_bs_grouped_gemm_device_tensorop_sm121 | ||
| ../sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf8_mxf4_f32_group_gemm_fusion.cu | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Porting some internal review comments: It seems like this PR added support for CTA_TILE_M == 64 which is smaller than 128. Can we add one unit test in this file to test this case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @christopherowen check |
||
| ) | ||
|
|
||
| endif() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this change needed?
TMA descriptors that are being created on host (here in to_underlying_arguments) are mainly to just properly initialize kernel dependent parameters, but global tensor pointer, shapes and strides will be correctly updated on device - based on the group - prior to any usage.
Were you running into any issues with this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@christopherowen please check