Skip to content

fix: Support K=64 block-scaled GEMM tiles on SM120/SM121#3121

Open
RobTand wants to merge 1 commit intoNVIDIA:mainfrom
RobTand:fix/sm120-k64-blockscaled-tma-layout
Open

fix: Support K=64 block-scaled GEMM tiles on SM120/SM121#3121
RobTand wants to merge 1 commit intoNVIDIA:mainfrom
RobTand:fix/sm120-k64-blockscaled-tma-layout

Conversation

@RobTand
Copy link
Copy Markdown

@RobTand RobTand commented Mar 20, 2026

Summary

Two fixes that enable K=64 tile shapes for block-scaled MoE GEMM on SM120 (RTX 5090/PRO 6000) and SM121 (DGX Spark GB10). Without these, K=64 tiles produce invalid TMA descriptors or overflow scale factor layout computations.

1. TMA zero-stride basis handling (copy_traits_sm90_tma.hpp)

When K=64 with SFVectorSize=32, scale factor folding creates a broadcast dimension with zero stride in fill_tma_gmem_shape_stride(). The existing code passes this to basis_get(), which produces invalid TMA descriptors. Fix: detect zero-stride basis via is_constant<0> and emit shape=1, stride=0.

2. Scale factor block size clamping (sm120_blockscaled_mma_builder.inl)

K=64 with SFVectorSize=32 gives NumSFAlongK=2, but Blk_SF=4. The existing division Blk_SF/MMA_NSF overflows. Fix: clamp effective block size (EffBlk_SF) to min(NumSFAlongK, Blk_SF) and conditionally fold into kBasicBlock to keep the TMA layout flat.

Impact

Together these enable K=64 CTA shapes ([128,128,64], [128,256,64], [256,128,64]) which achieve 7-11 pipeline stages vs 2 with K=128, giving ~2x single-user decode throughput on SM120/SM121 for NVFP4 MoE models.

Behavior for existing tile sizes

The EffBlk_SF clamping only activates when NumSFAlongK < Blk_SF (i.e., small K values). For K >= 128 the code is identical to the current behavior:

K SFVectorSize NumSFAlongK Blk_SF EffBlk_SF FoldSF? Behavior
64 32 2 4 2 yes New: clamped and folded
128 32 4 4 4 no Unchanged
256 32 8 4 4 no Unchanged

Testing

Tested on DGX Spark (SM121, 128 GB unified LPDDR5X) with:

  • Nemotron-3-Super-120B-A12B-NVFP4 — 24 tok/s (up from ~12 tok/s without K=64)
  • Qwen3.5-122B-A10B-NVFP4 — 26 tok/s

Caveats for reviewers

  1. copy_traits_sm90_tma.hpp is core CuTe infrastructure used by all TMA operations across all architectures. The if constexpr branch only triggers for compile-time zero-stride constants (Int<0>), so there is no runtime cost and no change to existing non-zero-stride code paths. However, this file has broad blast radius — please verify that shape=1, stride=0 is the universally correct interpretation for a zero-stride TMA basis (we believe broadcast is the only valid case).

  2. Only tested on SM121 (DGX Spark). The EffBlk_SF clamping is architecture-independent and should work on SM120 (RTX 5090) and SM100 (datacenter Blackwell) if K=64 tiles are used there, but we have not validated those configurations.

  3. The multi-mode branch (tma_i_rank != 1) in fill_tma_gmem_shape_stride also calls basis_get and could theoretically encounter zero strides. That branch has existing gcd and != 0 guards that may already handle it, but we only patched the rank-1 path.

  4. K=32 is untested. NumSFAlongK=1 with Blk_SF=4 would produce EffBlk_SF=1 and FoldSFIntoBasicBlock=false (since 1 > MMA_NSF is false). This takes a different path than K=64 and is unverified.

Related work

  • FlashInfer flashinfer-ai/flashinfer#2786 — adds K=64 tile shapes to FlashInfer's kernel generation and dispatch. Depends on these CUTLASS fixes for correctness.
  • CUTLASS #3096 — original issue describing SM120 grouped GEMM failures
  • CUTLASS #3120 — our companion PR excluding SM12x from E2M1 PTX (separate issue, same hardware)

cc @brandonmmusic-max — your FlashInfer PR #2786 adds the K=64 shapes but the underlying CUTLASS TMA and scale factor layout fixes aren't in CUTLASS 4.4.2 yet. This PR provides those. Happy to coordinate or withdraw if you're working on the CUTLASS side separately.

Two fixes that enable K=64 tile shapes for block-scaled MoE GEMM on
SM120 (RTX 5090/PRO 6000) and SM121 (DGX Spark GB10):

1. TMA zero-stride basis handling (copy_traits_sm90_tma.hpp):
   When K=64 with SFVectorSize=32, scale factor folding creates a
   broadcast dimension with zero stride. The existing code passes this
   to basis_get() which produces invalid TMA descriptors. Fix: detect
   zero-stride basis via is_constant<0> and emit shape=1, stride=0.

2. Scale factor block size clamping (sm120_blockscaled_mma_builder.inl):
   K=64 with SFVectorSize=32 gives NumSFAlongK=2, but Blk_SF=4. The
   division Blk_SF/MMA_NSF overflows. Fix: clamp effective block size
   (EffBlk_SF) to min(NumSFAlongK, Blk_SF) and conditionally fold
   into kBasicBlock to keep TMA layout flat.

Together these enable K=64 CTA shapes ([128,128,64], [128,256,64],
[256,128,64]) which achieve 7-11 pipeline stages vs 2 with K=128,
giving ~2x single-user decode throughput on SM120/SM121.

Tested on DGX Spark (SM121) with Nemotron-3-Super-120B and
Qwen3.5-122B NVFP4 models via FlashInfer CUTLASS MoE backend.

Related: FlashInfer PR flashinfer-ai/flashinfer#2786 adds the K=64
tile shapes to FlashInfer's kernel generation but depends on these
CUTLASS fixes for correctness.

Signed-off-by: Rob Tand <robert.tand@icloud.com>
@brandonmmusic-max
Copy link
Copy Markdown

brandonmmusic-max commented Mar 20, 2026 via email

@voipmonitor
Copy link
Copy Markdown

Testing on RTX PRO 6000 (SM120, TP=4)

We applied both this PR and FlashInfer #2786 together on 4x RTX PRO 6000 Blackwell Server Edition (SM120, PCIe, 96GB each) running SGLang with nvidia/Qwen3.5-397B-A17B-NVFP4.

Setup:

  • Docker image: voipmonitor/sglang:test-cu132 (CUDA 13.2, PyTorch nightly cu132)
  • CUTLASS 4.4.2 + this PR patch applied to headers
  • FlashInfer 0.6.6 reinstalled from PR Fixed compilation error when using StreamK scheduler + PDL. (#2686) #2786 branch
  • Backend: --fp4-gemm-backend flashinfer_cutlass --moe-runner-backend flashinfer_cutlass
  • No speculative decoding, no torch.compile
  • Single-request decode benchmark: 512 tokens, temperature=0

Results (single-user decode throughput):

Config Run 1 Run 2 Run 3
Baseline (K=128 only) 68.5 tok/s 70.2 tok/s 70.3 tok/s
With K=64 patches (both PRs) 69.9 tok/s 70.2 tok/s 70.2 tok/s

No measurable difference on this configuration. The patches applied cleanly and FlashInfer JIT compiled the K=64 kernels without errors, but the dispatch doesn't seem to select K=64 tiles for this model/shape combination, or the benefit is offset by TP=4 communication overhead.

Questions:

  1. What model and configuration did you benchmark for the ~2x speedup? (The PR mentions Nemotron-3 on DGX Spark SM121 — was that single-GPU unified memory?)
  2. Does K=64 benefit primarily smaller batch sizes or specific expert dimensions?
  3. Is there a way to force K=64 tile selection to verify it's being dispatched?

Both patches are functionally correct on SM120 — no crashes, no compilation errors. The CUTLASS TMA zero-stride fix and EffBlk_SF clamping work as described.

@RobTand
Copy link
Copy Markdown
Author

RobTand commented Mar 22, 2026

Thanks for taking the time to review this — really appreciate it.

To address your question about forcing K=64 tiles and demonstrating necessity:

SMEM constraint on SM120/121

SM120/SM121 devices (RTX 5090, DGX Spark GB10) have 99 KB of opt-in shared memory per block, compared to 227 KB on SM100. This changes which tile shapes are viable for block-scaled GEMM.

Pipeline stage impact

The builder computes pipeline stages as available_smem / bytes_per_stage. For NVF4 with SFVectorSize=32:

Tile Shape Bytes/Stage SM100 Stages SM120/121 Stages
128×128×K=128 ~17 KB 12 5
128×256×K=128 ~26 KB 8 3
128×128×K=64 ~9 KB 25 10
128×256×K=64 ~13 KB 17 7

K=128 tiles do work on SM120/121 (3–5 stages), so this isn't strictly about making things functional at K=128. The issue is that K=64 tiles — the natural choice for the constrained SMEM budget — cannot be instantiated at all.

Why K=64 fails without this fix

Two bugs surface only when K/SFVectorSize < Blk_SF (i.e., K=64 with SFVectorSize=32 gives NumSFAlongK=2, but Blk_SF=4):

  • TMA descriptor corruption (copy_traits_sm90_tma.hpp): Scale factor folding creates a broadcast dimension with zero stride. basis_get() on a zero-stride basis produces undefined gmem_prob_shape/gmem_prob_stride values, resulting in an invalid TMA descriptor.

  • Scale factor layout overflow (sm120_blockscaled_mma_builder.inl): The division Blk_SF/MMA_NSF assumes Blk_SF ≤ NumSFAlongK. When that doesn't hold, the resulting Blk_Elems layout overflows, producing an incorrect TMA tensor map for scale factors.

Why these were never caught on SM100

With 227 KB of SMEM, K=128 tiles get 8–12 pipeline stages — more than sufficient. K=64 tiles are never needed, so the code path was never exercised. Both bugs are latent but only manifest with small K values.

Performance expectations

The performance improvement from deeper pipelining may be modest on SM120/121, since these devices tend to be memory-bandwidth-bound. The primary motivation is correctness: K=64 tile shapes should be valid instantiations of the block-scaled GEMM builder, and currently they are not.

How to test

The simplest way to verify is a one-line change to an existing unit test. In test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_bf16.cu, change:

using TileShape = Shape<_128,_128,_256>;

to:

using TileShape = Shape<_128,_128,_64>;

Without this PR, this will either fail to compile or produce corrupt results. With the fix, it should compile and pass TestSmall correctness checks.

Both FlashInfer (v0.6.6) and vLLM (v0.18.0) already define K=64 tile configs in their SM120 enums (CtaShape128x128x64B, etc.) but deliberately exclude them from dispatch tables because they can't be instantiated against upstream CUTLASS today.

@brandonmmusic-max
Copy link
Copy Markdown

I got no reason to

Testing on RTX PRO 6000 (SM120, TP=4)

We applied both this PR and FlashInfer #2786 together on 4x RTX PRO 6000 Blackwell Server Edition (SM120, PCIe, 96GB each) running SGLang with nvidia/Qwen3.5-397B-A17B-NVFP4.

Setup:

  • Docker image: voipmonitor/sglang:test-cu132 (CUDA 13.2, PyTorch nightly cu132)
  • CUTLASS 4.4.2 + this PR patch applied to headers
  • FlashInfer 0.6.6 reinstalled from PR Fixed compilation error when using StreamK scheduler + PDL. (#2686) #2786 branch
  • Backend: --fp4-gemm-backend flashinfer_cutlass --moe-runner-backend flashinfer_cutlass
  • No speculative decoding, no torch.compile
  • Single-request decode benchmark: 512 tokens, temperature=0

Results (single-user decode throughput):

Config Run 1 Run 2 Run 3
Baseline (K=128 only) 68.5 tok/s 70.2 tok/s 70.3 tok/s
With K=64 patches (both PRs) 69.9 tok/s 70.2 tok/s 70.2 tok/s
No measurable difference on this configuration. The patches applied cleanly and FlashInfer JIT compiled the K=64 kernels without errors, but the dispatch doesn't seem to select K=64 tiles for this model/shape combination, or the benefit is offset by TP=4 communication overhead.

Questions:

  1. What model and configuration did you benchmark for the ~2x speedup? (The PR mentions Nemotron-3 on DGX Spark SM121 — was that single-GPU unified memory?)
  2. Does K=64 benefit primarily smaller batch sizes or specific expert dimensions?
  3. Is there a way to force K=64 tile selection to verify it's being dispatched?

Both patches are functionally correct on SM120 — no crashes, no compilation errors. The CUTLASS TMA zero-stride fix and EffBlk_SF clamping work as described.

I think his benchmarks are fair. I did my own benchmarking her e https://github.com/brandonmmusic-max/sm120-moe-bench, but I was testing for effects on different things. The fix may help more on the prefill side (i was getting 17k or so) The prefill numbers are pretty good. But I haven’t ran benchmark is such a manner, that I can’t exclude anny increase being related to p2p or mtp; i was optimizing for my particular neuro-symbolic pipeline for my local workflow.. I’m glad to hear you were getting 2x speed up! Would love to hear more about your setup!

@voipmonitor
Copy link
Copy Markdown

Update: Prefill benchmarks + methodology fix

My earlier decode results were correct (no difference), but the initial prefill numbers were wrong — the "+19%" was a cold-start artifact (first request without warmup hitting radix cache miss + JIT warmup). Re-ran with proper methodology: 3 warmup + 5 measured runs per prompt size.

Hardware: 4x RTX PRO 6000 Blackwell Server Edition (SM120, 96GB GDDR7 each, PCIe Gen5)
Model: nvidia/Qwen3.5-397B-A17B-NVFP4, TP=4, flashinfer_cutlass backend
Patches: CUTLASS #3121 + FlashInfer #2786 both applied, JIT cache cleared, FlashInfer reinstalled from PR branch

Decode throughput (single user, 512 output tokens)

Config Run 1 Run 2 Run 3
Baseline (K=128 only) 68.5 tok/s 70.2 tok/s 70.3 tok/s
With K=64 patches 69.9 tok/s 70.2 tok/s 70.2 tok/s

Prefill throughput (3 warmup + 5 measured, max_tokens=1)

Prompt tokens Baseline (K=128) K=64 patch Diff
885 8,761 tok/s 8,767 tok/s +0.1%
1,769 10,507 tok/s 10,563 tok/s +0.5%
3,550 11,855 tok/s 11,967 tok/s +0.9%
7,099 12,445 tok/s 12,549 tok/s +0.8%

All differences are within noise (<1%). Standard deviation across runs was 0.1–3.1ms.

Analysis

The MoE GEMM dimensions for Qwen3.5-397B with TP=4 are:

  • GEMM1 (gate+up): M=tokens, N=512, K=4096
  • GEMM2 (down): M=tokens, N=4096, K=256

For decode (M=1): the operation is a GEMV, entirely memory-bandwidth bound. Pipeline stages (K=64's advantage) don't help because compute is not the bottleneck — reading weights from GDDR7 dominates.

For prefill (M=885–7099): with TP=4, per-GPU M is ~220–1775. Even at these sizes, GDDR7 bandwidth (1.79 TB/s per GPU) appears sufficient that the additional pipeline depth from K=64 tiles doesn't provide measurable benefit over K=128.

Note on hardware difference: Your PR reports 2x speedup on DGX Spark (SM121, unified LPDDR5X). LPDDR5X has significantly lower bandwidth (~273 GB/s) than GDDR7 (1.79 TB/s per GPU × 4 = 7.16 TB/s aggregate). This likely explains why K=64's compute efficiency gains are visible on DGX Spark (compute-bound) but not on discrete RTX PRO 6000 GPUs (memory-bandwidth bound even for prefill).

The patches are functionally correct — JIT compilation succeeds, kernels load, no errors. They just don't provide a throughput benefit on this specific hardware configuration.

@johnnynunez
Copy link
Copy Markdown

@depaulmillz

@depaulmillz
Copy link
Copy Markdown
Contributor

When instantiating the 128x128x64 NVFP4 tile I am seeing refcheck failures with this MR.

For the 128x128x64 with MXFP4, the reason you are hitting issues is the layout for SMEM that you have computed is (((32, 4),1),((32,2),1,1),10) : (((16, 4),256),((0,1),2,256),256).

image

Along the contiguous dimension you are copying 2 elements (shaded in green) then skipping over 2 elements (not shaded) and repeating this pattern. This is not possible to copy since we require 16B of contiguous elements.

To use this tiling pattern, it requires switching to a universal copy atom instead (copies are 2B contiguous) or loading 1 SFA/SFB tile per every 2 MMA tiles at least.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants