From 4d1f0dd3e61e322378190b158e0563577a760536 Mon Sep 17 00:00:00 2001 From: Michal Wichrowski Date: Wed, 13 May 2026 20:46:29 +0200 Subject: [PATCH] Add support m16n8k4 path for FP64 (sm90+). --- .../Dialect/TritonGPU/Transforms/Utility.h | 3 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 10 +- .../TritonGPU/IR/LinearLayoutConversions.cpp | 23 ++-- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 3 +- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 23 +++- test/Conversion/tritongpu_to_llvm_hopper.mlir | 3 +- test/TritonGPU/accelerate-matmul.mlir | 3 +- .../DotOpToLLVM/MMAv2.cpp | 111 ++++++++++++++---- 8 files changed, 139 insertions(+), 40 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 230cf96bde35..7f6bfd27ff5f 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -30,7 +30,8 @@ class SwizzledSharedEncodingAttr; // Version = 3: SmallVector mmaVersionToInstrShape(int version, const ArrayRef &shape, - Type type, int numWarps); + Type type, int numWarps, + int computeCapability = 80); // Gets the order of a tensor from its contiguity. Places the dimensions with // the largest contiguity as the inner most dimension. If the contiguity is diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 5df13a344f00..fa8f5f439558 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2783,9 +2783,17 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, } // warpSizeK * (warpRepK * VecBitWidth) auto tileBitWidthK = bitwidth == 64 ? (1 * 256) : (4 * 64); + // FP64: instrShape[M] disambiguates m16n8k4 (M=16) from legacy m8n8k4 (M=8). + unsigned mTile = 16; + if (bitwidth == 64) { + auto instrShape = getInstrShape(); + unsigned r = shape.size(); + if (instrShape[r - 2] == 8) + mTile = 8; + } if (opIdx == 0) { // m x k - tileSize.push_back(bitwidth == 64 ? 8 : 16); + tileSize.push_back(mTile); tileSize.push_back(tileBitWidthK / bitwidth); } else { // k x n diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 764d9df1115f..f82659cdcd03 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -986,8 +986,9 @@ NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef shape) const { SmallVector tileShape; if (isAmpere()) { - // Ampere.getInstrShape() returns the tile shape - tileShape = SmallVector(getInstrShape()); + // FP64 instrShape may carry a trailing K-dim; accumulator tile is M×N only. + auto instr = getInstrShape(); + tileShape = SmallVector(instr.take_front(rank)); } else { assert(isHopper()); auto instrShapeMNK = getInstrShape(); @@ -1014,16 +1015,24 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef shape, MLIRContext *ctx = mma.getContext(); SmallVector tileShape(rank, 1); - unsigned instrM = mma.getInstrShape()[rank - 2]; - // For fp64 (instrM == 8), the native m8n8k4 instruction uses a smaller tile. - unsigned kTileMultiplier = instrM == 8 ? 4 : 8; + auto instrShape = mma.getInstrShape(); + unsigned instrM = instrShape[rank - 2]; + // K-tile: FP64 m16-family stores K explicitly in instrShape; legacy + // m8n8k4 has K=4 implicit; everything else uses kWidth*8. + unsigned kTile; + if (instrShape.size() > static_cast(rank)) + kTile = instrShape.back(); + else if (instrM == 8) + kTile = 4; + else + kTile = kWidth * 8; if (isA) { tileShape[rank - 2] = instrM; - tileShape[rank - 1] = kWidth * kTileMultiplier; + tileShape[rank - 1] = kTile; } else { // Hopper takes the rhs via shared memory assert(mma.isAmpere()); - tileShape[rank - 2] = kWidth * kTileMultiplier; + tileShape[rank - 2] = kTile; tileShape[rank - 1] = 8; } auto order = getOrderForDotOperand(dot.getOpIdx(), rank, /*kContig*/ true); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 82789d800f9a..ac9155929b3e 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -339,7 +339,8 @@ static MMAEncodingResult createMMAEncodingForDot(DotOpInterface dotOp, auto CGALayout = getCGALayout(oldRetType.getEncoding()); auto retShapePerCTA = getShapePerCTA(oldRetType); auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA, - oldAType.getElementType(), numWarps); + oldAType.getElementType(), numWarps, + computeCapability); auto warpsPerTile = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, numWarps, instrShape); diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 427769628636..1fdc9cbbf595 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -28,16 +28,33 @@ namespace mlir { using namespace triton; +// FP64 m16-family K-tile selector. Returns 0 on sm_80 (falls back to legacy +// m8n8k4). Extend here for supporting K=8/16 — instrShape carries the +// chosen K downstream. +static unsigned pickFp64MmaK(int computeCapability) { + if (computeCapability < 90) + return 0; + return 4; +} + SmallVector mmaVersionToInstrShape(int version, const ArrayRef &shape, - Type eltType, int numWarps) { + Type eltType, int numWarps, + int computeCapability) { if (version == 1) return {16, 16}; else if (version == 2) { auto rank = shape.size(); - SmallVector ret(rank, 1); + // FP64 sm_90+: m16n8kK with K stored as a trailing instrShape elt. + // FP64 sm_80: legacy m8n8k4 (2-elt instrShape, K=4 implicit). + bool isF64 = eltType.isF64(); + unsigned f64K = isF64 ? pickFp64MmaK(computeCapability) : 0; + bool isF64M16 = f64K != 0; + SmallVector ret(rank + (isF64M16 ? 1u : 0u), 1); ret[rank - 1] = 8; - ret[rank - 2] = eltType.isF64() ? 8 : 16; + ret[rank - 2] = isF64 && !isF64M16 ? 8 : 16; + if (isF64M16) + ret[rank] = f64K; return ret; } else if (version == 3) { unsigned k = 256 / eltType.getIntOrFloatBitWidth(); diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 537647f3d1bb..000b1f8940d3 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -639,7 +639,8 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- -#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 8]}> +// Legacy m8n8k4: instrShape=[8, 8] (default FP64 path uses m16n8k4 — see below). +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [8, 8]}> #shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}> #shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 2, order = [1, 0]}> #smem = #ttg.shared_memory diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 4a2756cb5110..dd7f88d099af 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -121,6 +121,7 @@ module attributes {"ttg.target" = "cuda:89", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- +// sm_80: legacy m8n8k4 (m16n8k4.f64 needs sm_90+). // CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [8, 8]}> #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { @@ -136,7 +137,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // ----- -// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [8, 8]}> +// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8, 4]}> #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @fp64_dot_hopper( diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 478d067fa901..4cb493fdcc75 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -289,7 +289,8 @@ enum class TensorCoreType : uint8_t { INT32_INT4_INT4_INT32, // Not implemented INT32_INT8_INT8_INT32, // Not implemented // double precision tensor core instr - FP64_FP64_FP64_FP64, + FP64_FP64_FP64_FP64, // m16n8k4.f64 (Phase 1+) + FP64_FP64_FP64_FP64_M8, // m8n8k4.f64 (legacy, used when instrShape[M] = 8) // scaled mxfp8 x mxfp8 matmul FP32_FP8E5M2_FP8E5M2_FP32_SCALE_VEC_1X, FP32_FP8E5M2_FP8E4M3FN_FP32_SCALE_VEC_1X, @@ -309,6 +310,8 @@ static Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) { Type i32Ty = type::i32Ty(ctx); Type fp64x2Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(2, fp64Ty)); + Type fp64x4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp64Ty)); Type fp32x4Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32Ty)); Type i32x4Ty = @@ -337,6 +340,8 @@ static Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) { case TensorCoreType::INT32_INT8_INT8_INT32: return i32x4Ty; case TensorCoreType::FP64_FP64_FP64_FP64: + return fp64x4Ty; + case TensorCoreType::FP64_FP64_FP64_FP64_M8: return fp64x2Ty; case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32_SCALE_VEC_1X: case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32_SCALE_VEC_1X: @@ -427,8 +432,18 @@ static TensorCoreType getMmaTypeDot(DotOp op, RankedTensorType aTy, llvm::isa(bTy.getElementType())) return TensorCoreType::FP16_FP8E4M3FN_FP8E4M3FN_FP16; } else if (dTy.getElementType().isF64()) { - if (aTy.getElementType().isF64() && bTy.getElementType().isF64()) + if (aTy.getElementType().isF64() && bTy.getElementType().isF64()) { + // instrShape[M] selects m16n8k4 (M=16) vs legacy m8n8k4 (M=8). + auto mmaEnc = dyn_cast(dTy.getEncoding()); + if (mmaEnc) { + auto instrShape = mmaEnc.getInstrShape(); + unsigned rank = dTy.getRank(); + unsigned instrM = instrShape[rank - 2]; + if (instrM == 8) + return TensorCoreType::FP64_FP64_FP64_FP64_M8; + } return TensorCoreType::FP64_FP64_FP64_FP64; + } } return TensorCoreType::NOT_APPLICABLE; @@ -482,6 +497,8 @@ inline static const std::map mmaInstrPtxAmpere = { "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16"}, {TensorCoreType::FP64_FP64_FP64_FP64, + "mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64"}, + {TensorCoreType::FP64_FP64_FP64_FP64_M8, "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64"}, }; @@ -588,21 +605,15 @@ static void callMmaTuringFp16(PTXBuilder &builder, int b, mma(retArgs, aArgs2, bArgs2, cArgs); } -// Emit m8n8k4 fp64 MMA instructions. -// With numRegisters.m=1, numRegisters.k=1: emits a single m8n8k4. -// With numRegisters.m=2, numRegisters.k=2: emits 2*2=4 m8n8k4 grouped as -// m16n8k8 -static void callMmaAmpereFp64(PTXBuilder &builder, int b, - const BaseOffset &base, - mlir::triton::PTXInstr &mma, unsigned numMmaRets, - unsigned colsPerThread, int numCPackedElem, - unsigned batchOffset, ValueTableV2 &ha, - ValueTableV2 &hb, const SmallVector &fc, - int kRegs, int mRegs) { - // Each m sub-tile gets numMmaRets/mRegs results (2 f64 values per m8n8k4). +// Legacy m8n8k4.f64 emit (sm_80 fallback; selected when instrShape=[8, 8]). +static void +callMmaAmpereFp64M8K4(PTXBuilder &builder, int b, const BaseOffset &base, + mlir::triton::PTXInstr &mma, unsigned numMmaRets, + unsigned colsPerThread, int numCPackedElem, + unsigned batchOffset, ValueTableV2 &ha, ValueTableV2 &hb, + const SmallVector &fc, int kRegs, int mRegs) { int retsPerM = numMmaRets / mRegs; - // Build ret/c operand lists for each m sub-tile. SmallVector retArgsList, cArgsList; for (int vm = 0; vm < mRegs; ++vm) { auto *retArgs = builder.newListOperand(retsPerM, "=d"); @@ -629,6 +640,39 @@ static void callMmaAmpereFp64(PTXBuilder &builder, int b, } } +// Emit one m16n8k4.f64 per call. A=2 regs/thread (gid, gid+8), B=1 reg, +// C=4 regs. The two A regs come in as adjacent ha[] entries. +// TODO(fp64 m16 family): extend for m16n8k{8,16} — A/B reg counts grow +// with kWidth. K-selector to extend: pickFp64MmaK in +// TritonGPU/Transforms/Utility.cpp. +static void callMmaAmpereFp64M16K4(PTXBuilder &builder, int b, + const BaseOffset &base, + mlir::triton::PTXInstr &mma, + unsigned numMmaRets, unsigned colsPerThread, + int numCPackedElem, unsigned batchOffset, + ValueTableV2 &ha, ValueTableV2 &hb, + const SmallVector &fc) { + assert(numMmaRets == 4 && "m16n8k4.f64 produces 4 f64 outputs/thread"); + + auto *retArgs = builder.newListOperand(numMmaRets, "=d"); + auto *cArgs = builder.newListOperand(); + unsigned cBase = + (base.m * colsPerThread + numMmaRets * numCPackedElem * base.n) / + numCPackedElem + + batchOffset * b; + for (unsigned i = 0; i < numMmaRets; ++i) { + cArgs->listAppend(builder.newOperand(fc[cBase + i], std::to_string(i))); + } + + auto *aArgs = builder.newListOperand({ + {ha[{b, base.m + 0, base.k}], "d"}, + {ha[{b, base.m + 1, base.k}], "d"}, + }); + auto *bArgs = builder.newListOperand({{hb[{b, base.n, base.k}], "d"}}); + + mma(retArgs, aArgs, bArgs, cArgs); +} + // Unified MMAV2 function for Ampere and HopperF64 architectures static void callMmaV2(PTXBuilder &builder, int b, const BaseOffset &base, mlir::triton::PTXInstr &mma, unsigned numMmaRets, @@ -778,7 +822,12 @@ convertMMAImpl(DotOpInterface op, Value llvmA, Value llvmB, Value llvmC, auto fc = unpackLLElements(loc, loadedC, rewriter); int bitwidthRet = dTensorTy.getElementType().getIntOrFloatBitWidth(); - auto numMmaRets = bitwidthRet == 64 ? 2 : bitwidthRet / 8; + // f64 m16n8k4 returns 4 f64 outputs/thread; legacy m8n8k4 returns 2. + unsigned numMmaRets; + if (bitwidthRet == 64) + numMmaRets = (mmaType == TensorCoreType::FP64_FP64_FP64_FP64_M8) ? 2 : 4; + else + numMmaRets = bitwidthRet / 8; int numCPackedElem = bitwidthRet == 64 ? 1 : 4 / numMmaRets; auto rank = dTensorTy.getRank(); @@ -849,11 +898,20 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, return op.emitError( "unsupported MMA instruction for the given operand/result types"); - NumRegisters numRegisters = (mmaType == TensorCoreType::FP64_FP64_FP64_FP64) - ? NumRegisters{1, 1, 1} - : NumRegisters{2, 1, 2}; + // f64 m16n8k4: A=2 regs (gid, gid+8). f64 m8n8k4 legacy: 1 A reg. + NumRegisters numRegisters; + if (mmaType == TensorCoreType::FP64_FP64_FP64_FP64) + numRegisters = NumRegisters{2, 1, 1}; + else if (mmaType == TensorCoreType::FP64_FP64_FP64_FP64_M8) + numRegisters = NumRegisters{1, 1, 1}; + else + numRegisters = NumRegisters{2, 1, 2}; - EmitMmaCallback emit = [&](PTXBuilder &builder, int b, int m, int n, int k, + bool isFp64M16 = mmaType == TensorCoreType::FP64_FP64_FP64_FP64; + bool isFp64M8 = mmaType == TensorCoreType::FP64_FP64_FP64_FP64_M8; + + EmitMmaCallback emit = [&, isFp64M16, isFp64M8]( + PTXBuilder &builder, int b, int m, int n, int k, mlir::triton::PTXInstr &mma, unsigned numMmaRets, unsigned colsPerThread, unsigned batchOffset, ValueTableV2 &ha, ValueTableV2 &hb, @@ -861,7 +919,7 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, int /*repK*/) { bool isIntMMA = dTy.getElementType().isInteger(32); bool isAccF16 = dTy.getElementType().isF16(); - bool isFp64MMA = dTy.getElementType().isF64(); + bool isFp64MMA = isFp64M16 || isFp64M8; const unsigned numCPackedElem = isFp64MMA ? 1u : 4u / numMmaRets; BaseOffset base{numRegisters.m * m, numRegisters.n * n, numRegisters.k * k}; if (isTuring) { @@ -873,10 +931,13 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, callMmaTuringFp16(builder, b, base, mma, numMmaRets, colsPerThread, numCPackedElem, ha, hb, fc, isAccF16); } else { - if (isFp64MMA) { - callMmaAmpereFp64(builder, b, base, mma, numMmaRets, colsPerThread, - numCPackedElem, batchOffset, ha, hb, fc, - numRegisters.k, numRegisters.m); + if (isFp64M16) { + callMmaAmpereFp64M16K4(builder, b, base, mma, numMmaRets, colsPerThread, + numCPackedElem, batchOffset, ha, hb, fc); + } else if (isFp64M8) { + callMmaAmpereFp64M8K4(builder, b, base, mma, numMmaRets, colsPerThread, + numCPackedElem, batchOffset, ha, hb, fc, + numRegisters.k, numRegisters.m); } else { callMmaV2(builder, b, base, mma, numMmaRets, colsPerThread, numCPackedElem, batchOffset, ha, hb, fc,