diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index f7b4d388c78f..cb5b2d80066f 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -1472,7 +1472,7 @@ vecIdx (index of the element in the quad; this is always along the k-dim) return $_get(context, opIdx, parent, 0); // For MMAV2 and V3 unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); - unsigned kWidth = 32 / bitwidth; + unsigned kWidth = std::max(32 / bitwidth, 1u); return $_get(context, opIdx, parent, kWidth); }]> ]; diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 0d49b632191b..06a32a229511 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -711,7 +711,7 @@ bool supportMMA(Value value, int version) { bool isFP8 = llvm::isa(elemTy); return isFP8 || elemTy.isF16() || elemTy.isBF16() || - (elemTy.isF32() && version >= 2) || + ((elemTy.isF32() || elemTy.isF64()) && version >= 2) || (elemTy.isInteger(8) && version >= 2); } diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index fb78360adaac..1f61b224912b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -235,7 +235,7 @@ struct ElementwiseInlineAsmOpConversion Type elemTy = getElementType(op.getOperand(i)); unsigned bitWidth = elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : 64; - unsigned numElementPerReg = bitWidth < 32 ? 32 / bitWidth : 1; + unsigned numElementPerReg = std::max(32 / bitWidth, 1u); numElementPerReg = std::min(numElementPerReg, numPackedElements); for (int j = 0; j < numPackedElements; j += numElementPerReg) { if (numElementPerReg == 1) { @@ -278,7 +278,7 @@ struct ElementwiseInlineAsmOpConversion // Pack return elements into 32-bits. unsigned bitWidth = ty.isIntOrFloat() ? ty.getIntOrFloatBitWidth() : 64; unsigned numElemsPerReg = - std::min(bitWidth < 32 ? 32 / bitWidth : 1, op.getPackedElement()); + std::min(std::max(32 / bitWidth, 1u), op.getPackedElement()); assert(op.getPackedElement() % numElemsPerReg == 0); if (numElemsPerReg > 1) { ty = vec_ty(ty, numElemsPerReg); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 1697b014cdc6..79526e33da1b 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2129,9 +2129,9 @@ NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { SmallVector NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, int kWidth, int opIdx) const { - assert( - kWidth >= 32 / bitwidth && - "kWidth must be >= 32 / bitwidth for this function to be well-defined"); + assert(kWidth >= std::max(32 / bitwidth, 1) && + "kWidth must be >= max(32 / bitwidth, 1) for this function to be " + "well-defined"); auto rank = shape.size(); // Broadcast long K auto warpsPerCTA = to_vector(getWarpsPerCTA()); @@ -2142,16 +2142,18 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, if (rank == 3) { tileSize.push_back(1); } + // warpSizeK * (warpRepK * VecBitWidth) + auto tileBitWidthK = (isAmpere() && bitwidth == 64) ? (4 * 256) : (4 * 64); if (opIdx == 0) { // m x k tileSize.push_back(16); - tileSize.push_back(4 * 64 / bitwidth); + tileSize.push_back(tileBitWidthK / bitwidth); } else { // k x n // Hopper path never uses the n value, since this method is only invoked // for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF // so it's fine if the n is incorrect here - tileSize.push_back(4 * 64 / bitwidth); + tileSize.push_back(tileBitWidthK / bitwidth); tileSize.push_back(8); } diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index c4ba73dadd71..0160ac6c63b9 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -333,6 +333,15 @@ class BlockedToMMA : public mlir::OpRewritePattern { auto oldBType = cast(b.getType()); auto oldRetType = cast(dotOp.getType()); + // Enable F64 MMA only on SM80/SM90 with high performance F64 tensorcore. + // Otherwise, fallback to F64 FMA for better performance. + if ((oldAType.getElementType().isF64() || + oldBType.getElementType().isF64() || + oldRetType.getElementType().isF64()) && + !(computeCapability == 80 || computeCapability == 90)) { + return failure(); + } + // get MMA encoding for the given number of warps auto CTALayout = getCTALayout(oldRetType.getEncoding()); auto retShapePerCTA = getShapePerCTA(oldRetType); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index bd739a627556..effe3a67f127 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -130,6 +130,18 @@ def check_type_supported(dtype, device): pytest.skip("bfloat16 is not supported in the interpreter") +def get_src_element_ty_size(dtype_str): + if dtype_str in ["int8", "uint8", "float8e4b15"]: + return 1 + if dtype_str == "float16": + return 2 + if dtype_str == "float32" or dtype_str == "tensorfloat32": + return 4 + if dtype_str == "float64": + return 8 + raise ValueError(f"Unknown dtype {dtype_str}") + + class MfmaLayout: def __init__(self, version, warps_per_cta, tiles_per_warp, instr_shape, is_transposed): @@ -3732,7 +3744,9 @@ def get_test_dot_base_cases(): for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] for input_precision in ['tf32', 'tf32x3', 'ieee'] - for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')] + for in_dtype, out_dtype in [('float16', 'float16'), ('float16', + 'float32'), ('float32', + 'float32'), ('float64', 'float64')] if not (input_precision != 'ieee' and (in_dtype in ['float16']))] @@ -3865,6 +3879,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty pytest.skip("Only test out_dtype=float16 on devices with sm >=80") if capability[0] < 9 and in_dtype == 'float8e4nv': pytest.skip("float8e4nv not supported on sm <= 80") + if in_dtype == 'float64' and input_precision != 'ieee': + pytest.skip("Only IEEE precision is supported for float64 dot") if is_hip(): if in_dtype in ("float8e5", "float8e4nv") and not (is_hip_cdna4() or is_hip_gfx12()): @@ -3875,6 +3891,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty pytest.skip(f"{input_precision} not supported on HIP") if kpack == 2 and in_dtype == 'int8' and K < 64: pytest.skip("kpack too large for K") + if in_dtype == 'float64': + pytest.skip("float64 not supported on HIP yet") + if not is_hip() and kpack == 2: pytest.skip("Skip duplicated tests on nv path") @@ -4036,11 +4055,17 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid # make sure ld/st are vectorized ptx = pgm.asm['ptx'] + if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4): # XXX: skip small sizes because they are not vectorized - assert 'ld.global.v4' in ptx + if 'float64' in in_dtype: + assert 'ld.global.v2.b64' in ptx + else: + assert 'ld.global.v4' in ptx if 'float8' in in_dtype: assert 'st.global.v2' in ptx + elif 'float64' in in_dtype: + assert 'st.global.v2.b64' in ptx else: assert 'st.global.v4' in ptx @@ -4349,23 +4374,24 @@ def make_finite(x, dtype): @pytest.mark.interpreter -@pytest.mark.parametrize("B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str", - [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) - for B in [1, 2, 4, 8] - for num_warps in [1, 2, 4, 8, 16] - for BLOCK_M, BLOCK_N in [(32, 32)] - for M, N, K in [(64, 64, 64), (32, 32, 32)] - for in_dtype_str, out_dtype_str in [('int8', 'int8'), ('float16', 'float16'), - ('float16', 'float32'), ('float32', 'float32')]] + - # Large block sizes - [(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')] + - # Small block sizes - [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) - for B in [1, 2, 8] - for num_warps in [1, 2, 4] - for BLOCK_M, BLOCK_N in [(1, 32), (32, 2), (8, 8)] - for M, N, K in [(32, 32, 32)] - for in_dtype_str, out_dtype_str in [('float16', 'float16'), ('float32', 'float32')]]) +@pytest.mark.parametrize( + "B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str", + [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) + for B in [1, 2, 4, 8] + for num_warps in [1, 2, 4, 8, 16] + for BLOCK_M, BLOCK_N in [(32, 32)] + for M, N, K in [(64, 64, 64), (32, 32, 32)] + for in_dtype_str, out_dtype_str in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), + ('float32', 'float32'), ('float64', 'float64')]] + + # Large block sizes + [(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')] + + # Small block sizes + [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) + for B in [1, 2, 8] + for num_warps in [1, 2, 4] + for BLOCK_M, BLOCK_N in [(1, 32), (32, 2), (8, 8)] + for M, N, K in [(32, 32, 32)] + for in_dtype_str, out_dtype_str in [('float16', 'float16'), ('float32', 'float32')]]) def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str, device): if is_hip(): # hip does not support tf32 precision, so use ieee for all tests @@ -4376,17 +4402,17 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_ pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") if out_dtype_str == "float16": pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") + if in_dtype_str == "float64": + pytest.skip("float64 not supported on HIP yet") else: input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee" if not is_interpreter() and (BLOCK_M < 16 or BLOCK_N < 16): pytest.skip("small dots are supported only on HIP at the moment") - if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32": - if not is_interpreter() and triton.runtime.driver.active.utils.get_device_properties( - triton.runtime.driver.active.get_current_device())["max_shared_mem"] < 131072: - pytest.skip( - "Skipping tests with B = 8, M = 64, in_type = float32, out_type = float32 due to insufficient shared memory (less than 128 KB per SM) on this GPU." - ) + shared_mem_accum = B * (BLOCK_M * K + K * BLOCK_N) * get_src_element_ty_size(in_dtype_str) + if not is_interpreter() and triton.runtime.driver.active.utils.get_device_properties( + triton.runtime.driver.active.get_current_device())["max_shared_mem"] < shared_mem_accum: + pytest.skip("Skipped due to insufficient shared memory on this GPU.") @triton.jit def kernel( diff --git a/python/test/unit/language/test_matmul.py b/python/test/unit/language/test_matmul.py index 02cf270b7e72..30a870d1565c 100644 --- a/python/test/unit/language/test_matmul.py +++ b/python/test/unit/language/test_matmul.py @@ -82,11 +82,13 @@ def get_src_element_ty_size(dtype_str): return 2 if dtype_str == "float32" or dtype_str == "tensorfloat32": return 4 + if dtype_str == "float64": + return 8 raise ValueError(f"Unknown dtype {dtype_str}") -@pytest.mark.parametrize("dtype_src_str", ["float32", "tensorfloat32", "float16", "float8e5"]) -@pytest.mark.parametrize("dtype_dst_str", ["float32", "float16"]) +@pytest.mark.parametrize("dtype_src_str", ["float32", "tensorfloat32", "float16", "float8e5", "float64"]) +@pytest.mark.parametrize("dtype_dst_str", ["float32", "float16", "float64"]) @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES", [(128, 128, 16, 4), (64, 128, 32, 4), (32, 32, 32, 4), (256, 128, 32, 4), (64, 512, 32, 2), (512, 64, 32, 2), (64, 16, 16, 4)]) @@ -98,15 +100,20 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, EPILOGUE_SUBTILE, LAYOUT_16x256, monkeypatch): if NUM_CTAS > 1 and (not is_cuda() or torch.cuda.get_device_capability()[0] < 9): pytest.skip("Clusters requires nvidia compute capability >= 9") - if is_hip() and ((BLOCK_K * BLOCK_M + BLOCK_K * BLOCK_N) * NUM_STAGES * get_src_element_ty_size(dtype_src_str) - > 65536): - pytest.skip("HIP path requires less than 64KB of shared memory") + shared_mem_accum = (BLOCK_K * BLOCK_M + BLOCK_K * BLOCK_N) * NUM_STAGES * get_src_element_ty_size(dtype_src_str) + shared_mem_avail = triton.runtime.driver.active.utils.get_device_properties(0)["max_shared_mem"] + if shared_mem_accum > shared_mem_avail: + pytest.skip("Skipped due to insufficient shared memory on this GPU.") if is_hip() and (not is_hip_cdna3()) and dtype_src_str == "tensorfloat32": pytest.skip("tensorfloat32 is only supported on HIP CDNA3") if dtype_src_str == "float8e5" and BLOCK_K == 16: pytest.skip("Skipping cases small K for float8") if dtype_src_str == "float8e5" and device == "cuda" and torch.cuda.get_device_capability()[0] < 9: pytest.skip("Float8 requires compute capability >= 9") + if (dtype_src_str == "float64") != (dtype_dst_str == "float64"): + pytest.skip("Skipping unsupported case") + if dtype_src_str == "float64" and not is_cuda(): + pytest.skip("Float64 not supported on HIP yet") if "float32" in dtype_src_str and dtype_dst_str == "float16": pytest.skip("Skipping unsupported case") if "float32" == dtype_src_str and NUM_CTAS > 1: @@ -160,7 +167,8 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, # This applies only if TCv5 MMA is used (M % 64 == 0 and N % 8 == 0) and # when MMA arguments loads are pipelined (N > 16) if (device == "cuda" and torch.cuda.get_device_capability()[0] == 10 and NUM_STAGES > 1 and BLOCK_M % 64 == 0 - and BLOCK_N % 8 == 0 and BLOCK_N > 16 and not (precision == "ieee" and dtype_src_str == "float32")): + and BLOCK_N % 8 == 0 and BLOCK_N > 16 + and not (precision == "ieee" and (dtype_src_str == "float32" or dtype_src_str == "float64"))): ttgir = k.asm["ttgir"] count = ttgir.count("ttng.tc_gen5_mma") assert count == 2, "The TTGIR does not match the expected pattern." diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 0da22adef6c9..f5afeac07523 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1472,10 +1472,10 @@ def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Opti # All combinations of supported fp8 x fp8 are permitted pass else: - assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, - tl.float32), f"Unsupported lhs dtype {lhs.dtype}" - assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, - tl.float32), f"Unsupported rhs dtype {rhs.dtype}" + assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32, + tl.float64), f"Unsupported lhs dtype {lhs.dtype}" + assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32, + tl.float64), f"Unsupported rhs dtype {rhs.dtype}" assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}" if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): @@ -1514,6 +1514,9 @@ def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Opti elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): _0 = self.builder.get_fp32(0) ret_scalar_ty = tl.float32 + elif lhs.type.scalar.is_fp64(): + _0 = self.builder.get_fp64(0) + ret_scalar_ty = tl.float64 else: _0 = self.builder.get_fp16(0) if out_dtype.is_fp16() else self.builder.get_fp32(0) ret_scalar_ty = out_dtype diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 7e71cddf8b52..4167241aa193 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1875,6 +1875,35 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr } } +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 8]}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 2, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 4096 : i32, ttg.target = "cuda:80", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { + tt.func public @f64_mma_cvt() { + %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<16x16xf64, #shared, #smem, mutable> + %1 = ttg.local_alloc {allocation.offset = 2048 : i32} : () -> !ttg.memdesc<16x16xf64, #shared1, #smem, mutable> + + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf64, #mma> + + %2 = ttg.local_load %0 : !ttg.memdesc<16x16xf64, #shared, #smem, mutable> -> tensor<16x16xf64, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + + %3 = ttg.local_load %1 : !ttg.memdesc<16x16xf64, #shared1, #smem, mutable> -> tensor<16x16xf64, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 + + %out = tt.dot %2, %3, %cst, inputPrecision = tf32 : tensor<16x16xf64, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf64, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf64, #mma> + + tt.return + } +} + + // ----- #blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 4d9ca27b99e4..c21189082062 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -480,3 +480,31 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- tt.return } } + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 8]}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 2, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 4096 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { + tt.func public @hopper_f64_mma_cvt() { + %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<16x16xf64, #shared, #smem, mutable> + %1 = ttg.local_alloc {allocation.offset = 2048 : i32} : () -> !ttg.memdesc<16x16xf64, #shared1, #smem, mutable> + + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf64, #mma> + + %2 = ttg.local_load %0 : !ttg.memdesc<16x16xf64, #shared, #smem, mutable> -> tensor<16x16xf64, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + + %3 = ttg.local_load %1 : !ttg.memdesc<16x16xf64, #shared1, #smem, mutable> -> tensor<16x16xf64, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 + + %out = tt.dot %2, %3, %cst, inputPrecision = tf32 : tensor<16x16xf64, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf64, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf64, #mma> + + tt.return + } +} diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 2e344a389fbe..98678487ee05 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -121,6 +121,36 @@ module attributes {"ttg.target" = "cuda:89", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- +// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @fp64_dot( + %arg0: tensor<128x32xf64, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<32x128xf64, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<128x128xf64, #blocked> { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf64, #blocked> + // CHECK: tt.dot {{.*}} : tensor<128x32xf64, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf64, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<128x128xf64, #mma> + %d = tt.dot %arg0, %arg1, %cst, inputPrecision = tf32 : tensor<128x32xf64, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x128xf64, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf64, #blocked> + tt.return %d : tensor<128x128xf64, #blocked> + } +} + +// ----- + +// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @fp64_dot_hopper( + %arg0: tensor<128x32xf64, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<32x128xf64, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<128x128xf64, #blocked> { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf64, #blocked> + // CHECK: tt.dot {{.*}} : tensor<128x32xf64, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf64, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<128x128xf64, #mma> + %d = tt.dot %arg0, %arg1, %cst, inputPrecision = tf32 : tensor<128x32xf64, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x128xf64, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf64, #blocked> + tt.return %d : tensor<128x128xf64, #blocked> + } +} + +// ----- + // CHECK-DAG: #[[MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> // CHECK-DAG: #[[MMA1:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}> diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp index d169f4993f30..29e286fcd4ef 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp @@ -9,13 +9,10 @@ using namespace mlir::triton; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; -LogicalResult convertMMA1688(triton::DotOp op, triton::DotOp::Adaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter); - -LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter); +LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, bool isTuring, + bool isHopperF64); LogicalResult convertWGMMA(triton::nvidia_gpu::WarpGroupDotOp op, triton::nvidia_gpu::WarpGroupDotOp::Adaptor adaptor, @@ -25,6 +22,11 @@ namespace { struct DotOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + DotOpConversion(LLVMTypeConverter &converter, int computeCapability, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + computeCapability(computeCapability) {} + LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -42,10 +44,13 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { NvidiaMmaEncodingAttr mmaLayout = dyn_cast( cast(D.getType()).getEncoding()); if (!isOuter && mmaLayout && supportMMA(op, mmaLayout.getVersionMajor())) { - if (mmaLayout.isTuring()) - return convertMMA1688(op, adaptor, getTypeConverter(), rewriter); - if (mmaLayout.isAmpere()) - return convertMMA16816(op, adaptor, getTypeConverter(), rewriter); + if (mmaLayout.getVersionMajor() == 2) { + bool isHopperF64 = + computeCapability == 90 && + cast(A.getType()).getElementType().isF64(); + return convertMMA(op, adaptor, getTypeConverter(), rewriter, + mmaLayout.isTuring(), isHopperF64); + } llvm::report_fatal_error( "Unsupported MMA kind found when converting DotOp to LLVM."); @@ -58,6 +63,9 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { llvm::report_fatal_error( "Unsupported DotOp found when converting TritonGPU to LLVM."); } + +private: + int computeCapability; }; struct WarpGroupDotOpConversion @@ -155,8 +163,8 @@ struct WarpGroupDotWaitOpConversion void mlir::triton::NVIDIA::populateDotOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { - patterns.add(typeConverter, benefit); + int computeCapability, PatternBenefit benefit) { + patterns.add(typeConverter, computeCapability, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 6eff6bac6c7a..ac40db345681 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -32,8 +32,8 @@ Value loadC(Value tensor, Value llTensor, "mma layout."); auto numMmaRets = tensorTy.getElementType().getIntOrFloatBitWidth() / 8; - assert(numMmaRets == 4 || numMmaRets == 2); - if (numMmaRets == 4) { + assert(numMmaRets == 8 || numMmaRets == 4 || numMmaRets == 2); + if (numMmaRets == 8 || numMmaRets == 4) { return llTensor; } else if (numMmaRets == 2) { auto cPack = SmallVector(); @@ -63,14 +63,14 @@ Value loadC(Value tensor, Value llTensor, ValueTableV2 getValuesFromDotOperandLayoutStruct( const LLVMTypeConverter *typeConverter, Location loc, ConversionPatternRewriter &rewriter, Value value, int batch, int repOuter, - int repK, RankedTensorType type) { + int repK, RankedTensorType type, bool isHopperF64) { auto b = TritonLLVMOpBuilder(loc, rewriter); auto elems = unpackLLElements(loc, value, rewriter); auto eltTy = typeConverter->convertType(type.getElementType()); int offset{}; ValueTableV2 vals; auto bitwidth = eltTy.getIntOrFloatBitWidth(); - auto numElemsPerVec = 32 / bitwidth; + auto numElemsPerVec = std::max(32 / bitwidth, 1u); auto vecTy = vec_ty(eltTy, numElemsPerVec); auto packVec = [&](std::array dstIdx) { @@ -79,13 +79,21 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( vec = b.insert_element(vec, b.bitcast(elems[offset + i], eltTy), b.i32_val(i)); } - vals[dstIdx] = b.bitcast(vec, i32_ty); + if (bitwidth == 64) { + vals[dstIdx] = vec; + } else { + vals[dstIdx] = b.bitcast(vec, i32_ty); + } offset += numElemsPerVec; }; auto dot = cast(type.getEncoding()); auto kWidth = dot.getKWidth(); - auto largeK = bitwidth * kWidth > 32; + auto largeK = bitwidth * kWidth > std::max(32u, bitwidth); + + assert((bitwidth != 64 || largeK == false) && + "Currently fp64 don't support largeK MMA"); + if (largeK) { // For layouts with a large K dimension, the original register layout needs // to be divided into multiple MMAs, where each MMA has contiguous 32 bits @@ -94,7 +102,7 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( // we split the MMA into 4 sub-MMAs, each with a stride 4 x 32-bit along the // K dimension. llvm::SmallVector si; - auto kIters = kWidth / (32 / bitwidth); + auto kIters = kWidth / (std::max(32 / bitwidth, 1u)); if (dot.getOpIdx() == 0) { // Original register layout: @@ -221,22 +229,24 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( } } + auto numVecM = 2; + auto numVecN = 1; + auto numVecK = bitwidth == 64 ? 4 : 2; + if (dot.getOpIdx() == 0) { for (auto b = 0; b < batch; ++b) for (auto m = 0; m < repOuter; ++m) - for (auto k = 0; k < repK; ++k) { - packVec({b, 2 * m, 2 * k}); - packVec({b, 2 * m + 1, 2 * k}); - packVec({b, 2 * m, 2 * k + 1}); - packVec({b, 2 * m + 1, 2 * k + 1}); - } + for (auto k = 0; k < repK; ++k) + for (auto vk = 0; vk < numVecK; ++vk) + for (auto vm = 0; vm < numVecM; ++vm) + packVec({b, m * numVecM + vm, k * numVecK + vk}); } else { for (auto b = 0; b < batch; ++b) for (auto n = 0; n < repOuter; ++n) - for (auto k = 0; k < repK; ++k) { - packVec({b, n, 2 * k}); - packVec({b, n, 2 * k + 1}); - } + for (auto k = 0; k < repK; ++k) + for (auto vk = 0; vk < numVecK; ++vk) + for (auto vn = 0; vn < numVecN; ++vn) + packVec({b, n * numVecN + vn, k * numVecK + vk}); } return vals; } @@ -255,14 +265,19 @@ enum class TensorCoreType : uint8_t { INT32_INT1_INT1_INT32, // Not implemented INT32_INT4_INT4_INT32, // Not implemented INT32_INT8_INT8_INT32, // Not implemented + // double precision tensor core instr + FP64_FP64_FP64_FP64, // NOT_APPLICABLE, }; static Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) { + Type fp64Ty = type::f64Ty(ctx); Type fp32Ty = type::f32Ty(ctx); Type fp16Ty = type::f16Ty(ctx); Type i32Ty = type::i32Ty(ctx); + Type fp64x4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp64Ty)); Type fp32x4Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32Ty)); Type i32x4Ty = @@ -285,6 +300,8 @@ static Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) { return fp32x4Ty; case TensorCoreType::INT32_INT8_INT8_INT32: return i32x4Ty; + case TensorCoreType::FP64_FP64_FP64_FP64: + return fp64x4Ty; default: llvm::report_fatal_error("Unsupported mma type found"); } @@ -324,6 +341,9 @@ static TensorCoreType getMmaType(triton::DotOp op) { } else if (dTy.getElementType().isF16()) { if (aTy.getElementType().isF16() && bTy.getElementType().isF16()) return TensorCoreType::FP16_FP16_FP16_FP16; + } else if (dTy.getElementType().isF64()) { + if (aTy.getElementType().isF64() && bTy.getElementType().isF64()) + return TensorCoreType::FP64_FP64_FP64_FP64; } return TensorCoreType::NOT_APPLICABLE; @@ -366,6 +386,14 @@ inline static const std::map mmaInstrPtxAmpere = { "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32"}, {TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32, "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32"}, + + {TensorCoreType::FP64_FP64_FP64_FP64, + "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64"}, +}; + +inline static const std::map mmaInstrPtxHopper = { + {TensorCoreType::FP64_FP64_FP64_FP64, + "mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64"}, }; static void callMmaTuringInt8(PTXBuilder &builder, int b, int m, int n, int k, @@ -442,14 +470,51 @@ static void callMmaTuringFp16(PTXBuilder &builder, int b, int m, int n, int k, mma(retArgs, aArgs2, bArgs2, cArgs); } -static void callMmaAmpere(PTXBuilder &builder, int b, int m, int n, int k, - mlir::triton::PTXInstr &mma, unsigned numMmaRets, - unsigned colsPerThread, int numCPackedElem, - unsigned batchOffset, ValueTableV2 &ha, - ValueTableV2 &hb, const SmallVector &fc, - bool isAccF16, bool isIntMMA) { - auto retArgs = - builder.newListOperand(numMmaRets, isIntMMA || isAccF16 ? "=r" : "=f"); +// Repeat m8n8k4 (2, 1, 4) times, as m16n8k16 on hopper. +static void callMmaAmpereFp64(PTXBuilder &builder, int b, int m, int n, int k, + mlir::triton::PTXInstr &mma, unsigned numMmaRets, + unsigned colsPerThread, int numCPackedElem, + unsigned batchOffset, ValueTableV2 &ha, + ValueTableV2 &hb, const SmallVector &fc) { + auto retArgs1 = builder.newListOperand(numMmaRets / 2, "=d"); + auto retArgs2 = builder.newListOperand(numMmaRets / 2, "=d"); + auto cArgs1 = builder.newListOperand(); + for (int i = 0; i < numMmaRets / 2; ++i) { + cArgs1->listAppend(builder.newOperand( + fc[(m * colsPerThread + 4 * n) / numCPackedElem + i + batchOffset * b], + std::to_string(i))); + // reuse the output registers + } + auto cArgs2 = builder.newListOperand(); + for (int i = numMmaRets / 2; i < numMmaRets; ++i) { + cArgs2->listAppend(builder.newOperand( + fc[(m * colsPerThread + 4 * n) / numCPackedElem + i + batchOffset * b], + std::to_string(i))); + // reuse the output registers + } + + for (int vk = 0; vk < 4; ++vk) { + auto aArgs1 = builder.newListOperand({ + {ha[{b, m, k + vk}], "d"}, + }); + auto bArgs = builder.newListOperand({{hb[{b, n, k + vk}], "d"}}); + auto aArgs2 = builder.newListOperand({ + {ha[{b, m + 1, k + vk}], "d"}, + }); + mma(retArgs1, aArgs1, bArgs, cArgs1); + mma(retArgs2, aArgs2, bArgs, cArgs2); + } +} + +// Unified MMAV2 function for Ampere and HopperF64 architectures +static void callMmaV2(PTXBuilder &builder, int b, int m, int n, int k, + mlir::triton::PTXInstr &mma, unsigned numMmaRets, + unsigned colsPerThread, int numCPackedElem, + unsigned batchOffset, ValueTableV2 &ha, ValueTableV2 &hb, + const SmallVector &fc, + const std::string &constraintRet, + const std::string &constraintAB, int numVecK) { + auto retArgs = builder.newListOperand(numMmaRets, constraintRet); auto cArgs = builder.newListOperand(); for (int i = 0; i < numMmaRets; ++i) { cArgs->listAppend(builder.newOperand( @@ -457,14 +522,18 @@ static void callMmaAmpere(PTXBuilder &builder, int b, int m, int n, int k, std::to_string(i))); // reuse the output registers } - auto aArgs = builder.newListOperand({ - {ha[{b, m, k}], "r"}, - {ha[{b, m + 1, k}], "r"}, - {ha[{b, m, k + 1}], "r"}, - {ha[{b, m + 1, k + 1}], "r"}, - }); - auto bArgs = - builder.newListOperand({{hb[{b, n, k}], "r"}, {hb[{b, n, k + 1}], "r"}}); + + auto aArgs = builder.newListOperand(); + for (int vk = 0; vk < numVecK; ++vk) { + aArgs->listAppend(builder.newOperand(ha[{b, m, k + vk}], constraintAB)); + aArgs->listAppend(builder.newOperand(ha[{b, m + 1, k + vk}], constraintAB)); + } + + auto bArgs = builder.newListOperand(); + for (int vk = 0; vk < numVecK; ++vk) { + bArgs->listAppend(builder.newOperand(hb[{b, n, k + vk}], constraintAB)); + } + mma(retArgs, aArgs, bArgs, cArgs); } @@ -472,7 +541,8 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, Value a, Value b, Value c, Value d, Value loadedA, Value loadedB, Value loadedC, DotOp op, - DotOpAdaptor adaptor, bool isTuring) { + DotOpAdaptor adaptor, bool isTuring, + bool isHopperF64) { auto tb = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = c.getContext(); auto aTensorTy = cast(a.getType()); @@ -504,23 +574,28 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, assert(dotOpA.getRepOrder() == getOrderForDotOperand(dotOpA.getOpIdx(), aShapePerCTA.size(), /*kContig=*/true)); - auto ha = getValuesFromDotOperandLayoutStruct( - typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy); + auto ha = getValuesFromDotOperandLayoutStruct(typeConverter, loc, rewriter, + loadedA, repBatch, repM, repK, + aTensorTy, isHopperF64); assert(dotOpB.getRepOrder() == getOrderForDotOperand(dotOpB.getOpIdx(), bShapePerCTA.size(), /*kContig=*/true)); - auto hb = getValuesFromDotOperandLayoutStruct( - typeConverter, loc, rewriter, loadedB, repBatch, repN, repK, bTensorTy); + auto hb = getValuesFromDotOperandLayoutStruct(typeConverter, loc, rewriter, + loadedB, repBatch, repN, repK, + bTensorTy, isHopperF64); auto fc = unpackLLElements(loc, loadedC, rewriter); - auto numMmaRets = dTensorTy.getElementType().getIntOrFloatBitWidth() / 8; + + int bitwidthRet = dTensorTy.getElementType().getIntOrFloatBitWidth(); + auto numMmaRets = bitwidthRet == 64 ? 4 : bitwidthRet / 8; int numCPackedElem = 4 / numMmaRets; auto mmaType = getMmaType(op); - const auto &mmaInstructions = - isTuring ? mmaInstrPtxTuring : mmaInstrPtxAmpere; + const auto &mmaInstructions = isTuring ? mmaInstrPtxTuring + : isHopperF64 ? mmaInstrPtxHopper + : mmaInstrPtxAmpere; if (mmaInstructions.find(mmaType) == mmaInstructions.end()) { return emitError(loc, "Unsupported MMA instruction for the given mma type"); } @@ -535,6 +610,7 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, // using =r for float32 works but leads to less readable ptx. bool isIntMMA = dTensorTy.getElementType().isInteger(32); bool isAccF16 = dTensorTy.getElementType().isF16(); + bool isFp64MMA = dTensorTy.getElementType().isF64(); if (isTuring) { assert(b == 0 && "Turing only supports batch size 1"); @@ -544,10 +620,20 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, else // Turing fp16 callMmaTuringFp16(builder, b, m, n, k, mma, numMmaRets, colsPerThread, numCPackedElem, ha, hb, fc, isAccF16); - } else { // Ampere - callMmaAmpere(builder, b, m, n, k, mma, numMmaRets, colsPerThread, - numCPackedElem, batchOffset, ha, hb, fc, isAccF16, - isIntMMA); + } else { // Ampere and later + if (isFp64MMA) { + if (!isHopperF64) { + callMmaAmpereFp64(builder, b, m, n, k, mma, numMmaRets, colsPerThread, + numCPackedElem, batchOffset, ha, hb, fc); + } else { + callMmaV2(builder, b, m, n, k, mma, numMmaRets, colsPerThread, + numCPackedElem, batchOffset, ha, hb, fc, "=d", "d", 4); + } + } else { + callMmaV2(builder, b, m, n, k, mma, numMmaRets, colsPerThread, + numCPackedElem, batchOffset, ha, hb, fc, + isIntMMA || isAccF16 ? "=r" : "=f", "r", 2); + } } Value mmaOut = @@ -563,8 +649,10 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, for (int b = 0; b < repBatch; ++b) for (int k = 0; k < repK; ++k) for (int m = 0; m < repM; ++m) - for (int n = 0; n < repN; ++n) - callMma(b, 2 * m, n, 2 * k); + for (int n = 0; n < repN; ++n) { + auto numVecK = bitwidth == 64 ? 4 : 2; + callMma(b, 2 * m, n, k * numVecK); + } Type resElemTy = dTensorTy.getElementType(); @@ -587,9 +675,11 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, return success(); } +// Convert to mma.m16n8k? LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, bool isTuring) { + ConversionPatternRewriter &rewriter, bool isTuring, + bool isHopperF64) { assert(mlir::isa(op.getA().getType().getEncoding()) && mlir::isa(op.getB().getType().getEncoding()) && "Both $a and %b should be DotOperand layout."); @@ -598,19 +688,5 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, loadC(op.getC(), adaptor.getC(), typeConverter, op.getLoc(), rewriter); return convertDot(typeConverter, rewriter, op.getLoc(), op.getA(), op.getB(), op.getC(), op.getD(), adaptor.getA(), adaptor.getB(), - loadedC, op, adaptor, isTuring); -} - -// Convert to mma.m16n8k8 -LogicalResult convertMMA1688(triton::DotOp op, triton::DotOp::Adaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) { - return convertMMA(op, adaptor, typeConverter, rewriter, true /*isTuring*/); -} - -// Convert to mma.m16n8k16 -LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) { - return convertMMA(op, adaptor, typeConverter, rewriter, false /*isTuring*/); + loadedC, op, adaptor, isTuring, isHopperF64); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h index 8eb28a5c7fc6..f8e0adde25d8 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -35,7 +35,7 @@ void populateConvertLayoutOpToLLVMOptimizedPatterns( void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit); + int computeCapability, PatternBenefit benefit); void populateElementwiseOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 8945ecc9735c..fba6143232da 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -113,7 +113,8 @@ struct ConvertTritonGPUToLLVM typeConverter, patterns, patternBenefitNvidiaTensorCoreSubviewPattern); mlir::triton::NVIDIA::populateTMAToLLVMPatterns(typeConverter, targetInfo, patterns, benefit); - populateDotOpToLLVMPatterns(typeConverter, patterns, benefit); + populateDotOpToLLVMPatterns(typeConverter, patterns, computeCapability, + benefit); populateElementwiseOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis, computeCapability, targetInfo, benefit);