diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 074e8c295d7d..43e6e183037e 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -562,10 +562,11 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, auto dstLayout = dstTy.getEncoding(); auto mmaLayout = srcLayout.cast(); auto dotOperandLayout = dstLayout.cast(); - auto ans = - mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 && - isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) && - (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); + int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth(); + auto ans = mmaLayout.getVersionMajor() == 3 && + dotOperandLayout.getOpIdx() == 0 && + isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) && + (elementTypeSize == 16 || elementTypeSize == 8); return ans; } diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index baf845377fd2..8d5162a390f0 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -291,18 +291,14 @@ struct MMAV3UseRegOperand : public OpRewritePattern { if (!srcEnc || srcEnc.getVersionMajor() != 3 || !dstEnc || dstEnc.getVersionMajor() != 3) return failure(); - - // We currently only support convert from f16 and bf16 mma to f16 and bf16 - // dot operand, as the other types require shuffling data across threads. - // TODO: extend it to more types. auto srcTy = alloc.getInit().getType().cast(); - if (!(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16())) - return failure(); - auto dotOperandEnc = DotOperandEncodingAttr::get( dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0); auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), dotOperandEnc); + if (!isMmaToDotShortcut(srcTy, newTy)) + return failure(); + Value newOperand = rewriter.create(dotOp.getLoc(), newTy, alloc.getInit()); rewriter.modifyOpInPlace(dotOp, [&]() { dotOp.setOperand(0, newOperand); }); diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 28985e49ffaf..d92a048af753 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -151,6 +151,7 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { SmallVector queue = {op->getResult(0)}; SetVector forwardSlice; llvm::SmallDenseSet seen; + bool isMMAV3 = encoding.cast().getVersionMajor() == 3; while (!queue.empty()) { Value currentValue = queue.back(); queue.pop_back(); @@ -164,6 +165,8 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { if (dstEncoding.isa()) return encoding.cast().getVersionMajor() > 1; } + if (isMMAV3 && isa(op)) + return true; auto yield = dyn_cast(op); if (!yield) continue; diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d7b010bb7cea..4115bec88a96 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2742,6 +2742,14 @@ def kernel(In, Out, # # --------------- +def convert_fp8_to_fp32(x, device, dtype_str): + if dtype_str == 'float8e4nv': + return torch.tensor(x, device=device).view(torch.float8_e4m3fn).to(torch.float32) + elif dtype_str == 'float8e5': + return torch.tensor(x, device=device).view(torch.float8_e5m2).to(torch.float32) + assert "Unsupported float8 dtype" + + @pytest.mark.interpreter @pytest.mark.parametrize( "M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype", @@ -2761,7 +2769,9 @@ def kernel(In, Out, # 'float32')]] + [(64, 64, 64, 4, col_a, col_b, 'none', False, 'float32', 'float32') for col_a in [True, False] - for col_b in [True, False]] + [(64, 64, 64, 4, False, False, 'chain-dot', False, 'bfloat16', 'float32')]) + for col_b in [True, False]] + [(64, 64, 64, 4, False, False, 'chain-dot', False, 'bfloat16', 'float32')] + + [(128, 128, 64, 4, False, False, 'chain-dot', False, float8_type, 'float32') + for float8_type in ["float8e5", "float8e4nv"]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, num_ctas, device): check_cuda_only(device) @@ -2781,6 +2791,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o if out_dtype == 'float16': # TODO: support out_dtype=float16 for tl.dot on V100 pytest.skip("Only test out_dtype=float16 on devices with sm >=80") + if capability[0] < 9 and in_dtype == 'float8e4nv': + pytest.skip("float8e4nv not supported on sm <= 80") if is_interpreter() and in_dtype == 'int8': pytest.skip( "numpy.dot with int8 inputs will overflow while tl.dot doesn't because MMA instruction's accumulator is 32-bit" @@ -2839,16 +2851,16 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid else: y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) w = numpy_random((N, N), dtype_str=in_dtype, rs=rs) - if 'int' not in in_dtype: + if 'int' not in in_dtype and 'float8' not in in_dtype: x *= .1 y *= .1 if in_dtype == 'float32' and allow_tf32: x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') - x_tri = to_triton(x, device=device) - y_tri = to_triton(y, device=device) - w_tri = to_triton(w, device=device) + x_tri = to_triton(x, device=device, dst_type=in_dtype) + y_tri = to_triton(y, device=device, dst_type=in_dtype) + w_tri = to_triton(w, device=device, dst_type=in_dtype) # triton result if out_dtype == 'int8': z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs) @@ -2894,6 +2906,10 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid # torch result if in_dtype == 'int8': z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32) + elif 'float8' in in_dtype: + x = convert_fp8_to_fp32(x, device, in_dtype) + y = convert_fp8_to_fp32(y, device, in_dtype) + z_ref = to_numpy(torch.matmul(x, y)) else: z_ref = np.matmul(x, y) @@ -2908,12 +2924,14 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid denom = np.sum(num, axis=-1, keepdims=True) z_ref = num / denom if epilogue == 'chain-dot': + if 'float8' in in_dtype: + w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype)) z_ref = np.matmul(z_ref, w) # compare if in_dtype == 'float32': # XXX: Somehow there's a larger difference when we use float32 np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) - elif out_dtype == tl.float16: + elif out_dtype == tl.float16 or in_dtype == 'bfloat16': np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) else: # added atol, to loose precision for float16xfloat16->float32 case @@ -2925,7 +2943,10 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4): # XXX: skip small sizes because they are not vectorized assert 'ld.global.v4' in ptx - assert 'st.global.v4' in ptx + if 'float8' in in_dtype: + assert 'st.global.v2' in ptx + else: + assert 'st.global.v4' in ptx if in_dtype == 'float32' and allow_tf32: assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx) elif in_dtype == 'float16' and out_dtype == tl.float32: @@ -2944,6 +2965,12 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid else: assert 'wgmma.mma_async.sync.aligned' in ptx or\ 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx + elif in_dtype == "float8e5" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2' in ptx + elif in_dtype == "float8e4nv" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx @pytest.mark.parametrize("B", [1, 2, 4, 8]) diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 7042aaa946d5..b09e4ee9d3ce 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -179,3 +179,21 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c tt.return } } + +// ----- + +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +// CHECK-LABEL: cvt_mma_to_dot_fp8 +// CHECK: prmt.b32 +// CHECK: prmt.b32 +// CHECK: nvvm.shfl.sync +// CHECK: nvvm.shfl.sync +// CHECK: prmt.b32 +// CHECK: prmt.b32 + tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) { + %opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + tt.return + } +} diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 7716e83b58fb..bdef81286803 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -154,6 +154,22 @@ tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdes } } +// ----- + +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @mma_v3_reg_operand_A_fp8 +// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> +// CHECK: tt.dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> +tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !tt.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !tt.memdesc<128x64xf8E5M2, #shared1> + %r = tt.dot %A, %arg1, %arg2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> +} +} + // ----- #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 497911a998c4..e8e0524794e7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -1,10 +1,9 @@ #include "PatternTritonGPUOpToLLVM.h" #include "Utility.h" - +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "triton/Analysis/Allocation.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" - #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" using mlir::isLayoutMmaV1; using ::mlir::LLVM::getMultiDimOffset; @@ -584,6 +583,83 @@ struct ConvertLayoutOpConversion return success(); } + // Convert from accumulator MMA layout to 8bit dot operand layout. + // The conversion logic is taken from: + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a9de6446c1c0415c926025cea284210c799b11f8/src/fmha-pipeline/reg2reg.h#L45 + void + convertMMAV3To8BitsDotOperand(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto dstTy = op.getType(); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector retVals; + for (int i = 0; i < vals.size(); i += 8) { + Value upper = undef(vec_ty(i8_ty, 4)); + for (int j = 0; j < 4; j++) { + upper = + insert_element(vec_ty(i8_ty, 4), upper, vals[i + j], i32_val(j)); + } + upper = bitcast(upper, i32_ty); + Value lower = undef(vec_ty(i8_ty, 4)); + for (int j = 0; j < 4; j++) { + lower = insert_element(vec_ty(i8_ty, 4), lower, vals[i + 4 + j], + i32_val(j)); + } + lower = bitcast(lower, i32_ty); + + Value threadIdMod4 = urem(getThreadId(rewriter, loc), i32_val(4)); + Value cnd = or_(icmp_eq(threadIdMod4, i32_val(0)), + icmp_eq(threadIdMod4, i32_val(3))); + Value selectorEx0 = select(cnd, i32_val(0x3210), i32_val(0x7654)); + Value selectorEx1 = select(cnd, i32_val(0x7654), i32_val(0x3210)); + Value selectorEx4 = select(cnd, i32_val(0x5410), i32_val(0x1054)); + Value selectorEx5 = select(cnd, i32_val(0x7632), i32_val(0x3276)); + + Value isOne = icmp_eq(threadIdMod4, i32_val(1)); + Value isTwo = icmp_eq(threadIdMod4, i32_val(2)); + Value isThree = icmp_eq(threadIdMod4, i32_val(3)); + Value upperIdx = i32_val(0); + upperIdx = select(isOne, i32_val(3), upperIdx); + upperIdx = select(isTwo, i32_val(1), upperIdx); + upperIdx = select(isThree, i32_val(2), upperIdx); + + Value lowerIdx = i32_val(1); + lowerIdx = select(isOne, i32_val(2), lowerIdx); + lowerIdx = select(isTwo, i32_val(0), lowerIdx); + lowerIdx = select(isThree, i32_val(3), lowerIdx); + + Value upper0 = + LLVM::NVIDIA::permute(loc, rewriter, upper, lower, selectorEx0); + Value lower0 = + LLVM::NVIDIA::permute(loc, rewriter, upper, lower, selectorEx1); + Value mask = i32_val(0xFFFFFFFF); + // Set clamp tp shuffle only within 4 lanes. + Value clamp = i32_val(0x1C1F); + upper0 = + rewriter.create(loc, i32_ty, mask, upper0, upperIdx, + clamp, NVVM::ShflKind::idx, UnitAttr()); + lower0 = + rewriter.create(loc, i32_ty, mask, lower0, lowerIdx, + clamp, NVVM::ShflKind::idx, UnitAttr()); + Value upper1 = + LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx4); + Value vecVal = bitcast(upper1, vec_ty(i8_ty, 4)); + for (int i = 0; i < 4; i++) { + retVals.push_back(extract_element(i8_ty, vecVal, i32_val(i))); + } + Value lower1 = + LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx5); + vecVal = bitcast(lower1, vec_ty(i8_ty, 4)); + for (int i = 0; i < 4; i++) { + retVals.push_back(extract_element(i8_ty, vecVal, i32_val(i))); + } + } + Value result = + packLLElements(loc, getTypeConverter(), retVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + } + // mma -> dot_operand LogicalResult lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, @@ -592,7 +668,13 @@ struct ConvertLayoutOpConversion auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) { - rewriter.replaceOp(op, adaptor.getSrc()); + if (srcTy.getElementType().getIntOrFloatBitWidth() == 16) { + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } + assert(srcTy.getElementType().getIntOrFloatBitWidth() == 8 && + "Unsupported type size."); + convertMMAV3To8BitsDotOperand(op, adaptor, rewriter); return success(); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index 2e67710fec70..b744630685a1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -8,8 +8,9 @@ namespace LLVM { namespace NVIDIA { using namespace mlir::triton; -Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, - Value val, Value i, NVVM::ShflKind mode, Value clamp) { +static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, + Value val, Value i, NVVM::ShflKind mode, + Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); if (bits == 64) { @@ -90,6 +91,19 @@ Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr) { Value val = builder.launch(b, loc, b.getIntegerType(32), false); return val; } + +Value permute(Location loc, ConversionPatternRewriter &rewriter, Value a, + Value b, Value mask) { + PTXBuilder builder; + auto &prmt = builder.create("prmt")->o("b32"); + auto *destOpr = builder.newOperand("=r"); + auto *aOperand = builder.newOperand(a, "r"); + auto *bOperand = builder.newOperand(b, "r"); + auto *maskOperand = builder.newOperand(mask, "r"); + prmt(destOpr, aOperand, bOperand, maskOperand); + return builder.launch(rewriter, loc, rewriter.getIntegerType(32), false); +} + } // namespace NVIDIA } // namespace LLVM } // namespace mlir diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h index 816c8b599a81..ff14c2eb2b9a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h @@ -43,6 +43,8 @@ Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, int i); Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, Value i); +Value permute(Location loc, ConversionPatternRewriter &rewriter, Value a, + Value b, Value mask); Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, ModuleOp moduleOp, int axis);