fix: Support K=64 block-scaled GEMM tiles on SM120/SM121#3121
fix: Support K=64 block-scaled GEMM tiles on SM120/SM121#3121RobTand wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Two fixes that enable K=64 tile shapes for block-scaled MoE GEMM on SM120 (RTX 5090/PRO 6000) and SM121 (DGX Spark GB10): 1. TMA zero-stride basis handling (copy_traits_sm90_tma.hpp): When K=64 with SFVectorSize=32, scale factor folding creates a broadcast dimension with zero stride. The existing code passes this to basis_get() which produces invalid TMA descriptors. Fix: detect zero-stride basis via is_constant<0> and emit shape=1, stride=0. 2. Scale factor block size clamping (sm120_blockscaled_mma_builder.inl): K=64 with SFVectorSize=32 gives NumSFAlongK=2, but Blk_SF=4. The division Blk_SF/MMA_NSF overflows. Fix: clamp effective block size (EffBlk_SF) to min(NumSFAlongK, Blk_SF) and conditionally fold into kBasicBlock to keep TMA layout flat. Together these enable K=64 CTA shapes ([128,128,64], [128,256,64], [256,128,64]) which achieve 7-11 pipeline stages vs 2 with K=128, giving ~2x single-user decode throughput on SM120/SM121. Tested on DGX Spark (SM121) with Nemotron-3-Super-120B and Qwen3.5-122B NVFP4 models via FlashInfer CUTLASS MoE backend. Related: FlashInfer PR flashinfer-ai/flashinfer#2786 adds the K=64 tile shapes to FlashInfer's kernel generation but depends on these CUTLASS fixes for correctness. Signed-off-by: Rob Tand <robert.tand@icloud.com>
|
Thank you for the mentioned, I am not working on it as this time as I’ve
been busy trying to work on an sm120 optimized attention kernel, which is
proving to be harder than I thought.
I will be happy to help coordinate with you and anything you need.
|
Testing on RTX PRO 6000 (SM120, TP=4)We applied both this PR and FlashInfer #2786 together on 4x RTX PRO 6000 Blackwell Server Edition (SM120, PCIe, 96GB each) running SGLang with Setup:
Results (single-user decode throughput):
No measurable difference on this configuration. The patches applied cleanly and FlashInfer JIT compiled the K=64 kernels without errors, but the dispatch doesn't seem to select K=64 tiles for this model/shape combination, or the benefit is offset by TP=4 communication overhead. Questions:
Both patches are functionally correct on SM120 — no crashes, no compilation errors. The CUTLASS TMA zero-stride fix and EffBlk_SF clamping work as described. |
|
Thanks for taking the time to review this — really appreciate it. To address your question about forcing K=64 tiles and demonstrating necessity: SMEM constraint on SM120/121 SM120/SM121 devices (RTX 5090, DGX Spark GB10) have 99 KB of opt-in shared memory per block, compared to 227 KB on SM100. This changes which tile shapes are viable for block-scaled GEMM. Pipeline stage impact The builder computes pipeline stages as
K=128 tiles do work on SM120/121 (3–5 stages), so this isn't strictly about making things functional at K=128. The issue is that K=64 tiles — the natural choice for the constrained SMEM budget — cannot be instantiated at all. Why K=64 fails without this fix Two bugs surface only when
Why these were never caught on SM100 With 227 KB of SMEM, K=128 tiles get 8–12 pipeline stages — more than sufficient. K=64 tiles are never needed, so the code path was never exercised. Both bugs are latent but only manifest with small K values. Performance expectations The performance improvement from deeper pipelining may be modest on SM120/121, since these devices tend to be memory-bandwidth-bound. The primary motivation is correctness: K=64 tile shapes should be valid instantiations of the block-scaled GEMM builder, and currently they are not. How to test The simplest way to verify is a one-line change to an existing unit test. In using TileShape = Shape<_128,_128,_256>;to: using TileShape = Shape<_128,_128,_64>;Without this PR, this will either fail to compile or produce corrupt results. With the fix, it should compile and pass Both FlashInfer (v0.6.6) and vLLM (v0.18.0) already define K=64 tile configs in their SM120 enums ( |
|
I got no reason to
I think his benchmarks are fair. I did my own benchmarking her e https://github.com/brandonmmusic-max/sm120-moe-bench, but I was testing for effects on different things. The fix may help more on the prefill side (i was getting 17k or so) The prefill numbers are pretty good. But I haven’t ran benchmark is such a manner, that I can’t exclude anny increase being related to p2p or mtp; i was optimizing for my particular neuro-symbolic pipeline for my local workflow.. I’m glad to hear you were getting 2x speed up! Would love to hear more about your setup! |
Update: Prefill benchmarks + methodology fixMy earlier decode results were correct (no difference), but the initial prefill numbers were wrong — the "+19%" was a cold-start artifact (first request without warmup hitting radix cache miss + JIT warmup). Re-ran with proper methodology: 3 warmup + 5 measured runs per prompt size. Hardware: 4x RTX PRO 6000 Blackwell Server Edition (SM120, 96GB GDDR7 each, PCIe Gen5) Decode throughput (single user, 512 output tokens)
Prefill throughput (3 warmup + 5 measured, max_tokens=1)
All differences are within noise (<1%). Standard deviation across runs was 0.1–3.1ms. AnalysisThe MoE GEMM dimensions for Qwen3.5-397B with TP=4 are:
For decode (M=1): the operation is a GEMV, entirely memory-bandwidth bound. Pipeline stages (K=64's advantage) don't help because compute is not the bottleneck — reading weights from GDDR7 dominates. For prefill (M=885–7099): with TP=4, per-GPU M is ~220–1775. Even at these sizes, GDDR7 bandwidth (1.79 TB/s per GPU) appears sufficient that the additional pipeline depth from K=64 tiles doesn't provide measurable benefit over K=128. Note on hardware difference: Your PR reports 2x speedup on DGX Spark (SM121, unified LPDDR5X). LPDDR5X has significantly lower bandwidth (~273 GB/s) than GDDR7 (1.79 TB/s per GPU × 4 = 7.16 TB/s aggregate). This likely explains why K=64's compute efficiency gains are visible on DGX Spark (compute-bound) but not on discrete RTX PRO 6000 GPUs (memory-bandwidth bound even for prefill). The patches are functionally correct — JIT compilation succeeds, kernels load, no errors. They just don't provide a throughput benefit on this specific hardware configuration. |

Summary
Two fixes that enable K=64 tile shapes for block-scaled MoE GEMM on SM120 (RTX 5090/PRO 6000) and SM121 (DGX Spark GB10). Without these, K=64 tiles produce invalid TMA descriptors or overflow scale factor layout computations.
1. TMA zero-stride basis handling (
copy_traits_sm90_tma.hpp)When K=64 with SFVectorSize=32, scale factor folding creates a broadcast dimension with zero stride in
fill_tma_gmem_shape_stride(). The existing code passes this tobasis_get(), which produces invalid TMA descriptors. Fix: detect zero-stride basis viais_constant<0>and emitshape=1, stride=0.2. Scale factor block size clamping (
sm120_blockscaled_mma_builder.inl)K=64 with SFVectorSize=32 gives
NumSFAlongK=2, butBlk_SF=4. The existing divisionBlk_SF/MMA_NSFoverflows. Fix: clamp effective block size (EffBlk_SF) tomin(NumSFAlongK, Blk_SF)and conditionally fold intokBasicBlockto keep the TMA layout flat.Impact
Together these enable K=64 CTA shapes (
[128,128,64],[128,256,64],[256,128,64]) which achieve 7-11 pipeline stages vs 2 with K=128, giving ~2x single-user decode throughput on SM120/SM121 for NVFP4 MoE models.Behavior for existing tile sizes
The EffBlk_SF clamping only activates when
NumSFAlongK < Blk_SF(i.e., small K values). For K >= 128 the code is identical to the current behavior:Testing
Tested on DGX Spark (SM121, 128 GB unified LPDDR5X) with:
Caveats for reviewers
copy_traits_sm90_tma.hppis core CuTe infrastructure used by all TMA operations across all architectures. Theif constexprbranch only triggers for compile-time zero-stride constants (Int<0>), so there is no runtime cost and no change to existing non-zero-stride code paths. However, this file has broad blast radius — please verify thatshape=1, stride=0is the universally correct interpretation for a zero-stride TMA basis (we believe broadcast is the only valid case).Only tested on SM121 (DGX Spark). The EffBlk_SF clamping is architecture-independent and should work on SM120 (RTX 5090) and SM100 (datacenter Blackwell) if K=64 tiles are used there, but we have not validated those configurations.
The multi-mode branch (
tma_i_rank != 1) infill_tma_gmem_shape_stridealso callsbasis_getand could theoretically encounter zero strides. That branch has existinggcdand!= 0guards that may already handle it, but we only patched the rank-1 path.K=32 is untested.
NumSFAlongK=1withBlk_SF=4would produceEffBlk_SF=1andFoldSFIntoBasicBlock=false(since1 > MMA_NSFis false). This takes a different path than K=64 and is unverified.Related work
cc @brandonmmusic-max — your FlashInfer PR #2786 adds the K=64 shapes but the underlying CUTLASS TMA and scale factor layout fixes aren't in CUTLASS 4.4.2 yet. This PR provides those. Happy to coordinate or withdraw if you're working on the CUTLASS side separately.