cutlass: enable SM121-gated MXFP4 MoE kernel path#3038
cutlass: enable SM121-gated MXFP4 MoE kernel path#3038christopherowen wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
| @@ -0,0 +1,14 @@ | |||
| if (CUTLASS_NVCC_ARCHS MATCHES 121a) | |||
There was a problem hiding this comment.
Please add copyright to this file.
There was a problem hiding this comment.
Thanks for catching this. I’ve added the standard NVIDIA copyright header.
|
I've ported this PR to internal repository to run the pipeline. I'll merge this PR after codes are merged in internal repository. |
| stride_a = InternalStrideA{}; | ||
| stride_b = InternalStrideB{}; | ||
| // However, TMA descriptor encoding requires valid non-zero strides. | ||
| // Use tile dimensions as placeholder values for the runtime stride components. |
There was a problem hiding this comment.
Why is this change needed?
TMA descriptors that are being created on host (here in to_underlying_arguments) are mainly to just properly initialize kernel dependent parameters, but global tensor pointer, shapes and strides will be correctly updated on device - based on the group - prior to any usage.
Were you running into any issues with this?
|
|
||
| cutlass_test_unit_gemm_device_add_executable( | ||
| cutlass_test_unit_bs_grouped_gemm_device_tensorop_sm121 | ||
| ../sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf8_mxf4_f32_group_gemm_fusion.cu |
There was a problem hiding this comment.
Porting some internal review comments:
It seems like this PR added support for CTA_TILE_M == 64 which is smaller than 128. Can we add one unit test in this file to test this case?
| if (CUTLASS_NVCC_ARCHS MATCHES 121a) | ||
|
|
||
| add_custom_target( | ||
| cutlass_test_unit_gemm_device_sm121_bs |
There was a problem hiding this comment.
Porting some internal review comments:
Let's just spell out blockscaled and not call the kernel bs.
|
@christopherowen could you apply the comments to straightforward this PR? |
Add SM121-gated MXFP4 kernel wiring and launch config updates for MoE inference paths.
Validation: