From ab8665f1dfe95faf18a8e041de0dc01184b6cc9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Wichrowski?= Date: Mon, 20 Apr 2026 01:44:58 +0200 Subject: [PATCH] Extend support for small MMAv2 FP64: single 8x8x4 instructions. --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 4 +- .../TritonGPU/IR/LinearLayoutConversions.cpp | 9 ++- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 2 +- python/test/unit/language/test_core.py | 12 ++- test/Conversion/tritongpu_to_llvm.mlir | 11 ++- test/TritonGPU/accelerate-matmul.mlir | 4 +- third_party/nvidia/backend/compiler.py | 2 + .../DotOpToLLVM/MMAv2.cpp | 76 ++++++++++--------- 8 files changed, 70 insertions(+), 50 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 3da2732e8cd3..7b28429a743d 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2624,10 +2624,10 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, tileSize.push_back(1); } // warpSizeK * (warpRepK * VecBitWidth) - auto tileBitWidthK = bitwidth == 64 ? (2 * 256) : (4 * 64); + auto tileBitWidthK = bitwidth == 64 ? (1 * 256) : (4 * 64); if (opIdx == 0) { // m x k - tileSize.push_back(16); + tileSize.push_back(bitwidth == 64 ? 8 : 16); tileSize.push_back(tileBitWidthK / bitwidth); } else { // k x n diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index db8951268577..efdb98230c7b 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -943,13 +943,16 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef shape, MLIRContext *ctx = mma.getContext(); SmallVector tileShape(rank, 1); + unsigned instrM = mma.getInstrShape()[rank - 2]; + // For fp64 (instrM == 8), the native m8n8k4 instruction uses a smaller tile. + unsigned kTileMultiplier = instrM == 8 ? 4 : 8; if (isA) { - tileShape[rank - 2] = 16; - tileShape[rank - 1] = kWidth * 8; + tileShape[rank - 2] = instrM; + tileShape[rank - 1] = kWidth * kTileMultiplier; } else { // Hopper takes the rhs via shared memory assert(mma.isAmpere()); - tileShape[rank - 2] = kWidth * 8; + tileShape[rank - 2] = kWidth * kTileMultiplier; tileShape[rank - 1] = 8; } auto order = getOrderForDotOperand(dot.getOpIdx(), rank, /*kContig*/ true); diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 2e09475b418e..7e3196be67c2 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -37,7 +37,7 @@ SmallVector mmaVersionToInstrShape(int version, auto rank = shape.size(); SmallVector ret(rank, 1); ret[rank - 1] = 8; - ret[rank - 2] = 16; + ret[rank - 2] = eltType.isF64() ? 8 : 16; return ret; } else if (version == 3) { unsigned k = 256 / eltType.getIntOrFloatBitWidth(); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6593950f07eb..b81574f1adbc 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3147,6 +3147,12 @@ def get_test_dot_base_cases(): if not (input_precision != 'ieee' and (in_dtype in ['float16']))] +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +def get_test_dot_small_fp64_cases(): + return [(*shape, 1, False, False, 'none', 'ieee', 'float64', 'float64', 1, None) + for shape in [(8, 8, 4), (8, 8, 8), (16, 8, 4), (8, 8, 16)]] + + # M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size def get_test_dot_softmax(): return [(128, 128, 64, 8, False, False, 'softmax', 'ieee', 'float16', 'float32', 1, None)] @@ -3275,7 +3281,8 @@ def get_test_small_dots_cases(): get_test_dot_small_mn_wmma_cases() + \ get_test_dot_small_k_wmma_cases() + \ get_test_dot_softmax() + \ - get_test_small_dots_cases()) + get_test_small_dots_cases() + \ + get_test_dot_small_fp64_cases()) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size, num_ctas, device): @@ -3286,7 +3293,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty pytest.skip(f"input_precision {input_precision} is not supported in the interpreter") else: if not is_hip() and K < 16: - pytest.skip("small dots are supported only on HIP at the moment") + if in_dtype != 'float64': + pytest.skip("small dots are supported only on HIP at the moment") if is_cuda(): capability = torch.cuda.get_device_capability() diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 89168d46523e..28483e84e6a1 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -2022,7 +2022,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr // ----- -#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 8]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [8, 8]}> #shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}> #shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 2, order = [1, 0]}> #smem = #ttg.shared_memory @@ -2319,7 +2319,7 @@ tt.func @gather_in_shared(%arg0: tensor<16x4xi32, #blocked1>, %arg1: tensor<8x4x // ----- -#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [1, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [8, 8]}> #dot = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}> #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -2327,10 +2327,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @gather_in_shared_dot_input(%arg0: tensor<16x4xi32, #blocked>, %arg1: tensor<8x4xf32, #dot>) { // CHECK-LABEL: gather_in_shared_dot_input + + // CHECK: [[S0:%.*]] = llvm.extractvalue %arg1[0] + // CHECK: [[SMEM_BASE:%.*]] = llvm.mlir.addressof @global_smem // CHECK-NEXT: [[SMEM:%.*]] = llvm.getelementptr [[SMEM_BASE]] - // CHECK-COUNT-4: store - // CHECK-NEXT: nvvm.barrier0 + // CHECK: insertelement [[S0]] + // CHECK: nvvm.barrier0 // CHECK: [[I0:%.*]] = llvm.extractvalue %arg0[0] diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index d3b03f7650ba..14164704e0bf 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -121,7 +121,7 @@ module attributes {"ttg.target" = "cuda:89", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- -// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> +// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [8, 8]}> #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { tt.func public @fp64_dot( @@ -136,7 +136,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // ----- -// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> +// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [8, 8]}> #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @fp64_dot_hopper( diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 818d4f338a6b..fc6362b24a3d 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -25,6 +25,8 @@ def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, # For small M/N the input we can still use tensorcores with padding. if lhs_bitwidth == 8: return (1, 1, 32) + elif lhs_bitwidth == 64: + return (1, 1, 4) else: return (1, 1, 16) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 9ca5add00eee..bc7d1f80aae9 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -307,8 +307,8 @@ static Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) { Type fp32Ty = type::f32Ty(ctx); Type fp16Ty = type::f16Ty(ctx); Type i32Ty = type::i32Ty(ctx); - Type fp64x4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp64Ty)); + Type fp64x2Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(2, fp64Ty)); Type fp32x4Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32Ty)); Type i32x4Ty = @@ -337,7 +337,7 @@ static Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) { case TensorCoreType::INT32_INT8_INT8_INT32: return i32x4Ty; case TensorCoreType::FP64_FP64_FP64_FP64: - return fp64x4Ty; + return fp64x2Ty; case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32_SCALE_VEC_1X: case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32_SCALE_VEC_1X: case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32_SCALE_VEC_1X: @@ -588,42 +588,44 @@ static void callMmaTuringFp16(PTXBuilder &builder, int b, mma(retArgs, aArgs2, bArgs2, cArgs); } -// Repeat m8n8k4 (2, 1, 4) times, as m16n8k16 on hopper. +// Emit m8n8k4 fp64 MMA instructions. +// With numRegisters.m=1, numRegisters.k=1: emits a single m8n8k4. +// With numRegisters.m=2, numRegisters.k=2: emits 2*2=4 m8n8k4 grouped as +// m16n8k8 static void callMmaAmpereFp64(PTXBuilder &builder, int b, const BaseOffset &base, mlir::triton::PTXInstr &mma, unsigned numMmaRets, unsigned colsPerThread, int numCPackedElem, unsigned batchOffset, ValueTableV2 &ha, ValueTableV2 &hb, const SmallVector &fc, - int kRegs) { - auto retArgs1 = builder.newListOperand(numMmaRets / 2, "=d"); - auto retArgs2 = builder.newListOperand(numMmaRets / 2, "=d"); - auto cArgs1 = builder.newListOperand(); - for (int i = 0; i < numMmaRets / 2; ++i) { - cArgs1->listAppend(builder.newOperand( - fc[(base.m * colsPerThread + 4 * base.n) / numCPackedElem + i + - batchOffset * b], - std::to_string(i))); - // reuse the output registers - } - auto cArgs2 = builder.newListOperand(); - for (int i = numMmaRets / 2; i < numMmaRets; ++i) { - cArgs2->listAppend(builder.newOperand( - fc[(base.m * colsPerThread + 4 * base.n) / numCPackedElem + i + - batchOffset * b], - std::to_string(i))); - // reuse the output registers + int kRegs, int mRegs) { + // Each m sub-tile gets numMmaRets/mRegs results (2 f64 values per m8n8k4). + int retsPerM = numMmaRets / mRegs; + + // Build ret/c operand lists for each m sub-tile. + SmallVector retArgsList, cArgsList; + for (int vm = 0; vm < mRegs; ++vm) { + auto *retArgs = builder.newListOperand(retsPerM, "=d"); + auto *cArgs = builder.newListOperand(); + for (int i = 0; i < retsPerM; ++i) { + cArgs->listAppend( + builder.newOperand(fc[((base.m + vm) * colsPerThread + + numMmaRets * numCPackedElem * base.n) / + numCPackedElem + + i + batchOffset * b], + std::to_string(i))); + } + retArgsList.push_back(retArgs); + cArgsList.push_back(cArgs); } + for (int vk = 0; vk < kRegs; ++vk) { - auto aArgs1 = builder.newListOperand({ - {ha[{b, base.m, base.k + vk}], "d"}, - }); auto bArgs = builder.newListOperand({{hb[{b, base.n, base.k + vk}], "d"}}); - auto aArgs2 = builder.newListOperand({ - {ha[{b, base.m + 1, base.k + vk}], "d"}, - }); - mma(retArgs1, aArgs1, bArgs, cArgs1); - mma(retArgs2, aArgs2, bArgs, cArgs2); + for (int vm = 0; vm < mRegs; ++vm) { + auto aArgs = + builder.newListOperand({{ha[{b, base.m + vm, base.k + vk}], "d"}}); + mma(retArgsList[vm], aArgs, bArgs, cArgsList[vm]); + } } } @@ -776,8 +778,8 @@ convertMMAImpl(DotOpInterface op, Value llvmA, Value llvmB, Value llvmC, auto fc = unpackLLElements(loc, loadedC, rewriter); int bitwidthRet = dTensorTy.getElementType().getIntOrFloatBitWidth(); - auto numMmaRets = bitwidthRet == 64 ? 4 : bitwidthRet / 8; - int numCPackedElem = 4 / numMmaRets; + auto numMmaRets = bitwidthRet == 64 ? 2 : bitwidthRet / 8; + int numCPackedElem = bitwidthRet == 64 ? 1 : 4 / numMmaRets; if (mmaInstructions.find(mmaType) == mmaInstructions.end()) { return emitError(loc, "Unsupported MMA instruction for the given mma type"); @@ -801,7 +803,7 @@ convertMMAImpl(DotOpInterface op, Value llvmA, Value llvmB, Value llvmC, Type elemTy = cast(mmaOut.getType()).getBody()[0]; for (int i = 0; i < numMmaRets; ++i) { fc[(numRegisters.m * static_cast(m) * colsPerThread + - 4 * numRegisters.n * static_cast(n)) / + numMmaRets * numCPackedElem * numRegisters.n * static_cast(n)) / numCPackedElem + i + batchOffset * b] = tb.extract_val(elemTy, mmaOut, i); } @@ -845,7 +847,9 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, auto dTensorTy = op.getD().getType(); TensorCoreType mmaType = getMmaTypeDot(op, aTensorTy, bTensorTy, dTensorTy); - NumRegisters numRegisters = {2, 1, 2}; + NumRegisters numRegisters = (mmaType == TensorCoreType::FP64_FP64_FP64_FP64) + ? NumRegisters{1, 1, 1} + : NumRegisters{2, 1, 2}; const auto &instrMap = isTuring ? mmaInstrPtxTuring : mmaInstrPtxAmpere; EmitMmaCallback emit = [&](PTXBuilder &builder, int b, int m, int n, int k, @@ -854,10 +858,10 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, ValueTableV2 &ha, ValueTableV2 &hb, const SmallVector &fc, RankedTensorType dTy, int /*repK*/) { - const unsigned numCPackedElem = 4u / numMmaRets; bool isIntMMA = dTy.getElementType().isInteger(32); bool isAccF16 = dTy.getElementType().isF16(); bool isFp64MMA = dTy.getElementType().isF64(); + const unsigned numCPackedElem = isFp64MMA ? 1u : 4u / numMmaRets; BaseOffset base{numRegisters.m * m, numRegisters.n * n, numRegisters.k * k}; if (isTuring) { assert(b == 0 && "Turing only supports batch size 1"); @@ -871,7 +875,7 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, if (isFp64MMA) { callMmaAmpereFp64(builder, b, base, mma, numMmaRets, colsPerThread, numCPackedElem, batchOffset, ha, hb, fc, - numRegisters.k); + numRegisters.k, numRegisters.m); } else { callMmaV2(builder, b, base, mma, numMmaRets, colsPerThread, numCPackedElem, batchOffset, ha, hb, fc,