Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2624,10 +2624,10 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
tileSize.push_back(1);
}
// warpSizeK * (warpRepK * VecBitWidth)
auto tileBitWidthK = bitwidth == 64 ? (2 * 256) : (4 * 64);
auto tileBitWidthK = bitwidth == 64 ? (1 * 256) : (4 * 64);
if (opIdx == 0) {
// m x k
tileSize.push_back(16);
tileSize.push_back(bitwidth == 64 ? 8 : 16);
tileSize.push_back(tileBitWidthK / bitwidth);
} else {
// k x n
Expand Down
9 changes: 6 additions & 3 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,13 +943,16 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
MLIRContext *ctx = mma.getContext();

SmallVector<unsigned> tileShape(rank, 1);
unsigned instrM = mma.getInstrShape()[rank - 2];
// For fp64 (instrM == 8), the native m8n8k4 instruction uses a smaller tile.
unsigned kTileMultiplier = instrM == 8 ? 4 : 8;
if (isA) {
tileShape[rank - 2] = 16;
tileShape[rank - 1] = kWidth * 8;
tileShape[rank - 2] = instrM;
tileShape[rank - 1] = kWidth * kTileMultiplier;
} else {
// Hopper takes the rhs via shared memory
assert(mma.isAmpere());
tileShape[rank - 2] = kWidth * 8;
tileShape[rank - 2] = kWidth * kTileMultiplier;
tileShape[rank - 1] = 8;
}
auto order = getOrderForDotOperand(dot.getOpIdx(), rank, /*kContig*/ true);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
auto rank = shape.size();
SmallVector<unsigned, 3> ret(rank, 1);
ret[rank - 1] = 8;
ret[rank - 2] = 16;
ret[rank - 2] = eltType.isF64() ? 8 : 16;
return ret;
} else if (version == 3) {
unsigned k = 256 / eltType.getIntOrFloatBitWidth();
Expand Down
12 changes: 10 additions & 2 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3147,6 +3147,12 @@ def get_test_dot_base_cases():
if not (input_precision != 'ieee' and (in_dtype in ['float16']))]


# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size
def get_test_dot_small_fp64_cases():
return [(*shape, 1, False, False, 'none', 'ieee', 'float64', 'float64', 1, None)
for shape in [(8, 8, 4), (8, 8, 8), (16, 8, 4), (8, 8, 16)]]


# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size
def get_test_dot_softmax():
return [(128, 128, 64, 8, False, False, 'softmax', 'ieee', 'float16', 'float32', 1, None)]
Expand Down Expand Up @@ -3275,7 +3281,8 @@ def get_test_small_dots_cases():
get_test_dot_small_mn_wmma_cases() + \
get_test_dot_small_k_wmma_cases() + \
get_test_dot_softmax() + \
get_test_small_dots_cases())
get_test_small_dots_cases() + \
get_test_dot_small_fp64_cases())
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size,
num_ctas, device):
Expand All @@ -3286,7 +3293,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
pytest.skip(f"input_precision {input_precision} is not supported in the interpreter")
else:
if not is_hip() and K < 16:
pytest.skip("small dots are supported only on HIP at the moment")
if in_dtype != 'float64':
pytest.skip("small dots are supported only on HIP at the moment")
if is_cuda():
capability = torch.cuda.get_device_capability()

Expand Down
11 changes: 7 additions & 4 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2022,7 +2022,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 8]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [8, 8]}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 2, order = [1, 0]}>
#smem = #ttg.shared_memory
Expand Down Expand Up @@ -2319,18 +2319,21 @@ tt.func @gather_in_shared(%arg0: tensor<16x4xi32, #blocked1>, %arg1: tensor<8x4x

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [1, 1]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [8, 8]}>
#dot = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}>
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

tt.func @gather_in_shared_dot_input(%arg0: tensor<16x4xi32, #blocked>, %arg1: tensor<8x4xf32, #dot>) {
// CHECK-LABEL: gather_in_shared_dot_input

// CHECK: [[S0:%.*]] = llvm.extractvalue %arg1[0]

// CHECK: [[SMEM_BASE:%.*]] = llvm.mlir.addressof @global_smem
// CHECK-NEXT: [[SMEM:%.*]] = llvm.getelementptr [[SMEM_BASE]]
// CHECK-COUNT-4: store
// CHECK-NEXT: nvvm.barrier0
// CHECK: insertelement [[S0]]
// CHECK: nvvm.barrier0

// CHECK: [[I0:%.*]] = llvm.extractvalue %arg0[0]

Expand Down
4 changes: 2 additions & 2 deletions test/TritonGPU/accelerate-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ module attributes {"ttg.target" = "cuda:89", "ttg.num-ctas" = 1 : i32, "ttg.num-

// -----

// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [8, 8]}>
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @fp64_dot(
Expand All @@ -136,7 +136,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ

// -----

// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [8, 8]}>
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @fp64_dot_hopper(
Expand Down
2 changes: 2 additions & 0 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m,
# For small M/N the input we can still use tensorcores with padding.
if lhs_bitwidth == 8:
return (1, 1, 32)
elif lhs_bitwidth == 64:
return (1, 1, 4)
else:
return (1, 1, 16)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ static Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) {
Type fp32Ty = type::f32Ty(ctx);
Type fp16Ty = type::f16Ty(ctx);
Type i32Ty = type::i32Ty(ctx);
Type fp64x4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp64Ty));
Type fp64x2Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(2, fp64Ty));
Type fp32x4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp32Ty));
Type i32x4Ty =
Expand Down Expand Up @@ -337,7 +337,7 @@ static Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) {
case TensorCoreType::INT32_INT8_INT8_INT32:
return i32x4Ty;
case TensorCoreType::FP64_FP64_FP64_FP64:
return fp64x4Ty;
return fp64x2Ty;
case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32_SCALE_VEC_1X:
case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32_SCALE_VEC_1X:
case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32_SCALE_VEC_1X:
Expand Down Expand Up @@ -588,42 +588,44 @@ static void callMmaTuringFp16(PTXBuilder &builder, int b,
mma(retArgs, aArgs2, bArgs2, cArgs);
}

// Repeat m8n8k4 (2, 1, 4) times, as m16n8k16 on hopper.
// Emit m8n8k4 fp64 MMA instructions.
// With numRegisters.m=1, numRegisters.k=1: emits a single m8n8k4.
// With numRegisters.m=2, numRegisters.k=2: emits 2*2=4 m8n8k4 grouped as
// m16n8k8
static void callMmaAmpereFp64(PTXBuilder &builder, int b,
const BaseOffset &base,
mlir::triton::PTXInstr &mma, unsigned numMmaRets,
unsigned colsPerThread, int numCPackedElem,
unsigned batchOffset, ValueTableV2 &ha,
ValueTableV2 &hb, const SmallVector<Value> &fc,
int kRegs) {
auto retArgs1 = builder.newListOperand(numMmaRets / 2, "=d");
auto retArgs2 = builder.newListOperand(numMmaRets / 2, "=d");
auto cArgs1 = builder.newListOperand();
for (int i = 0; i < numMmaRets / 2; ++i) {
cArgs1->listAppend(builder.newOperand(
fc[(base.m * colsPerThread + 4 * base.n) / numCPackedElem + i +
batchOffset * b],
std::to_string(i)));
// reuse the output registers
}
auto cArgs2 = builder.newListOperand();
for (int i = numMmaRets / 2; i < numMmaRets; ++i) {
cArgs2->listAppend(builder.newOperand(
fc[(base.m * colsPerThread + 4 * base.n) / numCPackedElem + i +
batchOffset * b],
std::to_string(i)));
// reuse the output registers
int kRegs, int mRegs) {
// Each m sub-tile gets numMmaRets/mRegs results (2 f64 values per m8n8k4).
int retsPerM = numMmaRets / mRegs;

// Build ret/c operand lists for each m sub-tile.
SmallVector<PTXBuilder::Operand *> retArgsList, cArgsList;
for (int vm = 0; vm < mRegs; ++vm) {
auto *retArgs = builder.newListOperand(retsPerM, "=d");
auto *cArgs = builder.newListOperand();
for (int i = 0; i < retsPerM; ++i) {
cArgs->listAppend(
builder.newOperand(fc[((base.m + vm) * colsPerThread +
numMmaRets * numCPackedElem * base.n) /
numCPackedElem +
i + batchOffset * b],
std::to_string(i)));
}
retArgsList.push_back(retArgs);
cArgsList.push_back(cArgs);
}

for (int vk = 0; vk < kRegs; ++vk) {
auto aArgs1 = builder.newListOperand({
{ha[{b, base.m, base.k + vk}], "d"},
});
auto bArgs = builder.newListOperand({{hb[{b, base.n, base.k + vk}], "d"}});
auto aArgs2 = builder.newListOperand({
{ha[{b, base.m + 1, base.k + vk}], "d"},
});
mma(retArgs1, aArgs1, bArgs, cArgs1);
mma(retArgs2, aArgs2, bArgs, cArgs2);
for (int vm = 0; vm < mRegs; ++vm) {
auto aArgs =
builder.newListOperand({{ha[{b, base.m + vm, base.k + vk}], "d"}});
mma(retArgsList[vm], aArgs, bArgs, cArgsList[vm]);
}
}
}

Expand Down Expand Up @@ -776,8 +778,8 @@ convertMMAImpl(DotOpInterface op, Value llvmA, Value llvmB, Value llvmC,
auto fc = unpackLLElements(loc, loadedC, rewriter);

int bitwidthRet = dTensorTy.getElementType().getIntOrFloatBitWidth();
auto numMmaRets = bitwidthRet == 64 ? 4 : bitwidthRet / 8;
int numCPackedElem = 4 / numMmaRets;
auto numMmaRets = bitwidthRet == 64 ? 2 : bitwidthRet / 8;
int numCPackedElem = bitwidthRet == 64 ? 1 : 4 / numMmaRets;

if (mmaInstructions.find(mmaType) == mmaInstructions.end()) {
return emitError(loc, "Unsupported MMA instruction for the given mma type");
Expand All @@ -801,7 +803,7 @@ convertMMAImpl(DotOpInterface op, Value llvmA, Value llvmB, Value llvmC,
Type elemTy = cast<LLVM::LLVMStructType>(mmaOut.getType()).getBody()[0];
for (int i = 0; i < numMmaRets; ++i) {
fc[(numRegisters.m * static_cast<int>(m) * colsPerThread +
4 * numRegisters.n * static_cast<int>(n)) /
numMmaRets * numCPackedElem * numRegisters.n * static_cast<int>(n)) /
numCPackedElem +
i + batchOffset * b] = tb.extract_val(elemTy, mmaOut, i);
}
Expand Down Expand Up @@ -845,7 +847,9 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
auto dTensorTy = op.getD().getType();

TensorCoreType mmaType = getMmaTypeDot(op, aTensorTy, bTensorTy, dTensorTy);
NumRegisters numRegisters = {2, 1, 2};
NumRegisters numRegisters = (mmaType == TensorCoreType::FP64_FP64_FP64_FP64)
? NumRegisters{1, 1, 1}
: NumRegisters{2, 1, 2};

const auto &instrMap = isTuring ? mmaInstrPtxTuring : mmaInstrPtxAmpere;
EmitMmaCallback emit = [&](PTXBuilder &builder, int b, int m, int n, int k,
Expand All @@ -854,10 +858,10 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
ValueTableV2 &ha, ValueTableV2 &hb,
const SmallVector<Value> &fc, RankedTensorType dTy,
int /*repK*/) {
const unsigned numCPackedElem = 4u / numMmaRets;
bool isIntMMA = dTy.getElementType().isInteger(32);
bool isAccF16 = dTy.getElementType().isF16();
bool isFp64MMA = dTy.getElementType().isF64();
const unsigned numCPackedElem = isFp64MMA ? 1u : 4u / numMmaRets;
BaseOffset base{numRegisters.m * m, numRegisters.n * n, numRegisters.k * k};
if (isTuring) {
assert(b == 0 && "Turing only supports batch size 1");
Expand All @@ -871,7 +875,7 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
if (isFp64MMA) {
callMmaAmpereFp64(builder, b, base, mma, numMmaRets, colsPerThread,
numCPackedElem, batchOffset, ha, hb, fc,
numRegisters.k);
numRegisters.k, numRegisters.m);
} else {
callMmaV2(builder, b, base, mma, numMmaRets, colsPerThread,
numCPackedElem, batchOffset, ha, hb, fc,
Expand Down
Loading