From f40d3817ca4fe96012efc2a49e15c86ff8acaa08 Mon Sep 17 00:00:00 2001 From: Ian Barber Date: Sun, 12 Apr 2026 15:40:46 -0700 Subject: [PATCH] [SM121] Enable native block-scaled dot_scaled for DGX Spark (GB10) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SM121 (GB10 DGX Spark) supports the same mma.sync block-scaled instructions as SM120 (RTX 5090) but was excluded from the native lowering path by exact compute capability checks. Without this fix, dot_scaled on SM121 falls through to DecomposeScaledBlocked which upcasts to bf16 — ~10 TFLOPS vs ~270 TFLOPS with native mma.sync block-scaled FP4. Tested on GB10 with both MXFP4 (scale_vec::2X, ue8m0) and NVFP4 (scale_vec::4X, ue4m3). --- lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index d7898d72e35d..5d28f5d99fdc 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -677,7 +677,7 @@ class ScaledBlockedToMMA : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite(triton::DotScaledOp dotOp, mlir::PatternRewriter &rewriter) const override { - if (computeCapability != 120) + if (computeCapability / 10 != 12) return failure(); auto numCTAs = lookupNumCTAs(rewriter); @@ -924,7 +924,7 @@ static bool mmav2SupportsFp8Operands(int computeCapability) { // although PTX instructions for mma v2 w/ fp8 operands exist for sm90 and // sm100, they are emulated as fp16 upcasts + fp16 HMMA in SASS. sm120 has // hardware support for fp8 operands w/ mmav2. - return computeCapability == 89 || computeCapability == 120; + return computeCapability == 89 || computeCapability / 10 == 12; } // promote operands of dot op if the existing combination is not natively