Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class SwizzledSharedEncodingAttr;
// Version = 3: <m, n, k>
SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
const ArrayRef<int64_t> &shape,
Type type, int numWarps);
Type type, int numWarps,
int computeCapability = 80);

// Gets the order of a tensor from its contiguity. Places the dimensions with
// the largest contiguity as the inner most dimension. If the contiguity is
Expand Down
10 changes: 9 additions & 1 deletion lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2783,9 +2783,17 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
}
// warpSizeK * (warpRepK * VecBitWidth)
auto tileBitWidthK = bitwidth == 64 ? (1 * 256) : (4 * 64);
// FP64: instrShape[M] disambiguates m16n8k4 (M=16) from legacy m8n8k4 (M=8).
unsigned mTile = 16;
if (bitwidth == 64) {
auto instrShape = getInstrShape();
unsigned r = shape.size();
if (instrShape[r - 2] == 8)
mTile = 8;
}
if (opIdx == 0) {
// m x k
tileSize.push_back(bitwidth == 64 ? 8 : 16);
tileSize.push_back(mTile);
tileSize.push_back(tileBitWidthK / bitwidth);
} else {
// k x n
Expand Down
23 changes: 16 additions & 7 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -986,8 +986,9 @@ NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

SmallVector<unsigned> tileShape;
if (isAmpere()) {
// Ampere.getInstrShape() returns the tile shape
tileShape = SmallVector<unsigned>(getInstrShape());
// FP64 instrShape may carry a trailing K-dim; accumulator tile is M×N only.
auto instr = getInstrShape();
tileShape = SmallVector<unsigned>(instr.take_front(rank));
} else {
assert(isHopper());
auto instrShapeMNK = getInstrShape();
Expand All @@ -1014,16 +1015,24 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
MLIRContext *ctx = mma.getContext();

SmallVector<unsigned> tileShape(rank, 1);
unsigned instrM = mma.getInstrShape()[rank - 2];
// For fp64 (instrM == 8), the native m8n8k4 instruction uses a smaller tile.
unsigned kTileMultiplier = instrM == 8 ? 4 : 8;
auto instrShape = mma.getInstrShape();
unsigned instrM = instrShape[rank - 2];
// K-tile: FP64 m16-family stores K explicitly in instrShape; legacy
// m8n8k4 has K=4 implicit; everything else uses kWidth*8.
unsigned kTile;
if (instrShape.size() > static_cast<size_t>(rank))
kTile = instrShape.back();
else if (instrM == 8)
kTile = 4;
else
kTile = kWidth * 8;
if (isA) {
tileShape[rank - 2] = instrM;
tileShape[rank - 1] = kWidth * kTileMultiplier;
tileShape[rank - 1] = kTile;
} else {
// Hopper takes the rhs via shared memory
assert(mma.isAmpere());
tileShape[rank - 2] = kWidth * kTileMultiplier;
tileShape[rank - 2] = kTile;
tileShape[rank - 1] = 8;
}
auto order = getOrderForDotOperand(dot.getOpIdx(), rank, /*kContig*/ true);
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ static MMAEncodingResult createMMAEncodingForDot(DotOpInterface dotOp,
auto CGALayout = getCGALayout(oldRetType.getEncoding());
auto retShapePerCTA = getShapePerCTA(oldRetType);
auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA,
oldAType.getElementType(), numWarps);
oldAType.getElementType(), numWarps,
computeCapability);
auto warpsPerTile = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor,
numWarps, instrShape);

Expand Down
23 changes: 20 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,33 @@ namespace mlir {

using namespace triton;

// FP64 m16-family K-tile selector. Returns 0 on sm_80 (falls back to legacy
// m8n8k4). Extend here for supporting K=8/16 — instrShape carries the
// chosen K downstream.
static unsigned pickFp64MmaK(int computeCapability) {
if (computeCapability < 90)
return 0;
return 4;
}

SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
const ArrayRef<int64_t> &shape,
Type eltType, int numWarps) {
Type eltType, int numWarps,
int computeCapability) {
if (version == 1)
return {16, 16};
else if (version == 2) {
auto rank = shape.size();
SmallVector<unsigned, 3> ret(rank, 1);
// FP64 sm_90+: m16n8kK with K stored as a trailing instrShape elt.
// FP64 sm_80: legacy m8n8k4 (2-elt instrShape, K=4 implicit).
bool isF64 = eltType.isF64();
unsigned f64K = isF64 ? pickFp64MmaK(computeCapability) : 0;
bool isF64M16 = f64K != 0;
SmallVector<unsigned, 3> ret(rank + (isF64M16 ? 1u : 0u), 1);
ret[rank - 1] = 8;
ret[rank - 2] = eltType.isF64() ? 8 : 16;
ret[rank - 2] = isF64 && !isF64M16 ? 8 : 16;
if (isF64M16)
ret[rank] = f64K;
return ret;
} else if (version == 3) {
unsigned k = 256 / eltType.getIntOrFloatBitWidth();
Expand Down
3 changes: 2 additions & 1 deletion test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,8 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 8]}>
// Legacy m8n8k4: instrShape=[8, 8] (default FP64 path uses m16n8k4 — see below).
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [8, 8]}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 2, order = [1, 0]}>
#smem = #ttg.shared_memory
Expand Down
3 changes: 2 additions & 1 deletion test/TritonGPU/accelerate-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ module attributes {"ttg.target" = "cuda:89", "ttg.num-ctas" = 1 : i32, "ttg.num-

// -----

// sm_80: legacy m8n8k4 (m16n8k4.f64 needs sm_90+).
// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [8, 8]}>
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
Expand All @@ -136,7 +137,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ

// -----

// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [8, 8]}>
// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8, 4]}>
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @fp64_dot_hopper(
Expand Down
111 changes: 86 additions & 25 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ enum class TensorCoreType : uint8_t {
INT32_INT4_INT4_INT32, // Not implemented
INT32_INT8_INT8_INT32, // Not implemented
// double precision tensor core instr
FP64_FP64_FP64_FP64,
FP64_FP64_FP64_FP64, // m16n8k4.f64 (Phase 1+)
FP64_FP64_FP64_FP64_M8, // m8n8k4.f64 (legacy, used when instrShape[M] = 8)
// scaled mxfp8 x mxfp8 matmul
FP32_FP8E5M2_FP8E5M2_FP32_SCALE_VEC_1X,
FP32_FP8E5M2_FP8E4M3FN_FP32_SCALE_VEC_1X,
Expand All @@ -309,6 +310,8 @@ static Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) {
Type i32Ty = type::i32Ty(ctx);
Type fp64x2Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(2, fp64Ty));
Type fp64x4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp64Ty));
Type fp32x4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp32Ty));
Type i32x4Ty =
Expand Down Expand Up @@ -337,6 +340,8 @@ static Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) {
case TensorCoreType::INT32_INT8_INT8_INT32:
return i32x4Ty;
case TensorCoreType::FP64_FP64_FP64_FP64:
return fp64x4Ty;
case TensorCoreType::FP64_FP64_FP64_FP64_M8:
return fp64x2Ty;
case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32_SCALE_VEC_1X:
case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32_SCALE_VEC_1X:
Expand Down Expand Up @@ -427,8 +432,18 @@ static TensorCoreType getMmaTypeDot(DotOp op, RankedTensorType aTy,
llvm::isa<Float8E4M3FNType>(bTy.getElementType()))
return TensorCoreType::FP16_FP8E4M3FN_FP8E4M3FN_FP16;
} else if (dTy.getElementType().isF64()) {
if (aTy.getElementType().isF64() && bTy.getElementType().isF64())
if (aTy.getElementType().isF64() && bTy.getElementType().isF64()) {
// instrShape[M] selects m16n8k4 (M=16) vs legacy m8n8k4 (M=8).
auto mmaEnc = dyn_cast<NvidiaMmaEncodingAttr>(dTy.getEncoding());
if (mmaEnc) {
auto instrShape = mmaEnc.getInstrShape();
unsigned rank = dTy.getRank();
unsigned instrM = instrShape[rank - 2];
if (instrM == 8)
return TensorCoreType::FP64_FP64_FP64_FP64_M8;
}
return TensorCoreType::FP64_FP64_FP64_FP64;
}
}

return TensorCoreType::NOT_APPLICABLE;
Expand Down Expand Up @@ -482,6 +497,8 @@ inline static const std::map<TensorCoreType, std::string> mmaInstrPtxAmpere = {
"mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16"},

{TensorCoreType::FP64_FP64_FP64_FP64,
"mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64"},
{TensorCoreType::FP64_FP64_FP64_FP64_M8,
"mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64"},
};

Expand Down Expand Up @@ -588,21 +605,15 @@ static void callMmaTuringFp16(PTXBuilder &builder, int b,
mma(retArgs, aArgs2, bArgs2, cArgs);
}

// Emit m8n8k4 fp64 MMA instructions.
// With numRegisters.m=1, numRegisters.k=1: emits a single m8n8k4.
// With numRegisters.m=2, numRegisters.k=2: emits 2*2=4 m8n8k4 grouped as
// m16n8k8
static void callMmaAmpereFp64(PTXBuilder &builder, int b,
const BaseOffset &base,
mlir::triton::PTXInstr &mma, unsigned numMmaRets,
unsigned colsPerThread, int numCPackedElem,
unsigned batchOffset, ValueTableV2 &ha,
ValueTableV2 &hb, const SmallVector<Value> &fc,
int kRegs, int mRegs) {
// Each m sub-tile gets numMmaRets/mRegs results (2 f64 values per m8n8k4).
// Legacy m8n8k4.f64 emit (sm_80 fallback; selected when instrShape=[8, 8]).
static void
callMmaAmpereFp64M8K4(PTXBuilder &builder, int b, const BaseOffset &base,
mlir::triton::PTXInstr &mma, unsigned numMmaRets,
unsigned colsPerThread, int numCPackedElem,
unsigned batchOffset, ValueTableV2 &ha, ValueTableV2 &hb,
const SmallVector<Value> &fc, int kRegs, int mRegs) {
int retsPerM = numMmaRets / mRegs;

// Build ret/c operand lists for each m sub-tile.
SmallVector<PTXBuilder::Operand *> retArgsList, cArgsList;
for (int vm = 0; vm < mRegs; ++vm) {
auto *retArgs = builder.newListOperand(retsPerM, "=d");
Expand All @@ -629,6 +640,39 @@ static void callMmaAmpereFp64(PTXBuilder &builder, int b,
}
}

// Emit one m16n8k4.f64 per call. A=2 regs/thread (gid, gid+8), B=1 reg,
// C=4 regs. The two A regs come in as adjacent ha[] entries.
// TODO(fp64 m16 family): extend for m16n8k{8,16} — A/B reg counts grow
// with kWidth. K-selector to extend: pickFp64MmaK in
// TritonGPU/Transforms/Utility.cpp.
static void callMmaAmpereFp64M16K4(PTXBuilder &builder, int b,
const BaseOffset &base,
mlir::triton::PTXInstr &mma,
unsigned numMmaRets, unsigned colsPerThread,
int numCPackedElem, unsigned batchOffset,
ValueTableV2 &ha, ValueTableV2 &hb,
const SmallVector<Value> &fc) {
assert(numMmaRets == 4 && "m16n8k4.f64 produces 4 f64 outputs/thread");

auto *retArgs = builder.newListOperand(numMmaRets, "=d");
auto *cArgs = builder.newListOperand();
unsigned cBase =
(base.m * colsPerThread + numMmaRets * numCPackedElem * base.n) /
numCPackedElem +
batchOffset * b;
for (unsigned i = 0; i < numMmaRets; ++i) {
cArgs->listAppend(builder.newOperand(fc[cBase + i], std::to_string(i)));
}

auto *aArgs = builder.newListOperand({
{ha[{b, base.m + 0, base.k}], "d"},
{ha[{b, base.m + 1, base.k}], "d"},
});
auto *bArgs = builder.newListOperand({{hb[{b, base.n, base.k}], "d"}});

mma(retArgs, aArgs, bArgs, cArgs);
}

// Unified MMAV2 function for Ampere and HopperF64 architectures
static void callMmaV2(PTXBuilder &builder, int b, const BaseOffset &base,
mlir::triton::PTXInstr &mma, unsigned numMmaRets,
Expand Down Expand Up @@ -778,7 +822,12 @@ convertMMAImpl(DotOpInterface op, Value llvmA, Value llvmB, Value llvmC,
auto fc = unpackLLElements(loc, loadedC, rewriter);

int bitwidthRet = dTensorTy.getElementType().getIntOrFloatBitWidth();
auto numMmaRets = bitwidthRet == 64 ? 2 : bitwidthRet / 8;
// f64 m16n8k4 returns 4 f64 outputs/thread; legacy m8n8k4 returns 2.
unsigned numMmaRets;
if (bitwidthRet == 64)
numMmaRets = (mmaType == TensorCoreType::FP64_FP64_FP64_FP64_M8) ? 2 : 4;
else
numMmaRets = bitwidthRet / 8;
int numCPackedElem = bitwidthRet == 64 ? 1 : 4 / numMmaRets;

auto rank = dTensorTy.getRank();
Expand Down Expand Up @@ -849,19 +898,28 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
return op.emitError(
"unsupported MMA instruction for the given operand/result types");

NumRegisters numRegisters = (mmaType == TensorCoreType::FP64_FP64_FP64_FP64)
? NumRegisters{1, 1, 1}
: NumRegisters{2, 1, 2};
// f64 m16n8k4: A=2 regs (gid, gid+8). f64 m8n8k4 legacy: 1 A reg.
NumRegisters numRegisters;
if (mmaType == TensorCoreType::FP64_FP64_FP64_FP64)
numRegisters = NumRegisters{2, 1, 1};
else if (mmaType == TensorCoreType::FP64_FP64_FP64_FP64_M8)
numRegisters = NumRegisters{1, 1, 1};
else
numRegisters = NumRegisters{2, 1, 2};

EmitMmaCallback emit = [&](PTXBuilder &builder, int b, int m, int n, int k,
bool isFp64M16 = mmaType == TensorCoreType::FP64_FP64_FP64_FP64;
bool isFp64M8 = mmaType == TensorCoreType::FP64_FP64_FP64_FP64_M8;

EmitMmaCallback emit = [&, isFp64M16, isFp64M8](
PTXBuilder &builder, int b, int m, int n, int k,
mlir::triton::PTXInstr &mma, unsigned numMmaRets,
unsigned colsPerThread, unsigned batchOffset,
ValueTableV2 &ha, ValueTableV2 &hb,
const SmallVector<Value> &fc, RankedTensorType dTy,
int /*repK*/) {
bool isIntMMA = dTy.getElementType().isInteger(32);
bool isAccF16 = dTy.getElementType().isF16();
bool isFp64MMA = dTy.getElementType().isF64();
bool isFp64MMA = isFp64M16 || isFp64M8;
const unsigned numCPackedElem = isFp64MMA ? 1u : 4u / numMmaRets;
BaseOffset base{numRegisters.m * m, numRegisters.n * n, numRegisters.k * k};
if (isTuring) {
Expand All @@ -873,10 +931,13 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
callMmaTuringFp16(builder, b, base, mma, numMmaRets, colsPerThread,
numCPackedElem, ha, hb, fc, isAccF16);
} else {
if (isFp64MMA) {
callMmaAmpereFp64(builder, b, base, mma, numMmaRets, colsPerThread,
numCPackedElem, batchOffset, ha, hb, fc,
numRegisters.k, numRegisters.m);
if (isFp64M16) {
callMmaAmpereFp64M16K4(builder, b, base, mma, numMmaRets, colsPerThread,
numCPackedElem, batchOffset, ha, hb, fc);
} else if (isFp64M8) {
callMmaAmpereFp64M8K4(builder, b, base, mma, numMmaRets, colsPerThread,
numCPackedElem, batchOffset, ha, hb, fc,
numRegisters.k, numRegisters.m);
} else {
callMmaV2(builder, b, base, mma, numMmaRets, colsPerThread,
numCPackedElem, batchOffset, ha, hb, fc,
Expand Down