Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
a5129a0
[BACKEND] Support fp64 simt fma for common use and fp64 mma for SM80.
kzwrime Jun 25, 2025
abc3d03
[BACKEND] Replace (bitwidth < 32 ? 32 / bitwidth : 1) with std::max(3…
kzwrime Jun 26, 2025
ac506a5
[BACKEND] Replace (bitwidth < 32 ? 32 : bitwidth) with std::max(32u, …
kzwrime Jun 26, 2025
d8f9cef
[BACKEND] Fix format
kzwrime Jun 26, 2025
253f078
[BACKEND] Better check for whether fp64 MMA is supported
kzwrime Jun 26, 2025
2f577af
[BACKEND] Revert changes in lowerLdStMatrix
kzwrime Jun 26, 2025
a1aef12
[BACKEND] Revert changes in lowerSharedToDotOperandTransLL
kzwrime Jun 26, 2025
041b421
[TESTS] Check cuda_shared_mem_avail using shared_memory_per_block_opt…
kzwrime Jun 27, 2025
a963038
[BACKEND] Remove useless include in MMAv2
kzwrime Jun 27, 2025
04ae6e8
[BACKEND] More fp64 MMA support
kzwrime Jun 28, 2025
143f1d0
[TEST] Enable test_simple_matmul[float64] on all cuda archs and fix f…
kzwrime Jun 30, 2025
4362933
[BACKEND] Improved the comments regarding F64 MMA/FMA selection.
kzwrime Jun 30, 2025
77c61bd
[BACKEND] Unify SM80 F64 MMA with SM90 and remove the design of MMAv2.2
kzwrime Jul 2, 2025
6ee5178
[BACKEND] Revert the change of versionMinor
kzwrime Jul 3, 2025
04bf98f
[TESTS] Revert the change of versionMinor
kzwrime Jul 4, 2025
e6bf188
[TESTS] Fix: prohibite mmav5 check on float64 in test_simple_matmul
kzwrime Jul 4, 2025
b4187d1
Merge branch 'main' into dot-fp64-mma
kzwrime Jul 7, 2025
87261ca
[BACKEND] Restrict the F64 case of tileBitWidthK to isAmpere() && bit…
kzwrime Jul 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1472,7 +1472,7 @@ vecIdx (index of the element in the quad; this is always along the k-dim)
return $_get(context, opIdx, parent, 0);
// For MMAV2 and V3
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
unsigned kWidth = 32 / bitwidth;
unsigned kWidth = std::max(32 / bitwidth, 1u);
return $_get(context, opIdx, parent, kWidth);
}]>
];
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ bool supportMMA(Value value, int version) {
bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType>(elemTy);
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && version >= 2) ||
((elemTy.isF32() || elemTy.isF64()) && version >= 2) ||
(elemTy.isInteger(8) && version >= 2);
}

Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ struct ElementwiseInlineAsmOpConversion
Type elemTy = getElementType(op.getOperand(i));
unsigned bitWidth =
elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : 64;
unsigned numElementPerReg = bitWidth < 32 ? 32 / bitWidth : 1;
unsigned numElementPerReg = std::max(32 / bitWidth, 1u);
numElementPerReg = std::min(numElementPerReg, numPackedElements);
for (int j = 0; j < numPackedElements; j += numElementPerReg) {
if (numElementPerReg == 1) {
Expand Down Expand Up @@ -278,7 +278,7 @@ struct ElementwiseInlineAsmOpConversion
// Pack return elements into 32-bits.
unsigned bitWidth = ty.isIntOrFloat() ? ty.getIntOrFloatBitWidth() : 64;
unsigned numElemsPerReg =
std::min(bitWidth < 32 ? 32 / bitWidth : 1, op.getPackedElement());
std::min(std::max(32 / bitWidth, 1u), op.getPackedElement());
assert(op.getPackedElement() % numElemsPerReg == 0);
if (numElemsPerReg > 1) {
ty = vec_ty(ty, numElemsPerReg);
Expand Down
12 changes: 7 additions & 5 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2129,9 +2129,9 @@ NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
SmallVector<int64_t>
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
int kWidth, int opIdx) const {
assert(
kWidth >= 32 / bitwidth &&
"kWidth must be >= 32 / bitwidth for this function to be well-defined");
assert(kWidth >= std::max(32 / bitwidth, 1) &&
"kWidth must be >= max(32 / bitwidth, 1) for this function to be "
"well-defined");
auto rank = shape.size();
// Broadcast long K
auto warpsPerCTA = to_vector(getWarpsPerCTA());
Expand All @@ -2142,16 +2142,18 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
if (rank == 3) {
tileSize.push_back(1);
}
// warpSizeK * (warpRepK * VecBitWidth)
auto tileBitWidthK = (isAmpere() && bitwidth == 64) ? (4 * 256) : (4 * 64);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about this last change. What's the context for the Ampere tile to be larger? I thought it was the smaller one.

Copy link
Copy Markdown
Contributor Author

@kzwrime kzwrime Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It used to be auto tileBitWidthK = isHopperF64() ? (4 * 256) : (4 * 64), but later it was designed to repeat SM80 fp64 m8n8k4 (2,1,4)=8 times, making it the same as SM90 fp64 mma m16n8k16, thereby achieving a unified mmav2 fp64 processing logic.

As previously discussed:

  • SM90 fp64 m16n8k16 belongs to mmav2, with tileBitWidthK being 4*256
  • SM80 fp64 m8n8k4 also belongs to mmav2, with tileBitWidthK being 4*64 (same as other mma instructions).

We needed to distinguish SM90 fp64 mma, so initially I designed it with versionMajor=2, versionMinor=2 and used isHopperF64() to identify this information, as #7310 (comment) shows.

@ThomasRaoux thought introducing a new versionMinor would be more confusing (#7310 (comment)), so I repeat SM80 fp64 m8n8k4 (2,1,4)=8 times, making it the same as SM90 fp64 mma m16n8k16, thereby achieving a unified mmav2 fp64 processing logic ( #7310 (comment)).

if (opIdx == 0) {
// m x k
tileSize.push_back(16);
tileSize.push_back(4 * 64 / bitwidth);
tileSize.push_back(tileBitWidthK / bitwidth);
} else {
// k x n
// Hopper path never uses the n value, since this method is only invoked
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
// so it's fine if the n is incorrect here
tileSize.push_back(4 * 64 / bitwidth);
tileSize.push_back(tileBitWidthK / bitwidth);
tileSize.push_back(8);
}

Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,15 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
auto oldBType = cast<RankedTensorType>(b.getType());
auto oldRetType = cast<RankedTensorType>(dotOp.getType());

// Enable F64 MMA only on SM80/SM90 with high performance F64 tensorcore.
// Otherwise, fallback to F64 FMA for better performance.
if ((oldAType.getElementType().isF64() ||
oldBType.getElementType().isF64() ||
oldRetType.getElementType().isF64()) &&
!(computeCapability == 80 || computeCapability == 90)) {
Comment thread
kzwrime marked this conversation as resolved.
return failure();
}

// get MMA encoding for the given number of warps
auto CTALayout = getCTALayout(oldRetType.getEncoding());
auto retShapePerCTA = getShapePerCTA(oldRetType);
Expand Down
76 changes: 51 additions & 25 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ def check_type_supported(dtype, device):
pytest.skip("bfloat16 is not supported in the interpreter")


def get_src_element_ty_size(dtype_str):
if dtype_str in ["int8", "uint8", "float8e4b15"]:
return 1
if dtype_str == "float16":
return 2
if dtype_str == "float32" or dtype_str == "tensorfloat32":
return 4
if dtype_str == "float64":
return 8
raise ValueError(f"Unknown dtype {dtype_str}")


class MfmaLayout:

def __init__(self, version, warps_per_cta, tiles_per_warp, instr_shape, is_transposed):
Expand Down Expand Up @@ -3732,7 +3744,9 @@ def get_test_dot_base_cases():
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for input_precision in ['tf32', 'tf32x3', 'ieee']
for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')]
for in_dtype, out_dtype in [('float16', 'float16'), ('float16',
'float32'), ('float32',
'float32'), ('float64', 'float64')]
if not (input_precision != 'ieee' and (in_dtype in ['float16']))]


Expand Down Expand Up @@ -3865,6 +3879,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
pytest.skip("Only test out_dtype=float16 on devices with sm >=80")
if capability[0] < 9 and in_dtype == 'float8e4nv':
pytest.skip("float8e4nv not supported on sm <= 80")
if in_dtype == 'float64' and input_precision != 'ieee':
pytest.skip("Only IEEE precision is supported for float64 dot")

if is_hip():
if in_dtype in ("float8e5", "float8e4nv") and not (is_hip_cdna4() or is_hip_gfx12()):
Expand All @@ -3875,6 +3891,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
pytest.skip(f"{input_precision} not supported on HIP")
if kpack == 2 and in_dtype == 'int8' and K < 64:
pytest.skip("kpack too large for K")
if in_dtype == 'float64':
pytest.skip("float64 not supported on HIP yet")

if not is_hip() and kpack == 2:
pytest.skip("Skip duplicated tests on nv path")

Expand Down Expand Up @@ -4036,11 +4055,17 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid

# make sure ld/st are vectorized
ptx = pgm.asm['ptx']

if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4):
# XXX: skip small sizes because they are not vectorized
assert 'ld.global.v4' in ptx
if 'float64' in in_dtype:
assert 'ld.global.v2.b64' in ptx
else:
assert 'ld.global.v4' in ptx
if 'float8' in in_dtype:
assert 'st.global.v2' in ptx
elif 'float64' in in_dtype:
assert 'st.global.v2.b64' in ptx
else:
assert 'st.global.v4' in ptx

Expand Down Expand Up @@ -4349,23 +4374,24 @@ def make_finite(x, dtype):


@pytest.mark.interpreter
@pytest.mark.parametrize("B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str",
[(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str)
for B in [1, 2, 4, 8]
for num_warps in [1, 2, 4, 8, 16]
for BLOCK_M, BLOCK_N in [(32, 32)]
for M, N, K in [(64, 64, 64), (32, 32, 32)]
for in_dtype_str, out_dtype_str in [('int8', 'int8'), ('float16', 'float16'),
('float16', 'float32'), ('float32', 'float32')]] +
# Large block sizes
[(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')] +
# Small block sizes
[(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str)
for B in [1, 2, 8]
for num_warps in [1, 2, 4]
for BLOCK_M, BLOCK_N in [(1, 32), (32, 2), (8, 8)]
for M, N, K in [(32, 32, 32)]
for in_dtype_str, out_dtype_str in [('float16', 'float16'), ('float32', 'float32')]])
@pytest.mark.parametrize(
"B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str",
[(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str)
for B in [1, 2, 4, 8]
for num_warps in [1, 2, 4, 8, 16]
for BLOCK_M, BLOCK_N in [(32, 32)]
for M, N, K in [(64, 64, 64), (32, 32, 32)]
for in_dtype_str, out_dtype_str in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'),
('float32', 'float32'), ('float64', 'float64')]] +
# Large block sizes
[(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')] +
# Small block sizes
[(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str)
for B in [1, 2, 8]
for num_warps in [1, 2, 4]
for BLOCK_M, BLOCK_N in [(1, 32), (32, 2), (8, 8)]
for M, N, K in [(32, 32, 32)]
for in_dtype_str, out_dtype_str in [('float16', 'float16'), ('float32', 'float32')]])
def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str, device):
if is_hip():
# hip does not support tf32 precision, so use ieee for all tests
Expand All @@ -4376,17 +4402,17 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_
pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d")
if out_dtype_str == "float16":
pytest.skip(f"{out_dtype_str} has low precision in WMMA dot")
if in_dtype_str == "float64":
pytest.skip("float64 not supported on HIP yet")
else:
input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee"
if not is_interpreter() and (BLOCK_M < 16 or BLOCK_N < 16):
pytest.skip("small dots are supported only on HIP at the moment")

if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32":
if not is_interpreter() and triton.runtime.driver.active.utils.get_device_properties(
triton.runtime.driver.active.get_current_device())["max_shared_mem"] < 131072:
pytest.skip(
"Skipping tests with B = 8, M = 64, in_type = float32, out_type = float32 due to insufficient shared memory (less than 128 KB per SM) on this GPU."
)
shared_mem_accum = B * (BLOCK_M * K + K * BLOCK_N) * get_src_element_ty_size(in_dtype_str)
if not is_interpreter() and triton.runtime.driver.active.utils.get_device_properties(
triton.runtime.driver.active.get_current_device())["max_shared_mem"] < shared_mem_accum:
pytest.skip("Skipped due to insufficient shared memory on this GPU.")

@triton.jit
def kernel(
Expand Down
20 changes: 14 additions & 6 deletions python/test/unit/language/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ def get_src_element_ty_size(dtype_str):
return 2
if dtype_str == "float32" or dtype_str == "tensorfloat32":
return 4
if dtype_str == "float64":
return 8
raise ValueError(f"Unknown dtype {dtype_str}")


@pytest.mark.parametrize("dtype_src_str", ["float32", "tensorfloat32", "float16", "float8e5"])
@pytest.mark.parametrize("dtype_dst_str", ["float32", "float16"])
@pytest.mark.parametrize("dtype_src_str", ["float32", "tensorfloat32", "float16", "float8e5", "float64"])
@pytest.mark.parametrize("dtype_dst_str", ["float32", "float16", "float64"])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES", [(128, 128, 16, 4), (64, 128, 32, 4), (32, 32, 32, 4),
(256, 128, 32, 4), (64, 512, 32, 2),
(512, 64, 32, 2), (64, 16, 16, 4)])
Expand All @@ -98,15 +100,20 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
EPILOGUE_SUBTILE, LAYOUT_16x256, monkeypatch):
if NUM_CTAS > 1 and (not is_cuda() or torch.cuda.get_device_capability()[0] < 9):
pytest.skip("Clusters requires nvidia compute capability >= 9")
if is_hip() and ((BLOCK_K * BLOCK_M + BLOCK_K * BLOCK_N) * NUM_STAGES * get_src_element_ty_size(dtype_src_str)
> 65536):
pytest.skip("HIP path requires less than 64KB of shared memory")
shared_mem_accum = (BLOCK_K * BLOCK_M + BLOCK_K * BLOCK_N) * NUM_STAGES * get_src_element_ty_size(dtype_src_str)
shared_mem_avail = triton.runtime.driver.active.utils.get_device_properties(0)["max_shared_mem"]
if shared_mem_accum > shared_mem_avail:
pytest.skip("Skipped due to insufficient shared memory on this GPU.")
if is_hip() and (not is_hip_cdna3()) and dtype_src_str == "tensorfloat32":
pytest.skip("tensorfloat32 is only supported on HIP CDNA3")
if dtype_src_str == "float8e5" and BLOCK_K == 16:
pytest.skip("Skipping cases small K for float8")
if dtype_src_str == "float8e5" and device == "cuda" and torch.cuda.get_device_capability()[0] < 9:
pytest.skip("Float8 requires compute capability >= 9")
if (dtype_src_str == "float64") != (dtype_dst_str == "float64"):
pytest.skip("Skipping unsupported case")
if dtype_src_str == "float64" and not is_cuda():
pytest.skip("Float64 not supported on HIP yet")
if "float32" in dtype_src_str and dtype_dst_str == "float16":
pytest.skip("Skipping unsupported case")
if "float32" == dtype_src_str and NUM_CTAS > 1:
Expand Down Expand Up @@ -160,7 +167,8 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
# This applies only if TCv5 MMA is used (M % 64 == 0 and N % 8 == 0) and
# when MMA arguments loads are pipelined (N > 16)
if (device == "cuda" and torch.cuda.get_device_capability()[0] == 10 and NUM_STAGES > 1 and BLOCK_M % 64 == 0
and BLOCK_N % 8 == 0 and BLOCK_N > 16 and not (precision == "ieee" and dtype_src_str == "float32")):
and BLOCK_N % 8 == 0 and BLOCK_N > 16
and not (precision == "ieee" and (dtype_src_str == "float32" or dtype_src_str == "float64"))):
ttgir = k.asm["ttgir"]
count = ttgir.count("ttng.tc_gen5_mma")
assert count == 2, "The TTGIR does not match the expected pattern."
Expand Down
11 changes: 7 additions & 4 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,10 +1472,10 @@ def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Opti
# All combinations of supported fp8 x fp8 are permitted
pass
else:
assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
tl.float32), f"Unsupported lhs dtype {lhs.dtype}"
assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
tl.float32), f"Unsupported rhs dtype {rhs.dtype}"
assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32,
tl.float64), f"Unsupported lhs dtype {lhs.dtype}"
assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32,
tl.float64), f"Unsupported rhs dtype {rhs.dtype}"
assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"

if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
Expand Down Expand Up @@ -1514,6 +1514,9 @@ def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Opti
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
_0 = self.builder.get_fp32(0)
ret_scalar_ty = tl.float32
elif lhs.type.scalar.is_fp64():
_0 = self.builder.get_fp64(0)
ret_scalar_ty = tl.float64
else:
_0 = self.builder.get_fp16(0) if out_dtype.is_fp16() else self.builder.get_fp32(0)
ret_scalar_ty = out_dtype
Expand Down
29 changes: 29 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,35 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
}
}

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 8]}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 2, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 4096 : i32, ttg.target = "cuda:80", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
tt.func public @f64_mma_cvt() {
%0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<16x16xf64, #shared, #smem, mutable>
%1 = ttg.local_alloc {allocation.offset = 2048 : i32} : () -> !ttg.memdesc<16x16xf64, #shared1, #smem, mutable>

%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf64, #mma>

%2 = ttg.local_load %0 : !ttg.memdesc<16x16xf64, #shared, #smem, mutable> -> tensor<16x16xf64, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>

%3 = ttg.local_load %1 : !ttg.memdesc<16x16xf64, #shared1, #smem, mutable> -> tensor<16x16xf64, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>

// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64

%out = tt.dot %2, %3, %cst, inputPrecision = tf32 : tensor<16x16xf64, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf64, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf64, #mma>

tt.return
}
}


// -----

#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
Expand Down
28 changes: 28 additions & 0 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,31 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
tt.return
}
}

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 8]}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 2, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 4096 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
tt.func public @hopper_f64_mma_cvt() {
%0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<16x16xf64, #shared, #smem, mutable>
%1 = ttg.local_alloc {allocation.offset = 2048 : i32} : () -> !ttg.memdesc<16x16xf64, #shared1, #smem, mutable>

%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf64, #mma>

%2 = ttg.local_load %0 : !ttg.memdesc<16x16xf64, #shared, #smem, mutable> -> tensor<16x16xf64, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>

%3 = ttg.local_load %1 : !ttg.memdesc<16x16xf64, #shared1, #smem, mutable> -> tensor<16x16xf64, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>

// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64

%out = tt.dot %2, %3, %cst, inputPrecision = tf32 : tensor<16x16xf64, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf64, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf64, #mma>

tt.return
}
}
Loading
Loading