Skip to content

cutlass: enable SM121-gated MXFP4 MoE kernel path#3038

Open
christopherowen wants to merge 4 commits intoNVIDIA:mainfrom
christopherowen:sm121_mxfp4
Open

cutlass: enable SM121-gated MXFP4 MoE kernel path#3038
christopherowen wants to merge 4 commits intoNVIDIA:mainfrom
christopherowen:sm121_mxfp4

Conversation

@christopherowen
Copy link
Copy Markdown

Add SM121-gated MXFP4 kernel wiring and launch config updates for MoE inference paths.

Validation:

  • Builds cleanly on SM121 toolchains.
  • Runtime sanity and end-to-end vLLM mxfp4 serve checks pass on SM121.
  • Heavily community-tested across DGX Spark/SM121 setups.

@@ -0,0 +1,14 @@
if (CUTLASS_NVCC_ARCHS MATCHES 121a)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add copyright to this file.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. I’ve added the standard NVIDIA copyright header.

@Junkai-Wu
Copy link
Copy Markdown
Collaborator

I've ported this PR to internal repository to run the pipeline. I'll merge this PR after codes are merged in internal repository.

stride_a = InternalStrideA{};
stride_b = InternalStrideB{};
// However, TMA descriptor encoding requires valid non-zero strides.
// Use tile dimensions as placeholder values for the runtime stride components.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change needed?
TMA descriptors that are being created on host (here in to_underlying_arguments) are mainly to just properly initialize kernel dependent parameters, but global tensor pointer, shapes and strides will be correctly updated on device - based on the group - prior to any usage.
Were you running into any issues with this?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@christopherowen please check


cutlass_test_unit_gemm_device_add_executable(
cutlass_test_unit_bs_grouped_gemm_device_tensorop_sm121
../sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf8_mxf4_f32_group_gemm_fusion.cu
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Porting some internal review comments:

It seems like this PR added support for CTA_TILE_M == 64 which is smaller than 128. Can we add one unit test in this file to test this case?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (CUTLASS_NVCC_ARCHS MATCHES 121a)

add_custom_target(
cutlass_test_unit_gemm_device_sm121_bs
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Porting some internal review comments:

Let's just spell out blockscaled and not call the kernel bs.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@johnnynunez
Copy link
Copy Markdown

@christopherowen could you apply the comments to straightforward this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants