Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,10 +562,11 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
auto dstLayout = dstTy.getEncoding();
auto mmaLayout = srcLayout.cast<NvidiaMmaEncodingAttr>();
auto dotOperandLayout = dstLayout.cast<DotOperandEncodingAttr>();
auto ans =
mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 &&
isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth();
auto ans = mmaLayout.getVersionMajor() == 3 &&
dotOperandLayout.getOpIdx() == 0 &&
isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) &&
(elementTypeSize == 16 || elementTypeSize == 8);
return ans;
}

Expand Down
10 changes: 3 additions & 7 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,18 +291,14 @@ struct MMAV3UseRegOperand : public OpRewritePattern<DotOp> {
if (!srcEnc || srcEnc.getVersionMajor() != 3 || !dstEnc ||
dstEnc.getVersionMajor() != 3)
return failure();

// We currently only support convert from f16 and bf16 mma to f16 and bf16
// dot operand, as the other types require shuffling data across threads.
// TODO: extend it to more types.
auto srcTy = alloc.getInit().getType().cast<RankedTensorType>();
if (!(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()))
return failure();

auto dotOperandEnc = DotOperandEncodingAttr::get(
dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0);
auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(),
dotOperandEnc);
if (!isMmaToDotShortcut(srcTy, newTy))
return failure();

Value newOperand = rewriter.create<ConvertLayoutOp>(dotOp.getLoc(), newTy,
alloc.getInit());
rewriter.modifyOpInPlace(dotOp, [&]() { dotOp.setOperand(0, newOperand); });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
SmallVector<Value> queue = {op->getResult(0)};
SetVector<Operation *> forwardSlice;
llvm::SmallDenseSet<Value> seen;
bool isMMAV3 = encoding.cast<NvidiaMmaEncodingAttr>().getVersionMajor() == 3;
while (!queue.empty()) {
Value currentValue = queue.back();
queue.pop_back();
Expand All @@ -164,6 +165,8 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
if (dstEncoding.isa<DotOperandEncodingAttr>())
return encoding.cast<NvidiaMmaEncodingAttr>().getVersionMajor() > 1;
}
if (isMMAV3 && isa<LocalAllocOp>(op))
return true;
auto yield = dyn_cast<scf::YieldOp>(op);
if (!yield)
continue;
Expand Down
41 changes: 34 additions & 7 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2742,6 +2742,14 @@ def kernel(In, Out, #
# ---------------


def convert_fp8_to_fp32(x, device, dtype_str):
if dtype_str == 'float8e4nv':
return torch.tensor(x, device=device).view(torch.float8_e4m3fn).to(torch.float32)
elif dtype_str == 'float8e5':
return torch.tensor(x, device=device).view(torch.float8_e5m2).to(torch.float32)
assert "Unsupported float8 dtype"


@pytest.mark.interpreter
@pytest.mark.parametrize(
"M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype",
Expand All @@ -2761,7 +2769,9 @@ def kernel(In, Out, #
'float32')]] +
[(64, 64, 64, 4, col_a, col_b, 'none', False, 'float32', 'float32')
for col_a in [True, False]
for col_b in [True, False]] + [(64, 64, 64, 4, False, False, 'chain-dot', False, 'bfloat16', 'float32')])
for col_b in [True, False]] + [(64, 64, 64, 4, False, False, 'chain-dot', False, 'bfloat16', 'float32')] +
[(128, 128, 64, 4, False, False, 'chain-dot', False, float8_type, 'float32')
for float8_type in ["float8e5", "float8e4nv"]])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, num_ctas, device):
check_cuda_only(device)
Expand All @@ -2781,6 +2791,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
if out_dtype == 'float16':
# TODO: support out_dtype=float16 for tl.dot on V100
pytest.skip("Only test out_dtype=float16 on devices with sm >=80")
if capability[0] < 9 and in_dtype == 'float8e4nv':
pytest.skip("float8e4nv not supported on sm <= 80")
if is_interpreter() and in_dtype == 'int8':
pytest.skip(
"numpy.dot with int8 inputs will overflow while tl.dot doesn't because MMA instruction's accumulator is 32-bit"
Expand Down Expand Up @@ -2839,16 +2851,16 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
else:
y = numpy_random((K, N), dtype_str=in_dtype, rs=rs)
w = numpy_random((N, N), dtype_str=in_dtype, rs=rs)
if 'int' not in in_dtype:
if 'int' not in in_dtype and 'float8' not in in_dtype:
x *= .1
y *= .1
if in_dtype == 'float32' and allow_tf32:
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
w_tri = to_triton(w, device=device)
x_tri = to_triton(x, device=device, dst_type=in_dtype)
y_tri = to_triton(y, device=device, dst_type=in_dtype)
w_tri = to_triton(w, device=device, dst_type=in_dtype)
# triton result
if out_dtype == 'int8':
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
Expand Down Expand Up @@ -2894,6 +2906,10 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
# torch result
if in_dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32)
elif 'float8' in in_dtype:
x = convert_fp8_to_fp32(x, device, in_dtype)
y = convert_fp8_to_fp32(y, device, in_dtype)
z_ref = to_numpy(torch.matmul(x, y))
else:
z_ref = np.matmul(x, y)

Expand All @@ -2908,12 +2924,14 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
denom = np.sum(num, axis=-1, keepdims=True)
z_ref = num / denom
if epilogue == 'chain-dot':
if 'float8' in in_dtype:
w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype))
z_ref = np.matmul(z_ref, w)
# compare
if in_dtype == 'float32':
# XXX: Somehow there's a larger difference when we use float32
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
elif out_dtype == tl.float16:
elif out_dtype == tl.float16 or in_dtype == 'bfloat16':
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
else:
# added atol, to loose precision for float16xfloat16->float32 case
Expand All @@ -2925,7 +2943,10 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4):
# XXX: skip small sizes because they are not vectorized
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if 'float8' in in_dtype:
assert 'st.global.v2' in ptx
else:
assert 'st.global.v4' in ptx
if in_dtype == 'float32' and allow_tf32:
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx)
elif in_dtype == 'float16' and out_dtype == tl.float32:
Expand All @@ -2944,6 +2965,12 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
else:
assert 'wgmma.mma_async.sync.aligned' in ptx or\
'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
elif in_dtype == "float8e5" and out_dtype == tl.float32:
if capability[0] == 9:
assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2' in ptx
elif in_dtype == "float8e4nv" and out_dtype == tl.float32:
if capability[0] == 9:
assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx


@pytest.mark.parametrize("B", [1, 2, 4, 8])
Expand Down
18 changes: 18 additions & 0 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,21 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
tt.return
}
}

// -----

#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: cvt_mma_to_dot_fp8
// CHECK: prmt.b32
// CHECK: prmt.b32
// CHECK: nvvm.shfl.sync
// CHECK: nvvm.shfl.sync
// CHECK: prmt.b32
// CHECK: prmt.b32
tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) {
%opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
tt.return
}
}
16 changes: 16 additions & 0 deletions test/TritonGPU/dot-operands.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,22 @@ tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdes
}
}

// -----

#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK: tt.func @mma_v3_reg_operand_A_fp8
// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
// CHECK: tt.dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma>
tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !tt.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{
%A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !tt.memdesc<128x64xf8E5M2, #shared1>
%r = tt.dot %A, %arg1, %arg2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma>
tt.return %r : tensor<128x64xf32, #mma>
}
}

// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#include "PatternTritonGPUOpToLLVM.h"
#include "Utility.h"

#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

using mlir::isLayoutMmaV1;
using ::mlir::LLVM::getMultiDimOffset;
Expand Down Expand Up @@ -584,6 +583,83 @@ struct ConvertLayoutOpConversion
return success();
}

// Convert from accumulator MMA layout to 8bit dot operand layout.
// The conversion logic is taken from:
// https://github.com/ColfaxResearch/cutlass-kernels/blob/a9de6446c1c0415c926025cea284210c799b11f8/src/fmha-pipeline/reg2reg.h#L45
void
convertMMAV3To8BitsDotOperand(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto dstTy = op.getType();
auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
SmallVector<Value> retVals;
for (int i = 0; i < vals.size(); i += 8) {
Value upper = undef(vec_ty(i8_ty, 4));
for (int j = 0; j < 4; j++) {
upper =
insert_element(vec_ty(i8_ty, 4), upper, vals[i + j], i32_val(j));
}
upper = bitcast(upper, i32_ty);
Value lower = undef(vec_ty(i8_ty, 4));
for (int j = 0; j < 4; j++) {
lower = insert_element(vec_ty(i8_ty, 4), lower, vals[i + 4 + j],
i32_val(j));
}
lower = bitcast(lower, i32_ty);

Value threadIdMod4 = urem(getThreadId(rewriter, loc), i32_val(4));
Value cnd = or_(icmp_eq(threadIdMod4, i32_val(0)),
icmp_eq(threadIdMod4, i32_val(3)));
Value selectorEx0 = select(cnd, i32_val(0x3210), i32_val(0x7654));
Value selectorEx1 = select(cnd, i32_val(0x7654), i32_val(0x3210));
Value selectorEx4 = select(cnd, i32_val(0x5410), i32_val(0x1054));
Value selectorEx5 = select(cnd, i32_val(0x7632), i32_val(0x3276));

Value isOne = icmp_eq(threadIdMod4, i32_val(1));
Value isTwo = icmp_eq(threadIdMod4, i32_val(2));
Value isThree = icmp_eq(threadIdMod4, i32_val(3));
Value upperIdx = i32_val(0);
upperIdx = select(isOne, i32_val(3), upperIdx);
upperIdx = select(isTwo, i32_val(1), upperIdx);
upperIdx = select(isThree, i32_val(2), upperIdx);

Value lowerIdx = i32_val(1);
lowerIdx = select(isOne, i32_val(2), lowerIdx);
lowerIdx = select(isTwo, i32_val(0), lowerIdx);
lowerIdx = select(isThree, i32_val(3), lowerIdx);

Value upper0 =
LLVM::NVIDIA::permute(loc, rewriter, upper, lower, selectorEx0);
Value lower0 =
LLVM::NVIDIA::permute(loc, rewriter, upper, lower, selectorEx1);
Value mask = i32_val(0xFFFFFFFF);
// Set clamp tp shuffle only within 4 lanes.
Value clamp = i32_val(0x1C1F);
upper0 =
rewriter.create<NVVM::ShflOp>(loc, i32_ty, mask, upper0, upperIdx,
clamp, NVVM::ShflKind::idx, UnitAttr());
lower0 =
rewriter.create<NVVM::ShflOp>(loc, i32_ty, mask, lower0, lowerIdx,
clamp, NVVM::ShflKind::idx, UnitAttr());
Value upper1 =
LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx4);
Value vecVal = bitcast(upper1, vec_ty(i8_ty, 4));
for (int i = 0; i < 4; i++) {
retVals.push_back(extract_element(i8_ty, vecVal, i32_val(i)));
}
Value lower1 =
LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx5);
vecVal = bitcast(lower1, vec_ty(i8_ty, 4));
for (int i = 0; i < 4; i++) {
retVals.push_back(extract_element(i8_ty, vecVal, i32_val(i)));
}
}
Value result =
packLLElements(loc, getTypeConverter(), retVals, rewriter, dstTy);
rewriter.replaceOp(op, result);
}

// mma -> dot_operand
LogicalResult
lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
Expand All @@ -592,7 +668,13 @@ struct ConvertLayoutOpConversion
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) {
rewriter.replaceOp(op, adaptor.getSrc());
if (srcTy.getElementType().getIntOrFloatBitWidth() == 16) {
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}
assert(srcTy.getElementType().getIntOrFloatBitWidth() == 8 &&
"Unsupported type size.");
convertMMAV3To8BitsDotOperand(op, adaptor, rewriter);
return success();
}

Expand Down
18 changes: 16 additions & 2 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ namespace LLVM {
namespace NVIDIA {
using namespace mlir::triton;

Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter,
Value val, Value i, NVVM::ShflKind mode, Value clamp) {
static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter,
Value val, Value i, NVVM::ShflKind mode,
Value clamp) {
unsigned bits = val.getType().getIntOrFloatBitWidth();

if (bits == 64) {
Expand Down Expand Up @@ -90,6 +91,19 @@ Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr) {
Value val = builder.launch(b, loc, b.getIntegerType(32), false);
return val;
}

Value permute(Location loc, ConversionPatternRewriter &rewriter, Value a,
Value b, Value mask) {
PTXBuilder builder;
auto &prmt = builder.create("prmt")->o("b32");
auto *destOpr = builder.newOperand("=r");
auto *aOperand = builder.newOperand(a, "r");
auto *bOperand = builder.newOperand(b, "r");
auto *maskOperand = builder.newOperand(mask, "r");
prmt(destOpr, aOperand, bOperand, maskOperand);
return builder.launch(rewriter, loc, rewriter.getIntegerType(32), false);
}

} // namespace NVIDIA
} // namespace LLVM
} // namespace mlir
2 changes: 2 additions & 0 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val,
int i);
Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val,
Value i);
Value permute(Location loc, ConversionPatternRewriter &rewriter, Value a,
Value b, Value mask);

Value llGetPid(Location loc, ConversionPatternRewriter &rewriter,
ModuleOp moduleOp, int axis);
Expand Down