From a77e33a8aaf15581e33e3f3fdc76ac24751b0a35 Mon Sep 17 00:00:00 2001 From: Ilya Cherniavski Date: Wed, 13 Nov 2024 16:24:54 +0000 Subject: [PATCH 01/10] [AMD] Use wave shuffle for MFMA to Dot operand layout conversion (FP8) Adding a case for MFMA->Dot (FP8) layout conversion that avoids using shared memory, to speed up FP8 attention kernels. (Note: can be replaced by the LinearLayout based warp shuffle conversion when it is ready) Test: pytest pytest third_party/amd/python/test/test_chained_dot_fp8.py lit test ctest -j32 --- include/triton/Analysis/Utility.h | 5 + lib/Analysis/Utility.cpp | 23 ++- test/Conversion/amd/mfma-shortcut.mlir | 18 +- .../ConvertLayoutOpToLLVM.cpp | 91 ++++++++++ .../amd/python/test/test_chained_dot_fp8.py | 164 ++++++++++++++++++ 5 files changed, 299 insertions(+), 2 deletions(-) create mode 100644 third_party/amd/python/test/test_chained_dot_fp8.py diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index df6029db0de2..ae517912fbb4 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -218,6 +218,11 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy); +// Check if MFMA layout can be converted to the dot operand +// layout using warp shuffle. +bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, + RankedTensorType dstTy); + // TODO: Move utility functions that belong to ConvertLayoutOp to class // ConvertLayoutOpHelper in the future bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index ac72b4f26cd6..d5cb1f0bd1f0 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" +#include "triton/Conversion/MLIRTypes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -631,6 +632,23 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, return ans; } +bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto mfmaLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + if (!mfmaLayout || !dotOperandLayout) + return false; + + // Currently supporting 32x32 FP8 MFMA -> dot operand case + return dotOperandLayout.getParent() == mfmaLayout && + dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && + dotOperandLayout.getKWidth() == 8 && + getContigPerThread(mfmaLayout)[1] == 4 && mfmaLayout.getMDim() == 32 && + mfmaLayout.getNDim() == 32 && + triton::type::isFloat8(srcTy.getElementType()) && + mfmaLayout.getWarpsPerCTA()[1] == 1; +} + // We get the smallest submap of srcTy^{-1} * dstTy that is not the identity // under kBlock, kWarp or kLane (in that order). The idea here is that if we // have a transformation that's the identity on kBlock, we don't need to use @@ -729,7 +747,10 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && !isBlockedToDotShortcut(srcTy, dstTy) && - !matchMmaV3AndDotOperandLayout(srcTy, dstTy); + !matchMmaV3AndDotOperandLayout(srcTy, dstTy) && + // to be removed when generalized warp shuffle conversions + // are ready: + !matchMFMAAndDotOperandShuffleCase(srcTy, dstTy); } bool atomicNeedsSharedMemory(Value value) { diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index a2c8f48718d9..5b0036d4c590 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s +// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s #mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> #dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> @@ -27,3 +27,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_f8 + tt.func public @mfma_dot_cvt_f8(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index d3ffaed2e8fc..c1c2ca6bd217 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -115,6 +115,96 @@ struct LocalLoadOpConversion } }; +struct ConvertLayoutOpMFMAToDotOpConversion + : public ConvertOpToLLVMPattern { +public: + explicit ConvertLayoutOpMFMAToDotOpConversion( + LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = cast(op.getSrc().getType()); + auto dstType = cast(op.getType()); + + if (!matchMFMAAndDotOperandShuffleCase(srcType, dstType)) + return failure(); + + /* + Using wave shuffle to convert layouts: + 1) Input MMA layout (32x32, fp8, 16 values): + _____________________________________________________________ + |(t0 v0 v1 v2 v3) (t32 v0 v1 v2 v3) ... (t32 v12 v13 v14 v15)| + | ... ... | + |(t31 v0 v1 v2 v3) (t63 v0 v1 v2 v3) ... (t63 v12 v13 v14 v15)| + |_____________________________________________________________| + + 2) Output Dot operand layout (two 32x16 tiles, fp8, 8 values each): + ____________________________________________________________ ___ + |(t0 v0 v1 v2 v3 v4 v5 v6 v7) (t32 v0 v1 v2 v3 v4 v5 v6 v7) | | + | ... ... | |... + |(t31 v0 v1 v2 v3 v4 v5 v6 v7) (t63 v0 v1 v2 v3 v4 v5 v6 v7) | | + |____________________________________________________________| |___ + */ + + auto loc = op.getLoc(); + + SmallVector inVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + if (inVals.empty() || inVals.size() % 8 != 0) + return failure(); + + Value threadId = tid_val(); + Value warpSize = i32_val(64); // MFMA Warp Size + Value laneId = urem(threadId, warpSize); + Value laneOffset = i32_val(32); + Value mask = icmp_slt(laneId, laneOffset); + Value addr0 = select(mask, add(laneId, laneOffset), laneId); + Value addr1 = select(mask, laneId, sub(laneId, laneOffset)); + + SmallVector outVals; + auto elemTy = int_ty(8); + auto vecTy = vec_ty(elemTy, 4); + for (size_t startIdx = 0; startIdx < inVals.size(); startIdx += 8) { + Value vec0 = undef(vecTy); + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + vec0 = insert_element(vecTy, vec0, inVals[startIdx + vIdx], i32_val(vIdx)); + } + Value vec1 = undef(vecTy); + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + vec1 = insert_element(vecTy, vec1, inVals[startIdx + vIdx + 4], i32_val(vIdx)); + } + + Value shflVec0 = bitcast( + targetInfo.shuffleIdx(rewriter, loc, bitcast(vec0, int_ty(32)), addr0), vecTy); + Value shflVec1 = bitcast( + targetInfo.shuffleIdx(rewriter, loc, bitcast(vec1, int_ty(32)), addr1), vecTy); + + Value firstVec = select(mask, vec0, shflVec1); + Value secondVec = select(mask, shflVec0, vec1); + + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + outVals.push_back(extract_element(elemTy, firstVec, i32_val(vIdx))); + } + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + outVals.push_back(extract_element(elemTy, secondVec, i32_val(vIdx))); + } + } + + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + +protected: + const TargetInfoBase &targetInfo; +}; + } // namespace namespace mlir::triton::AMD { @@ -123,5 +213,6 @@ void populateConvertLayoutOpToLLVMPatterns( RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/python/test/test_chained_dot_fp8.py b/third_party/amd/python/test/test_chained_dot_fp8.py new file mode 100644 index 000000000000..abe7a47e9647 --- /dev/null +++ b/third_party/amd/python/test/test_chained_dot_fp8.py @@ -0,0 +1,164 @@ +""" +Testing the (FP8) case of a dot op that consumes the output (MFMA) of +another dot op as an input. + +""" + +import math +import pytest +import sys +import torch + +import triton +import triton.language as tl + +TORCH_HAS_FP8E4 = hasattr(torch, 'float8_e4m3fnuz') +float8: tl.constexpr = None if not TORCH_HAS_FP8E4 else torch.float8_e4m3fnuz +torch.manual_seed(42) + + +@triton.jit +def _chained_dot( + Q, + K, + V, + Out, + q_desc, + k_desc, + v_desc, + s_sc, + s_desc, + o_sc, + stride_qz, + stride_qm, + stride_qd, + stride_kz, + stride_kn, + stride_kd, + stride_vz, + stride_vd, + stride_vn, + stride_oz, + stride_om, + stride_od, + Z, + N, + BLOCK_D: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + USE_FP8: tl.constexpr, +): + start_m = tl.program_id(0) + off_z = tl.program_id(1) + qkv_offset = off_z * stride_qz + Q_block_ptr = tl.make_block_ptr(base=Q + qkv_offset, shape=(N, BLOCK_D), strides=(stride_qm, stride_qd), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_D), order=(1, 0)) + K_block_ptr = tl.make_block_ptr(base=K + qkv_offset, shape=(BLOCK_D, N), strides=(stride_kd, stride_kn), + offsets=(0, 0), block_shape=(BLOCK_D, BLOCK_N), order=(0, 1)) + V_block_ptr = tl.make_block_ptr(base=V + qkv_offset, shape=(N, BLOCK_D), strides=(stride_vn, stride_vd), + offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_D), order=(0, 1)) + + s_scale = q_desc * k_desc * s_sc + acc_scale = s_desc * v_desc * o_sc + + q = tl.load(Q_block_ptr) + + acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) + lo, hi = 0, N + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + k = tl.load(K_block_ptr) + s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + s += tl.dot(q, k) + + if USE_FP8: + s *= s_scale + + v = tl.load(V_block_ptr) + acc += tl.dot(s.to(v.dtype), v) + + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + if USE_FP8: + acc *= acc_scale + + O_block_ptr = tl.make_block_ptr(base=Out + qkv_offset, shape=(N, BLOCK_D), strides=(stride_om, stride_od), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_D), order=(1, 0)) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +class chained_dot_fn(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, q_desc=1.0, k_desc=1.0, v_desc=1.0, s_sc=1.0, s_desc=1.0, o_sc=1.0): + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-2] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q, dtype=v.dtype) + + BLOCK_M = 128 if q.dtype == float8 else 256 + BLOCK_N = 32 + waves_per_eu = 2 + num_warps = BLOCK_M // 32 + num_stages = 1 + + grid = (triton.cdiv(q.shape[1], BLOCK_M), q.shape[0], 1) + + _chained_dot[grid](q, k, v, o, q_desc, + k_desc, v_desc, s_sc, s_desc, o_sc, q.stride(0), q.stride(1), q.stride(2), k.stride(0), + k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), o.stride(0), o.stride(1), + o.stride(2), Z=q.shape[0], N=q.shape[1], BLOCK_D=Lk, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + USE_FP8=(q.dtype == float8), waves_per_eu=waves_per_eu, num_warps=num_warps, + num_stages=num_stages) + + return o + + +chained_dot = chained_dot_fn.apply + + +def to_float8(x, dtype=float8, margin: float = 1.0): + finfo = torch.finfo(dtype) + scale = finfo.max / x.abs().max().clamp(min=1e-12) + scale = math.pow(2, math.floor(math.log2(scale.float().item())) - margin) + x_scaled = (x.float() * scale).clamp(min=finfo.min, max=finfo.max) + return x_scaled.to(dtype), scale, 1.0 / scale + + +@pytest.mark.parametrize('N, D, dtype', [(*shape, dtype) for shape in [(128, 32), (256, 128)] for dtype in ['fp8']]) +def test_chained_dot(N, D, dtype): + if dtype == 'fp8': + assert float8 is not None + + BATCH = 1 + q = torch.empty((BATCH, N, D), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5) + k = torch.empty((BATCH, N, D), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5) + v = torch.empty((BATCH, D, N), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5) + + if dtype == 'fp8': + q_f8, _, q_desc = to_float8(q) + k_f8, _, k_desc = to_float8(k) + v_f8, _, v_desc = to_float8(v) + + s = torch._scaled_mm(q_f8[0], k_f8[0].transpose(0, 1), out_dtype=torch.float32, + scale_a=torch.tensor(q_desc, dtype=torch.float32, device="cuda"), + scale_b=torch.tensor(k_desc, dtype=torch.float32, device="cuda")) + s_f8, s_sc, s_desc = to_float8(s) + ref = torch._scaled_mm(s_f8, v_f8[0].transpose(0, 1), out_dtype=torch.float32, + scale_a=torch.tensor(s_desc, dtype=torch.float32, device="cuda"), + scale_b=torch.tensor(v_desc, dtype=torch.float32, device="cuda")) + ref_f8, ref_sc, _ = to_float8(ref) + + tri_out = chained_dot(q_f8, k_f8, v_f8, q_desc, k_desc, v_desc, s_sc, s_desc, ref_sc) + + assert tri_out.isnan().sum() == 0 + torch.testing.assert_close(tri_out[0].float(), ref_f8.float(), atol=1e-2, rtol=3e-3) + + else: + s = torch.matmul(q, k.transpose(1, 2)) + ref = torch.matmul(s, v.transpose(1, 2)) + + tri_out = chained_dot(q, k, v) + torch.testing.assert_close(tri_out, ref, atol=1e-2, rtol=3e-3) From 4cf8fc324a17b1ec063eae4ad841b7b08252155c Mon Sep 17 00:00:00 2001 From: Ilya Cherniavski Date: Thu, 14 Nov 2024 18:08:56 +0000 Subject: [PATCH 02/10] (linter) --- .../ConvertLayoutOpToLLVM.cpp | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index c1c2ca6bd217..94e83f755839 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -119,10 +119,10 @@ struct ConvertLayoutOpMFMAToDotOpConversion : public ConvertOpToLLVMPattern { public: explicit ConvertLayoutOpMFMAToDotOpConversion( - LLVMTypeConverter &typeConverter, - const TargetInfoBase &targetInfo, - PatternBenefit benefit) - : ConvertOpToLLVMPattern(typeConverter, benefit), + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, + benefit), targetInfo(targetInfo) {} LogicalResult @@ -172,17 +172,23 @@ struct ConvertLayoutOpMFMAToDotOpConversion for (size_t startIdx = 0; startIdx < inVals.size(); startIdx += 8) { Value vec0 = undef(vecTy); for (size_t vIdx = 0; vIdx < 4; ++vIdx) { - vec0 = insert_element(vecTy, vec0, inVals[startIdx + vIdx], i32_val(vIdx)); + vec0 = + insert_element(vecTy, vec0, inVals[startIdx + vIdx], i32_val(vIdx)); } Value vec1 = undef(vecTy); for (size_t vIdx = 0; vIdx < 4; ++vIdx) { - vec1 = insert_element(vecTy, vec1, inVals[startIdx + vIdx + 4], i32_val(vIdx)); + vec1 = insert_element(vecTy, vec1, inVals[startIdx + vIdx + 4], + i32_val(vIdx)); } - Value shflVec0 = bitcast( - targetInfo.shuffleIdx(rewriter, loc, bitcast(vec0, int_ty(32)), addr0), vecTy); - Value shflVec1 = bitcast( - targetInfo.shuffleIdx(rewriter, loc, bitcast(vec1, int_ty(32)), addr1), vecTy); + Value shflVec0 = + bitcast(targetInfo.shuffleIdx(rewriter, loc, + bitcast(vec0, int_ty(32)), addr0), + vecTy); + Value shflVec1 = + bitcast(targetInfo.shuffleIdx(rewriter, loc, + bitcast(vec1, int_ty(32)), addr1), + vecTy); Value firstVec = select(mask, vec0, shflVec1); Value secondVec = select(mask, shflVec0, vec1); @@ -213,6 +219,7 @@ void populateConvertLayoutOpToLLVMPatterns( RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { patterns.add(typeConverter, benefit); - patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, + benefit); } } // namespace mlir::triton::AMD From df150ab347fba5f961453b0c97d760c0c69b1f3b Mon Sep 17 00:00:00 2001 From: Ilya Cherniavski Date: Thu, 14 Nov 2024 18:34:10 +0000 Subject: [PATCH 03/10] (adj rtol=0) --- third_party/amd/python/test/test_chained_dot_fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/amd/python/test/test_chained_dot_fp8.py b/third_party/amd/python/test/test_chained_dot_fp8.py index abe7a47e9647..5aedb896633c 100644 --- a/third_party/amd/python/test/test_chained_dot_fp8.py +++ b/third_party/amd/python/test/test_chained_dot_fp8.py @@ -154,11 +154,11 @@ def test_chained_dot(N, D, dtype): tri_out = chained_dot(q_f8, k_f8, v_f8, q_desc, k_desc, v_desc, s_sc, s_desc, ref_sc) assert tri_out.isnan().sum() == 0 - torch.testing.assert_close(tri_out[0].float(), ref_f8.float(), atol=1e-2, rtol=3e-3) + torch.testing.assert_close(tri_out[0].float(), ref_f8.float(), atol=1e-2, rtol=0) else: s = torch.matmul(q, k.transpose(1, 2)) ref = torch.matmul(s, v.transpose(1, 2)) tri_out = chained_dot(q, k, v) - torch.testing.assert_close(tri_out, ref, atol=1e-2, rtol=3e-3) + torch.testing.assert_close(tri_out, ref, atol=1e-2, rtol=0) From 21a45d0612cdbce3c18490d716487286048d0570 Mon Sep 17 00:00:00 2001 From: Ilya Cherniavski Date: Tue, 19 Nov 2024 05:30:25 +0000 Subject: [PATCH 04/10] (get the warp size from the layout) --- .../amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 94e83f755839..5deebe591b48 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -159,7 +159,8 @@ struct ConvertLayoutOpMFMAToDotOpConversion return failure(); Value threadId = tid_val(); - Value warpSize = i32_val(64); // MFMA Warp Size + auto mfmaLayout = dyn_cast(srcType.getEncoding()); + Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout)); Value laneId = urem(threadId, warpSize); Value laneOffset = i32_val(32); Value mask = icmp_slt(laneId, laneOffset); From 1994150c488c01c32f94646d249fa353a30859e7 Mon Sep 17 00:00:00 2001 From: Ilya Cherniavski Date: Wed, 20 Nov 2024 22:56:27 +0000 Subject: [PATCH 05/10] (mfma_16 support) --- lib/Analysis/Utility.cpp | 8 +- test/Conversion/amd/mfma-shortcut.mlir | 52 +++++++- .../ConvertLayoutOpToLLVM.cpp | 120 +++++++++++++----- .../amd/python/test/test_chained_dot_fp8.py | 29 +++-- 4 files changed, 159 insertions(+), 50 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index d5cb1f0bd1f0..1a2c82330c70 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -639,13 +639,15 @@ bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, if (!mfmaLayout || !dotOperandLayout) return false; - // Currently supporting 32x32 FP8 MFMA -> dot operand case + // Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case return dotOperandLayout.getParent() == mfmaLayout && dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && dotOperandLayout.getKWidth() == 8 && - getContigPerThread(mfmaLayout)[1] == 4 && mfmaLayout.getMDim() == 32 && - mfmaLayout.getNDim() == 32 && + getContigPerThread(mfmaLayout)[1] == 4 && + ((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) || + (mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) && triton::type::isFloat8(srcTy.getElementType()) && + triton::type::isFloat8(dstTy.getElementType()) && mfmaLayout.getWarpsPerCTA()[1] == 1; } diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 5b0036d4c590..33a16ca611b9 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -34,8 +34,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: mfma_dot_cvt_f8 - tt.func public @mfma_dot_cvt_f8(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { + // CHECK-LABEL: mfma_dot_cvt_f8_mfma32 + tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { // CHECK-NOT: store // CHECK-NOT: load // CHECK: rocdl.ds_bpermute @@ -43,3 +43,51 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_bf8_mfma32 + tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_f8_mfma16 + tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_bf8_mfma16 + tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 5deebe591b48..401b2139385c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -134,23 +134,6 @@ struct ConvertLayoutOpMFMAToDotOpConversion if (!matchMFMAAndDotOperandShuffleCase(srcType, dstType)) return failure(); - /* - Using wave shuffle to convert layouts: - 1) Input MMA layout (32x32, fp8, 16 values): - _____________________________________________________________ - |(t0 v0 v1 v2 v3) (t32 v0 v1 v2 v3) ... (t32 v12 v13 v14 v15)| - | ... ... | - |(t31 v0 v1 v2 v3) (t63 v0 v1 v2 v3) ... (t63 v12 v13 v14 v15)| - |_____________________________________________________________| - - 2) Output Dot operand layout (two 32x16 tiles, fp8, 8 values each): - ____________________________________________________________ ___ - |(t0 v0 v1 v2 v3 v4 v5 v6 v7) (t32 v0 v1 v2 v3 v4 v5 v6 v7) | | - | ... ... | |... - |(t31 v0 v1 v2 v3 v4 v5 v6 v7) (t63 v0 v1 v2 v3 v4 v5 v6 v7) | | - |____________________________________________________________| |___ - */ - auto loc = op.getLoc(); SmallVector inVals = @@ -160,16 +143,24 @@ struct ConvertLayoutOpMFMAToDotOpConversion Value threadId = tid_val(); auto mfmaLayout = dyn_cast(srcType.getEncoding()); - Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout)); + assert((mfmaLayout.getMDim() == 16 || mfmaLayout.getMDim() == 32) && + "Expected MFMA size 16 or 32"); + assert(triton::gpu::getWarpSize(mfmaLayout) == 64 && + "Expected warp size 64 for MFMA"); + Value warpSize = i32_val(64); Value laneId = urem(threadId, warpSize); - Value laneOffset = i32_val(32); - Value mask = icmp_slt(laneId, laneOffset); - Value addr0 = select(mask, add(laneId, laneOffset), laneId); - Value addr1 = select(mask, laneId, sub(laneId, laneOffset)); - SmallVector outVals; auto elemTy = int_ty(8); auto vecTy = vec_ty(elemTy, 4); + + Value mask0 = icmp_slt(laneId, i32_val(32)); + Value mask1 = icmp_slt(urem(laneId, i32_val(32)), i32_val(16)); + + Value addrShift16 = urem(add(laneId, i32_val(16)), warpSize); + Value addrShift32 = urem(add(laneId, i32_val(32)), warpSize); + Value addrShift48 = urem(add(laneId, i32_val(48)), warpSize); + + SmallVector outVals; for (size_t startIdx = 0; startIdx < inVals.size(); startIdx += 8) { Value vec0 = undef(vecTy); for (size_t vIdx = 0; vIdx < 4; ++vIdx) { @@ -182,23 +173,82 @@ struct ConvertLayoutOpMFMAToDotOpConversion i32_val(vIdx)); } - Value shflVec0 = - bitcast(targetInfo.shuffleIdx(rewriter, loc, - bitcast(vec0, int_ty(32)), addr0), - vecTy); - Value shflVec1 = - bitcast(targetInfo.shuffleIdx(rewriter, loc, - bitcast(vec1, int_ty(32)), addr1), - vecTy); + Value resVec0, resVec1; + if (mfmaLayout.getMDim() == 32) { + /* + Using wave shuffle to convert layouts (32x32x16 case): + 1) Input MMA layout (32x32, fp8, 16 values): + _____________________________________________________________ + |(t0 v0 v1 v2 v3) (t32 v0 v1 v2 v3) ... (t32 v12 v13 v14 v15)| + | ... ... | + |(t31 v0 v1 v2 v3) (t63 v0 v1 v2 v3) ... (t63 v12 v13 v14 v15)| + |_____________________________________________________________| - Value firstVec = select(mask, vec0, shflVec1); - Value secondVec = select(mask, shflVec0, vec1); + 2) Output Dot operand layout (two 32x16 tiles, fp8, 8 values each): + ____________________________________________________________ ___ + |(t0 v0 v1 v2 v3 v4 v5 v6 v7) (t32 v0 v1 v2 v3 v4 v5 v6 v7) | | + | ... ... | |... + |(t31 v0 v1 v2 v3 v4 v5 v6 v7) (t63 v0 v1 v2 v3 v4 v5 v6 v7) | | + |____________________________________________________________| |___ + */ + + Value shflVec0 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32), + vecTy); + Value shflVec1 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32), + vecTy); + + resVec0 = select(mask0, vec0, shflVec1); + resVec1 = select(mask0, shflVec0, vec1); + } else if (mfmaLayout.getMDim() == 16) { + /* + 16x16x32 case: + 1) Input MMA layout (two 16x16, fp8, 4 values each): + _________________________________________________________ ___________ + |(t0 v0 v1 v2 v3) (t16 v0 v1 v2 v3) ... (t48 v0 v1 v2 v3)||(t0 v4 ... + | ... ... || ... + |(t15 v0 v1 v2 v3) (t31 v0 v1 v2 v3) ... (t63 v0 v1 v2 v3)||(t15 v4 ... + |_________________________________________________________||___________ + + 2) Output Dot operand layout (16x32 tile, fp8, 8 values): + ________________________________________________________________ + |(t0 v0 v1 v2 v3 v4 v5 v6 v7) ... (t48 v0 v1 v2 v3 v4 v5 v6 v7) | + | ... ... | + |(t15 v0 v1 v2 v3 v4 v5 v6 v7) ... (t63 v0 v1 v2 v3 v4 v5 v6 v7) | + |________________________________________________________________| + */ + + Value shflVec0_16 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec0, int_ty(32)), addrShift16), + vecTy); + Value shflVec0_32 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32), + vecTy); + Value shflVec1_32 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32), + vecTy); + Value shflVec1_48 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec1, int_ty(32)), addrShift48), + vecTy); + + resVec0 = select(mask0, select(mask1, vec0, shflVec0_16), + select(mask1, shflVec1_32, shflVec1_48)); + resVec1 = select(mask0, select(mask1, shflVec0_16, shflVec0_32), + select(mask1, shflVec1_48, vec1)); + } for (size_t vIdx = 0; vIdx < 4; ++vIdx) { - outVals.push_back(extract_element(elemTy, firstVec, i32_val(vIdx))); + outVals.push_back(extract_element(elemTy, resVec0, i32_val(vIdx))); } for (size_t vIdx = 0; vIdx < 4; ++vIdx) { - outVals.push_back(extract_element(elemTy, secondVec, i32_val(vIdx))); + outVals.push_back(extract_element(elemTy, resVec1, i32_val(vIdx))); } } diff --git a/third_party/amd/python/test/test_chained_dot_fp8.py b/third_party/amd/python/test/test_chained_dot_fp8.py index 5aedb896633c..922babd35320 100644 --- a/third_party/amd/python/test/test_chained_dot_fp8.py +++ b/third_party/amd/python/test/test_chained_dot_fp8.py @@ -42,6 +42,7 @@ def _chained_dot( stride_om, stride_od, Z, + M, N, BLOCK_D: tl.constexpr, BLOCK_M: tl.constexpr, @@ -92,16 +93,21 @@ def _chained_dot( class chained_dot_fn(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, q_desc=1.0, k_desc=1.0, v_desc=1.0, s_sc=1.0, s_desc=1.0, o_sc=1.0): + def forward(ctx, q, k, v, msize=32, q_desc=1.0, k_desc=1.0, v_desc=1.0, s_sc=1.0, s_desc=1.0, o_sc=1.0): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-2] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} + assert msize in {16, 32} o = torch.empty_like(q, dtype=v.dtype) BLOCK_M = 128 if q.dtype == float8 else 256 + if BLOCK_M > q.shape[1]: + BLOCK_M = int(math.pow(2, math.floor(math.log2(q.shape[1])))) BLOCK_N = 32 + if BLOCK_N > k.shape[1]: + BLOCK_N = int(math.pow(2, math.floor(math.log2(k.shape[1])))) waves_per_eu = 2 - num_warps = BLOCK_M // 32 + num_warps = 4 if q.dtype == float8 else 8 num_stages = 1 grid = (triton.cdiv(q.shape[1], BLOCK_M), q.shape[0], 1) @@ -109,9 +115,9 @@ def forward(ctx, q, k, v, q_desc=1.0, k_desc=1.0, v_desc=1.0, s_sc=1.0, s_desc=1 _chained_dot[grid](q, k, v, o, q_desc, k_desc, v_desc, s_sc, s_desc, o_sc, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), o.stride(0), o.stride(1), - o.stride(2), Z=q.shape[0], N=q.shape[1], BLOCK_D=Lk, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - USE_FP8=(q.dtype == float8), waves_per_eu=waves_per_eu, num_warps=num_warps, - num_stages=num_stages) + o.stride(2), Z=q.shape[0], M=q.shape[1], N=k.shape[1], BLOCK_D=Lk, BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, USE_FP8=(q.dtype == float8), waves_per_eu=waves_per_eu, num_warps=num_warps, + num_stages=num_stages, matrix_instr_nonkdim=msize) return o @@ -127,13 +133,16 @@ def to_float8(x, dtype=float8, margin: float = 1.0): return x_scaled.to(dtype), scale, 1.0 / scale -@pytest.mark.parametrize('N, D, dtype', [(*shape, dtype) for shape in [(128, 32), (256, 128)] for dtype in ['fp8']]) -def test_chained_dot(N, D, dtype): +@pytest.mark.parametrize('M, N, D, dtype, msize', [(*shape, dtype, msize) + for shape in [(128, 64, 32), (256, 128, 128)] + for dtype in ['fp8'] + for msize in [16, 32]]) +def test_chained_dot(M, N, D, dtype, msize): if dtype == 'fp8': assert float8 is not None BATCH = 1 - q = torch.empty((BATCH, N, D), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5) + q = torch.empty((BATCH, M, D), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5) k = torch.empty((BATCH, N, D), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5) v = torch.empty((BATCH, D, N), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5) @@ -151,7 +160,7 @@ def test_chained_dot(N, D, dtype): scale_b=torch.tensor(v_desc, dtype=torch.float32, device="cuda")) ref_f8, ref_sc, _ = to_float8(ref) - tri_out = chained_dot(q_f8, k_f8, v_f8, q_desc, k_desc, v_desc, s_sc, s_desc, ref_sc) + tri_out = chained_dot(q_f8, k_f8, v_f8, msize, q_desc, k_desc, v_desc, s_sc, s_desc, ref_sc) assert tri_out.isnan().sum() == 0 torch.testing.assert_close(tri_out[0].float(), ref_f8.float(), atol=1e-2, rtol=0) @@ -160,5 +169,5 @@ def test_chained_dot(N, D, dtype): s = torch.matmul(q, k.transpose(1, 2)) ref = torch.matmul(s, v.transpose(1, 2)) - tri_out = chained_dot(q, k, v) + tri_out = chained_dot(q, k, v, msize) torch.testing.assert_close(tri_out, ref, atol=1e-2, rtol=0) From a055fc4159be06bf15c9437e143ea708c2d4b651 Mon Sep 17 00:00:00 2001 From: Ilya Cherniavski Date: Wed, 20 Nov 2024 23:56:43 +0000 Subject: [PATCH 06/10] (move out test_chained_dot_fp8.py) --- .../amd/python/test/test_chained_dot_fp8.py | 173 ------------------ 1 file changed, 173 deletions(-) delete mode 100644 third_party/amd/python/test/test_chained_dot_fp8.py diff --git a/third_party/amd/python/test/test_chained_dot_fp8.py b/third_party/amd/python/test/test_chained_dot_fp8.py deleted file mode 100644 index 922babd35320..000000000000 --- a/third_party/amd/python/test/test_chained_dot_fp8.py +++ /dev/null @@ -1,173 +0,0 @@ -""" -Testing the (FP8) case of a dot op that consumes the output (MFMA) of -another dot op as an input. - -""" - -import math -import pytest -import sys -import torch - -import triton -import triton.language as tl - -TORCH_HAS_FP8E4 = hasattr(torch, 'float8_e4m3fnuz') -float8: tl.constexpr = None if not TORCH_HAS_FP8E4 else torch.float8_e4m3fnuz -torch.manual_seed(42) - - -@triton.jit -def _chained_dot( - Q, - K, - V, - Out, - q_desc, - k_desc, - v_desc, - s_sc, - s_desc, - o_sc, - stride_qz, - stride_qm, - stride_qd, - stride_kz, - stride_kn, - stride_kd, - stride_vz, - stride_vd, - stride_vn, - stride_oz, - stride_om, - stride_od, - Z, - M, - N, - BLOCK_D: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - USE_FP8: tl.constexpr, -): - start_m = tl.program_id(0) - off_z = tl.program_id(1) - qkv_offset = off_z * stride_qz - Q_block_ptr = tl.make_block_ptr(base=Q + qkv_offset, shape=(N, BLOCK_D), strides=(stride_qm, stride_qd), - offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_D), order=(1, 0)) - K_block_ptr = tl.make_block_ptr(base=K + qkv_offset, shape=(BLOCK_D, N), strides=(stride_kd, stride_kn), - offsets=(0, 0), block_shape=(BLOCK_D, BLOCK_N), order=(0, 1)) - V_block_ptr = tl.make_block_ptr(base=V + qkv_offset, shape=(N, BLOCK_D), strides=(stride_vn, stride_vd), - offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_D), order=(0, 1)) - - s_scale = q_desc * k_desc * s_sc - acc_scale = s_desc * v_desc * o_sc - - q = tl.load(Q_block_ptr) - - acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) - lo, hi = 0, N - for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - - k = tl.load(K_block_ptr) - s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - s += tl.dot(q, k) - - if USE_FP8: - s *= s_scale - - v = tl.load(V_block_ptr) - acc += tl.dot(s.to(v.dtype), v) - - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - - if USE_FP8: - acc *= acc_scale - - O_block_ptr = tl.make_block_ptr(base=Out + qkv_offset, shape=(N, BLOCK_D), strides=(stride_om, stride_od), - offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_D), order=(1, 0)) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) - - -class chained_dot_fn(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, k, v, msize=32, q_desc=1.0, k_desc=1.0, v_desc=1.0, s_sc=1.0, s_desc=1.0, o_sc=1.0): - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-2] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - assert msize in {16, 32} - o = torch.empty_like(q, dtype=v.dtype) - - BLOCK_M = 128 if q.dtype == float8 else 256 - if BLOCK_M > q.shape[1]: - BLOCK_M = int(math.pow(2, math.floor(math.log2(q.shape[1])))) - BLOCK_N = 32 - if BLOCK_N > k.shape[1]: - BLOCK_N = int(math.pow(2, math.floor(math.log2(k.shape[1])))) - waves_per_eu = 2 - num_warps = 4 if q.dtype == float8 else 8 - num_stages = 1 - - grid = (triton.cdiv(q.shape[1], BLOCK_M), q.shape[0], 1) - - _chained_dot[grid](q, k, v, o, q_desc, - k_desc, v_desc, s_sc, s_desc, o_sc, q.stride(0), q.stride(1), q.stride(2), k.stride(0), - k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), o.stride(0), o.stride(1), - o.stride(2), Z=q.shape[0], M=q.shape[1], N=k.shape[1], BLOCK_D=Lk, BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, USE_FP8=(q.dtype == float8), waves_per_eu=waves_per_eu, num_warps=num_warps, - num_stages=num_stages, matrix_instr_nonkdim=msize) - - return o - - -chained_dot = chained_dot_fn.apply - - -def to_float8(x, dtype=float8, margin: float = 1.0): - finfo = torch.finfo(dtype) - scale = finfo.max / x.abs().max().clamp(min=1e-12) - scale = math.pow(2, math.floor(math.log2(scale.float().item())) - margin) - x_scaled = (x.float() * scale).clamp(min=finfo.min, max=finfo.max) - return x_scaled.to(dtype), scale, 1.0 / scale - - -@pytest.mark.parametrize('M, N, D, dtype, msize', [(*shape, dtype, msize) - for shape in [(128, 64, 32), (256, 128, 128)] - for dtype in ['fp8'] - for msize in [16, 32]]) -def test_chained_dot(M, N, D, dtype, msize): - if dtype == 'fp8': - assert float8 is not None - - BATCH = 1 - q = torch.empty((BATCH, M, D), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5) - k = torch.empty((BATCH, N, D), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5) - v = torch.empty((BATCH, D, N), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5) - - if dtype == 'fp8': - q_f8, _, q_desc = to_float8(q) - k_f8, _, k_desc = to_float8(k) - v_f8, _, v_desc = to_float8(v) - - s = torch._scaled_mm(q_f8[0], k_f8[0].transpose(0, 1), out_dtype=torch.float32, - scale_a=torch.tensor(q_desc, dtype=torch.float32, device="cuda"), - scale_b=torch.tensor(k_desc, dtype=torch.float32, device="cuda")) - s_f8, s_sc, s_desc = to_float8(s) - ref = torch._scaled_mm(s_f8, v_f8[0].transpose(0, 1), out_dtype=torch.float32, - scale_a=torch.tensor(s_desc, dtype=torch.float32, device="cuda"), - scale_b=torch.tensor(v_desc, dtype=torch.float32, device="cuda")) - ref_f8, ref_sc, _ = to_float8(ref) - - tri_out = chained_dot(q_f8, k_f8, v_f8, msize, q_desc, k_desc, v_desc, s_sc, s_desc, ref_sc) - - assert tri_out.isnan().sum() == 0 - torch.testing.assert_close(tri_out[0].float(), ref_f8.float(), atol=1e-2, rtol=0) - - else: - s = torch.matmul(q, k.transpose(1, 2)) - ref = torch.matmul(s, v.transpose(1, 2)) - - tri_out = chained_dot(q, k, v, msize) - torch.testing.assert_close(tri_out, ref, atol=1e-2, rtol=0) From 50c3f1d933ad1f6fbfc3821beab70167d5371561 Mon Sep 17 00:00:00 2001 From: Ilya Cherniavski Date: Thu, 21 Nov 2024 08:12:32 +0000 Subject: [PATCH 07/10] (more lit tests, merge) --- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 6 + test/Conversion/amd/mfma-shortcut.mlir | 143 +++++++++++++++++- .../ConvertLayoutOpToLLVM.cpp | 10 +- 3 files changed, 152 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 62499d8208cf..63c7a03df308 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -406,6 +406,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return failure(); } + // The following check can be removed when generalized warp shuffle + // conversions are ready: + if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) { + return failure(); + } + assert(cvtNeedsSharedMemory(srcTy, dstTy)); SmallVector inVals = diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 33a16ca611b9..97226c2eb5cd 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -38,7 +38,66 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { // CHECK-NOT: store // CHECK-NOT: load - // CHECK: rocdl.ds_bpermute + + // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] + + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[warpSize:%.*]] = llvm.mlir.constant(64 : i32) + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[warpSize]] + // CHECK: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + + // CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) + // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] + // CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] + + // CHECK: [[c16:%.*]] = llvm.mlir.constant(16 : i32) + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]] + // CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] + + // CHECK: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] + + // CHECK: [[c48:%.*]] = llvm.mlir.constant(48 : i32) + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] + // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] + + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]] + + // Input (8 values): (vec0, vec1) + // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): + // resVec0 resVec1 + // lanes 0-31: (vec0 , vec0 >> 32) (mask0=1) + // lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0) + + // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]] + // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]] + + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] + + // CHECK: llvm.return %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> tt.return } @@ -55,6 +114,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NOT: store // CHECK-NOT: load // CHECK: rocdl.ds_bpermute + // CHECK: llvm.return %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> tt.return } @@ -70,7 +130,85 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { // CHECK-NOT: store // CHECK-NOT: load - // CHECK: rocdl.ds_bpermute + + // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] + + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[warpSize:%.*]] = llvm.mlir.constant(64 : i32) + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[warpSize]] + // CHECK: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + + // CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) + // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] + // CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] + + // CHECK: [[c16:%.*]] = llvm.mlir.constant(16 : i32) + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]] + // CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] + + // CHECK: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] + + // CHECK: [[c48:%.*]] = llvm.mlir.constant(48 : i32) + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] + // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] + + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]] + // CHECK: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]] + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]] + // CHECK: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]] + + // Input (8 values): (vec0, vec1) + // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): + // resVec0 resVec1 + // lanes 0-15: (vec0 , vec0 >> 16) (mask0=1, mask1=1) + // lanes 16-31: (vec0 >> 16, vec0 >> 32) (mask0=1, mask1=0) + // lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1) + // lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0) + + // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8> + // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + + // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8> + // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] + + // CHECK: llvm.return %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> tt.return } @@ -87,6 +225,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NOT: store // CHECK-NOT: load // CHECK: rocdl.ds_bpermute + // CHECK: llvm.return %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> tt.return } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 401b2139385c..fb07e76a2f4b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -185,11 +185,11 @@ struct ConvertLayoutOpMFMAToDotOpConversion |_____________________________________________________________| 2) Output Dot operand layout (two 32x16 tiles, fp8, 8 values each): - ____________________________________________________________ ___ - |(t0 v0 v1 v2 v3 v4 v5 v6 v7) (t32 v0 v1 v2 v3 v4 v5 v6 v7) | | - | ... ... | |... - |(t31 v0 v1 v2 v3 v4 v5 v6 v7) (t63 v0 v1 v2 v3 v4 v5 v6 v7) | | - |____________________________________________________________| |___ + ____________________________________________________________ ___ + |(t0 v0 v1 v2 v3 v4 v5 v6 v7) (t32 v0 v1 v2 v3 v4 v5 v6 v7) || + | ... ... ||... + |(t31 v0 v1 v2 v3 v4 v5 v6 v7) (t63 v0 v1 v2 v3 v4 v5 v6 v7) || + |____________________________________________________________||___ */ Value shflVec0 = From 62083cf8f22e59c8b13c07b8817a4a54c8e51fb2 Mon Sep 17 00:00:00 2001 From: Ilya Cherniavski Date: Thu, 21 Nov 2024 15:49:16 +0000 Subject: [PATCH 08/10] (linter) --- test/Conversion/amd/mfma-shortcut.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 97226c2eb5cd..10deaaab1ede 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -65,7 +65,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] - // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] @@ -94,7 +94,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> - // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] // CHECK: llvm.return @@ -157,7 +157,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] - // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] @@ -205,7 +205,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> - // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] // CHECK: llvm.return From e179da1c00b552111a189d53da62350c82d09298 Mon Sep 17 00:00:00 2001 From: Ilya Cherniavski Date: Thu, 21 Nov 2024 16:16:31 +0000 Subject: [PATCH 09/10] (lit) --- test/Conversion/amd/mfma-shortcut.mlir | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 10deaaab1ede..c5885af5dde2 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -42,15 +42,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] - // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x - // CHECK: [[warpSize:%.*]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK-DAG: [[warpSize:%.*]] = llvm.mlir.constant(64 : i32) // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[warpSize]] // CHECK: [[c32:%.*]] = llvm.mlir.constant(32 : i32) // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] // CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) - // CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] + // CHECK-DAG: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] // CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] // CHECK: [[c16:%.*]] = llvm.mlir.constant(16 : i32) @@ -134,15 +134,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] - // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x - // CHECK: [[warpSize:%.*]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK-DAG: [[warpSize:%.*]] = llvm.mlir.constant(64 : i32) // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[warpSize]] // CHECK: [[c32:%.*]] = llvm.mlir.constant(32 : i32) // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] // CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) - // CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] + // CHECK-DAG: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] // CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] // CHECK: [[c16:%.*]] = llvm.mlir.constant(16 : i32) From 247432aa71637202db042c551630e97d9744b8b5 Mon Sep 17 00:00:00 2001 From: Ilya Cherniavski Date: Thu, 21 Nov 2024 16:39:30 +0000 Subject: [PATCH 10/10] (simplify lit) --- test/Conversion/amd/mfma-shortcut.mlir | 47 +++++++------------ .../ConvertLayoutOpToLLVM.cpp | 21 +++++---- 2 files changed, 29 insertions(+), 39 deletions(-) diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index c5885af5dde2..bcbc7eff590e 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -42,28 +42,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] - // CHECK-DAG: [[threadId:%.*]] = rocdl.workitem.id.x - // CHECK-DAG: [[warpSize:%.*]] = llvm.mlir.constant(64 : i32) - // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[warpSize]] - // CHECK: [[c32:%.*]] = llvm.mlir.constant(32 : i32) - // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] - - // CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) - // CHECK-DAG: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] - // CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] + // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) - // CHECK: [[c16:%.*]] = llvm.mlir.constant(16 : i32) - // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]] - // CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] - // CHECK: [[c32:%.*]] = llvm.mlir.constant(32 : i32) // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] - // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] - - // CHECK: [[c48:%.*]] = llvm.mlir.constant(48 : i32) - // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] - // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> @@ -134,28 +121,26 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] - // CHECK-DAG: [[threadId:%.*]] = rocdl.workitem.id.x - // CHECK-DAG: [[warpSize:%.*]] = llvm.mlir.constant(64 : i32) - // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[warpSize]] - // CHECK: [[c32:%.*]] = llvm.mlir.constant(32 : i32) - // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] - // CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) - // CHECK-DAG: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] + // CHECK-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32) + // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) + + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + + // CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] // CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] - // CHECK: [[c16:%.*]] = llvm.mlir.constant(16 : i32) // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]] - // CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] + // CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - // CHECK: [[c32:%.*]] = llvm.mlir.constant(32 : i32) // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] - // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - // CHECK: [[c48:%.*]] = llvm.mlir.constant(48 : i32) // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] - // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[warpSize]] + // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]] // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index fb07e76a2f4b..632ca182e4e8 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -141,24 +141,29 @@ struct ConvertLayoutOpMFMAToDotOpConversion if (inVals.empty() || inVals.size() % 8 != 0) return failure(); - Value threadId = tid_val(); auto mfmaLayout = dyn_cast(srcType.getEncoding()); assert((mfmaLayout.getMDim() == 16 || mfmaLayout.getMDim() == 32) && "Expected MFMA size 16 or 32"); assert(triton::gpu::getWarpSize(mfmaLayout) == 64 && "Expected warp size 64 for MFMA"); - Value warpSize = i32_val(64); - Value laneId = urem(threadId, warpSize); auto elemTy = int_ty(8); auto vecTy = vec_ty(elemTy, 4); - Value mask0 = icmp_slt(laneId, i32_val(32)); - Value mask1 = icmp_slt(urem(laneId, i32_val(32)), i32_val(16)); + Value c16 = i32_val(16); + Value c32 = i32_val(32); + Value c48 = i32_val(48); + Value c64 = i32_val(64); + + Value threadId = tid_val(); + Value laneId = urem(threadId, c64); + + Value mask0 = icmp_slt(laneId, c32); + Value mask1 = icmp_slt(urem(laneId, c32), c16); - Value addrShift16 = urem(add(laneId, i32_val(16)), warpSize); - Value addrShift32 = urem(add(laneId, i32_val(32)), warpSize); - Value addrShift48 = urem(add(laneId, i32_val(48)), warpSize); + Value addrShift16 = urem(add(laneId, c16), c64); + Value addrShift32 = urem(add(laneId, c32), c64); + Value addrShift48 = urem(add(laneId, c48), c64); SmallVector outVals; for (size_t startIdx = 0; startIdx < inVals.size(); startIdx += 8) {