diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 088978607081..c8512fce57fa 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -781,22 +781,24 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { InterfaceMethod<"Return shape per CTA.", "SmallVector", - "getShapePerCTATileForDotOperands", + "getShapePerCTATileForOperand", (ins "ArrayRef":$tensorShape, - "unsigned":$opIdx)>, + "int":$kWidth, + "int":$opIdx)>, InterfaceMethod<"Return total element size per thread for dot operands.", "unsigned", - "getTotalElemsPerThreadForOperands", + "getTotalElemsPerThreadForOperand", (ins "ArrayRef":$tensorShape, "Type":$eltTy, - "unsigned":$kWidth, - "unsigned":$opIdx)>, + "int":$kWidth, + "int":$opIdx)>, InterfaceMethod<"Return size per thread for dot operands.", "SmallVector", - "getSizePerThreadForOperands", - (ins "unsigned":$opIdx)>, + "getSizePerThreadForOperand", + (ins "int":$opIdx, + "int":$kWidth)>, ]; } @@ -914,11 +916,11 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129, bool supportReduction() const { return true; } - SmallVector getSizePerThreadForOperands(unsigned opIdx) const; - SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; - unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; - SmallVector getMFMAInstrShapeForOperands(int kWidth, int opIdx) const; - SmallVector getMFMARepForOperands(ArrayRef operandShape, int kWidth, int opIdx) const; + SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; + SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; + unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getInstrShapeForOperand(int kWidth, int opIdx) const; + SmallVector getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const; SmallVector getContigPerThread() { auto rank = getWarpsPerCTA().size(); @@ -1021,12 +1023,12 @@ Row | warp 0 warp 2 bool supportReduction() const { return true; } - SmallVector getSizePerThreadForOperands(unsigned opIdx) const; - SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; - unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; + SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; + unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getElemsPerInstrForOperands() const; - SmallVector getRepForOperands(ArrayRef operandShape, - Type elemType, int kWidth, int opIdx) const; + SmallVector getRepForOperand(ArrayRef operandShape, + Type elemType, int kWidth, int opIdx) const; static SmallVector getMNKDimPerInstr(); SmallVector getContigPerThread() { @@ -1222,8 +1224,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: SmallVector getMMAv1Rep(int opIdx) const; SmallVector getMMAv1ShapePerWarp(int opIdx) const; int getMMAv1Vec(int opIdx) const; - SmallVector getMMAv2Rep(ArrayRef shape, - int bitwidth, int opIdx) const; + SmallVector getMMAv2RepForOperand(ArrayRef shape, + int bitwidth, int kWidth, int opIdx) const; bool supportReduction() const { if (isAmpere() || isHopper()) { @@ -1231,9 +1233,9 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: } return false; }; - SmallVector getSizePerThreadForOperands(unsigned opIdx) const; - SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; - unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; + SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; + unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getContigPerThread() { assert(isVolta() || isAmpere() || isHopper()); @@ -1344,7 +1346,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim let genVerifyDecl = 1; let extraClassDeclaration = extraDistributedDeclaration # [{ SmallVector getContigPerThread() { - return getSizePerThread(); + auto rank = getWarpsPerCTA().size(); + assert(rank == 2 || rank == 3); + SmallVector contigPerThread(rank, 1); + auto kWidth = getKWidth(); + assert(kWidth != 0 && "Do not support kWidth=0"); + if (getOpIdx() == 0) + contigPerThread[rank - 1] = kWidth; + else + contigPerThread[rank - 2] = kWidth; + return contigPerThread; }; }]; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 60bfd56cb00f..80fe1aed29f4 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -244,11 +244,20 @@ SmallVector getWarpOrder(Attribute layout) { order.erase(it); order.insert(order.begin(), 0); } + } else if (auto dotOpLayout = dyn_cast(layout)) { + // opIdx=0: [/*dim0*/batch, /*dim1=*/m, /*dim2=*/k] -> order=[1, 2, 0] + // opIdx=1: [/*dim0*/batch, /*dim1=*/k, /*dim2=*/n] -> order=[2, 1, 0] + std::iota(order.rbegin(), order.rend(), 0); + if (dotOpLayout.getOpIdx() == 0) { + std::swap(order[0], order[1]); + } } return order; } SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank) { + assert((rank == 2 || rank == 3) && + "Invalid rank for dot operand order computation"); SmallVector order(rank); // The 'order' field typically represents a descending sorted array of // dimensions based on contiguity. For instance, in axisInfo utilities that @@ -257,14 +266,16 @@ SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank) { // // The relation between contiguity and order is only relevant if the layout // interfaces with HBM, as is the case when we load tensor from HBM to - // registers in the dot layout to bypass LDS. When bypassing LDS, we make the - // following assumptions about tensor layouts: + // registers in the dot layout to bypass LDS. When bypassing LDS, we make + // the following assumptions about tensor layouts: // - Tensor A (opIdx == 0) is considered to be row-major. // - Tensor B (opIdx == 1) is considered to be column-major. // // Based on these assumptions, we define the following orders: - // - For opIdx == 0, we assume an order of [1, 0]. - // - For opIdx == 1, we assume an order of [0, 1]. + // - For opIdx == 0, batch=dim0, m=dim1, and k=dim2, we assume an order of [2, + // 1, 0] for 3D tensors. + // - For opIdx == 1, batch=dim0, k=dim1, and n=dim2, we assume an order of [1, + // 2, 0] for 3D tensors. std::iota(order.rbegin(), order.rend(), 0); if (opIdx == 1) { std::swap(order[0], order[1]); @@ -285,13 +296,7 @@ SmallVector getOrder(Attribute layout) { } if (auto dotLayout = dyn_cast(layout)) { auto rank = getWarpsPerCTA(dotLayout.getParent()).size(); - SmallVector order(rank); - if (isa(dotLayout.getParent())) { - return getOrderForDotOperand(dotLayout.getOpIdx(), rank); - } else { - std::iota(order.rbegin(), order.rend(), 0); - } - return order; + return getOrderForDotOperand(dotLayout.getOpIdx(), rank); } if (auto sliceLayout = dyn_cast(layout)) { SmallVector parentOrder = getOrder(sliceLayout.getParent()); @@ -336,8 +341,6 @@ SmallVector getCTAsPerCGA(Attribute layout) { ArrayRef ref; if (auto distributedLayout = mlir::dyn_cast(layout)) return distributedLayout.getCTAsPerCGA(); - else if (mlir::isa(layout)) - return {1, 1}; else if (auto sharedLayout = mlir::dyn_cast(layout)) ref = sharedLayout.getCTALayout().getCTAsPerCGA(); else @@ -350,9 +353,6 @@ SmallVector getCTASplitNum(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { return distributedLayout.getCTASplitNum(); - } else if (mlir::isa(layout)) { - res.resize(2); - res[0] = res[1] = 1; } else if (auto sharedLayout = mlir::dyn_cast(layout)) { res.assign(sharedLayout.getCTALayout().getCTASplitNum().begin(), sharedLayout.getCTALayout().getCTASplitNum().end()); @@ -367,8 +367,6 @@ SmallVector getCTAOrder(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { res = distributedLayout.getCTAOrder(); - } else if (mlir::isa(layout)) { - return {0, 1}; } else if (auto sharedLayout = mlir::dyn_cast(layout)) { res = SmallVector(sharedLayout.getCTALayout().getCTAOrder()); } else { @@ -392,9 +390,9 @@ SmallVector getShapePerCTA(ArrayRef CTASplitNum, SmallVector getShapePerCTA(Attribute layout, ArrayRef shape) { if (auto sharedLayout = mlir::dyn_cast(layout)) { // Special logic for pipeline pass, where shape is 3D and CTALayout is 2D. - // The first dim of shape is numStages. This is a work around, otherwise too - // many places would have to be modified in pipeline pass. Maybe we need to - // refactor this logic in the future. + // The first dim of shape is numStages. This is a work around, otherwise + // too many places would have to be modified in pipeline pass. Maybe we + // need to refactor this logic in the future. auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); if (shape.size() == CTASplitNum.size() + 1) { auto res = getShapePerCTA(CTASplitNum, shape.drop_front()); @@ -417,7 +415,8 @@ unsigned getNumWarpsPerCTA(Attribute layout) { else if (auto sliceLayout = dyn_cast(layout)) return getNumWarpsPerCTA(sliceLayout.getParent()); else if (auto mmaLayout = dyn_cast(layout)) { - // Use the distributed layout interface to get the number of warps per CTA. + // Use the distributed layout interface to get the number of warps per + // CTA. auto distributedLayout = cast(layout); warpsPerCTA = distributedLayout.getWarpsPerCTA(); } else if (auto mfmaLayout = dyn_cast(layout)) @@ -451,9 +450,9 @@ bool hasDotOperandEncoding(Value value) { } bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { - // If the new elements per thread is less than the old one, we will need to do - // convert encoding that goes through shared memory anyway. So we consider it - // as expensive. + // If the new elements per thread is less than the old one, we will need to + // do convert encoding that goes through shared memory anyway. So we + // consider it as expensive. RankedTensorType tensorTy = cat.getType(); auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); auto shape = tensorTy.getShape(); @@ -974,7 +973,7 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, SmallVector elemsPerThread(rank); auto kWidth = getKWidth(); - auto rep = parent.getMFMARepForOperands(shape, kWidth, idx); + auto rep = parent.getRepForOperand(shape, kWidth, idx); if (rank == 3) elemsPerThread[0] = rep[0]; @@ -991,8 +990,8 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { if (auto mmaParent = mlir::dyn_cast(getParent())) { - return mmaParent.getTotalElemsPerThreadForOperands(shape, eltTy, - getKWidth(), getOpIdx()); + return mmaParent.getTotalElemsPerThreadForOperand(shape, eltTy, getKWidth(), + getOpIdx()); } if (auto blockedLayout = mlir::dyn_cast(getParent())) { auto shapePerCTA = getShapePerCTA(*this, shape); @@ -1060,8 +1059,8 @@ SmallVector DotOperandEncodingAttr::getShapePerCTATile( auto parentLayout = getParent(); assert(parentLayout && "DotOperandEncodingAttr must have a parent"); if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { - return parentMmaLayout.getShapePerCTATileForDotOperands(tensorShape, - getOpIdx()); + return parentMmaLayout.getShapePerCTATileForOperand( + tensorShape, getKWidth(), getOpIdx()); } else { llvm::report_fatal_error( "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " @@ -1646,7 +1645,7 @@ SmallVector AMDMfmaEncodingAttr::getSizePerThread() const { } SmallVector -AMDMfmaEncodingAttr::getMFMAInstrShapeForOperands(int kWidth, int opIdx) const { +AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const { unsigned mDim = getMDim(); unsigned nDim = getNDim(); assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) || @@ -1666,9 +1665,9 @@ AMDMfmaEncodingAttr::getMFMAInstrShapeForOperands(int kWidth, int opIdx) const { } SmallVector -AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef operandShape, - int kWidth, int opIdx) const { - auto operandTileShape = getMFMAInstrShapeForOperands(kWidth, opIdx); +AMDMfmaEncodingAttr::getRepForOperand(ArrayRef operandShape, + int kWidth, int opIdx) const { + auto operandTileShape = getInstrShapeForOperand(kWidth, opIdx); auto rank = operandShape.size(); auto warpsPerCTA = getWarpsPerCTA(); int numRepBatch = @@ -1689,27 +1688,31 @@ AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef operandShape, } } -unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperands( +unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { - auto rep = getMFMARepForOperands(shape, kWidth, opIdx); + auto rep = getRepForOperand(shape, kWidth, opIdx); return product(rep) * kWidth; } SmallVector -AMDMfmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { +AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { + auto rank = getWarpsPerCTA().size(); + auto sizePerThread = SmallVector(rank, 1); if (opIdx == 0) { - return {4, 1}; + sizePerThread[rank - 2] = 1; + sizePerThread[rank - 1] = kWidth; } else if (opIdx == 1) { - return {1, 4}; + sizePerThread[rank - 2] = kWidth; + sizePerThread[rank - 1] = 1; } else { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - return {}; } + return sizePerThread; } SmallVector -AMDMfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, - int opIdx) const { +AMDMfmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, + int kWidth, int opIdx) const { assert(getMDim() == 32 || getMDim() == 16); auto parentShapePerCTATile = getShapePerCTATile(shape); auto rank = parentShapePerCTATile.size(); @@ -1779,7 +1782,7 @@ SmallVector AMDWmmaEncodingAttr::getSizePerThread() const { return sizePerThread; } SmallVector -AMDWmmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { +AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { auto rank = getWarpsPerCTA().size(); SmallVector sizePerThread(rank, 1); auto numReplicated = getVersion() == 1 ? 2 : 1; @@ -1798,8 +1801,8 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { } SmallVector -AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, - int opIdx) const { +AMDWmmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, + int kWidth, int opIdx) const { auto parentShapePerCTA = getShapePerCTATile(shape); auto rank = shape.size(); assert(rank == 2); @@ -1812,9 +1815,9 @@ AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, } } -unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands( +unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { - auto rep = getRepForOperands(shape, eltTy, kWidth, opIdx); + auto rep = getRepForOperand(shape, eltTy, kWidth, opIdx); return product(rep) * kWidth; } @@ -1823,9 +1826,9 @@ SmallVector AMDWmmaEncodingAttr::getElemsPerInstrForOperands() const { } SmallVector -AMDWmmaEncodingAttr::getRepForOperands(ArrayRef operandShape, - Type elemType, int kWidth, - int opIdx) const { +AMDWmmaEncodingAttr::getRepForOperand(ArrayRef operandShape, + Type elemType, int kWidth, + int opIdx) const { auto operandTileShape = getElemsPerInstrForOperands(); assert(operandTileShape.size() == 2); auto warpsPerCTA = getWarpsPerCTA(); @@ -2016,9 +2019,8 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const { int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { return 2 * getMMAv1Rep(opIdx)[opIdx]; } -SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, - int bitwidth, - int opIdx) const { +SmallVector NvidiaMmaEncodingAttr::getMMAv2RepForOperand( + ArrayRef shape, int bitwidth, int kWidth, int opIdx) const { auto rank = shape.size(); auto warpsPerCTA = getWarpsPerCTA(); SmallVector shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth}; @@ -2041,7 +2043,7 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, warpsPerCTA[rank - 1]))}; } } -unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( +unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto shapePerCTA = getShapePerCTA(*this, shape); int warpsPerCTAM = getWarpsPerCTA()[0]; @@ -2052,7 +2054,8 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( } // A100 if (isAmpere()) { - auto rep = getMMAv2Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth(), opIdx); + auto rep = getMMAv2RepForOperand(shapePerCTA, eltTy.getIntOrFloatBitWidth(), + kWidth, opIdx); if (opIdx == 0) return 4 * rep[0] * rep[1] * rep[2]; if (opIdx == 1) @@ -2120,43 +2123,58 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( } llvm_unreachable("unknown mma layout"); } -SmallVector -NvidiaMmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, - int opIdx) const { +SmallVector NvidiaMmaEncodingAttr::getShapePerCTATileForOperand( + ArrayRef shape, int kWidth, int opIdx) const { assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); auto parentShapePerCTATile = getShapePerCTATile(shape); auto rank = parentShapePerCTATile.size(); + // 4 threads * 2 subtiles + unsigned kWidthTile = kWidth * 2 * 4; if (opIdx == 0) { if (rank == 2) - return {parentShapePerCTATile[rank - 2], 16}; + return {parentShapePerCTATile[rank - 2], kWidthTile}; else - return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 16}; + return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], + kWidthTile}; } else if (opIdx == 1) { if (rank == 2) - return {16, parentShapePerCTATile[rank - 1]}; + return {kWidthTile, parentShapePerCTATile[rank - 1]}; else - return {parentShapePerCTATile[0], 16, parentShapePerCTATile[rank - 1]}; + return {parentShapePerCTATile[0], kWidthTile, + parentShapePerCTATile[rank - 1]}; } else { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); } } SmallVector -NvidiaMmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { +NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); + auto rank = getWarpsPerCTA().size(); + auto sizePerThread = SmallVector(rank, 1); if (opIdx == 0) { - return {2, 4}; + sizePerThread[rank - 2] = 2; + sizePerThread[rank - 1] = 2 * kWidth; } else if (opIdx == 1) { - return {4, 1}; + sizePerThread[rank - 2] = 2 * kWidth; + sizePerThread[rank - 1] = 1; } else { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - return {}; } + return sizePerThread; } //===----------------------------------------------------------------------===// // DotOperand Encoding //===----------------------------------------------------------------------===// SmallVector DotOperandEncodingAttr::getThreadsPerWarp() const { + auto parent = getParent(); + if (auto mma = mlir::dyn_cast(parent)) { + auto threadsPerWarp = mma.getThreadsPerWarp(); + auto rank = threadsPerWarp.size(); + if (getOpIdx() == 1) + std::swap(threadsPerWarp[rank - 2], threadsPerWarp[rank - 1]); + return threadsPerWarp; + } llvm::report_fatal_error( "getThreadsPerWarp not implemented for DotOperandEncodingAttr"); } @@ -2164,7 +2182,7 @@ SmallVector DotOperandEncodingAttr::getSizePerThread() const { auto parentLayout = getParent(); assert(parentLayout && "DotOperandEncodingAttr must have a parent"); if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { - return parentMmaLayout.getSizePerThreadForOperands(getOpIdx()); + return parentMmaLayout.getSizePerThreadForOperand(getKWidth(), getOpIdx()); } else { llvm::report_fatal_error( "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index c0e774788e58..a689809fa3f4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -218,12 +218,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto elemTy = aTensorTy.getElementType(); auto kWidth = encoding.getKWidth(); - auto elemsPerInstr = mfmaLayout.getMFMAInstrShapeForOperands(kWidth, opIdx); + auto elemsPerInstr = mfmaLayout.getInstrShapeForOperand(kWidth, opIdx); int64_t mfmaInstrNonK; int64_t mfmaInstrK; // TODO(Lixun): make it simpler - // getMFMAInstrShapeForOperands always returns a 2D vector + // getInstrShapeForOperand always returns a 2D vector if (rank == 3) { mfmaInstrNonK = elemsPerInstr[nonKDimIdx - 1]; mfmaInstrK = elemsPerInstr[kDimIdx - 1]; @@ -232,12 +232,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, mfmaInstrK = elemsPerInstr[kDimIdx]; } - auto numReps = mfmaLayout.getMFMARepForOperands(shape, kWidth, opIdx); + auto numReps = mfmaLayout.getRepForOperand(shape, kWidth, opIdx); auto numRepNonK = numReps[nonKDimIdx]; auto numRepK = numReps[kDimIdx]; auto repB = numReps[0]; // TODO(Lixun): make it simpler - // getMFMARepForOperands always returns a 3D vector + // getRepForOperand always returns a 3D vector if (rank == 2) { numRepNonK = numReps[nonKDimIdx + 1]; numRepK = numReps[kDimIdx + 1]; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index 400bcd9f655e..1ca9e49745d6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -166,7 +166,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto wmmaInstrNonK = elemsPerInstr[opIdx == 0 ? 0 : 1]; assert(wmmaInstrNonK == 16); - auto numReps = wmmaLayout.getRepForOperands(shape, elemTy, kWidth, opIdx); + auto numReps = wmmaLayout.getRepForOperand(shape, elemTy, kWidth, opIdx); auto numRepNonK = numReps[opIdx == 0 ? 1 : 2]; auto numRepK = numReps[opIdx == 0 ? 2 : 1]; auto repB = numReps[0]; @@ -177,7 +177,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Value linearWaveId = udiv(thread, waveSize); unsigned numElemsPerThreadPerRep = - wmmaLayout.getSizePerThreadForOperands(opIdx)[kDimIdx]; + wmmaLayout.getSizePerThreadForOperand(kWidth, opIdx)[kDimIdx]; Value lane = urem(thread, waveSize); unsigned int maxNumWarps = shape[nonKDimIdx] / wmmaInstrNonK; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp index bdc25f0a8596..1eed112c30c0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -194,10 +194,8 @@ struct DotOpMFMAConversionHelper { int kWidth = aEncoding.getKWidth(); auto rank = aTensorTy.getShape().size(); - auto repA = - mfmaLayout.getMFMARepForOperands(aTensorTy.getShape(), kWidth, 0); - auto repB = - mfmaLayout.getMFMARepForOperands(bTensorTy.getShape(), kWidth, 1); + auto repA = mfmaLayout.getRepForOperand(aTensorTy.getShape(), kWidth, 0); + auto repB = mfmaLayout.getRepForOperand(bTensorTy.getShape(), kWidth, 1); assert(repA[2] == repB[1]); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index 9ed21fa00d2d..0042cf89e93b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -264,9 +264,9 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, int kWidth = aEncoding.getKWidth(); auto repA = - wmmaLayout.getRepForOperands(aTensorTy.getShape(), elemTy, kWidth, 0); + wmmaLayout.getRepForOperand(aTensorTy.getShape(), elemTy, kWidth, 0); auto repB = - wmmaLayout.getRepForOperands(bTensorTy.getShape(), elemTy, kWidth, 1); + wmmaLayout.getRepForOperand(bTensorTy.getShape(), elemTy, kWidth, 1); assert(repA[2] == repB[1]); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 21b74ecf99fa..a26a18ed96bc 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -630,14 +630,18 @@ class BlockedToWMMA : public RewritePattern { auto newAcc = convertAndCastTensor(rewriter, oldAcc, wmmaEnc, operandTypes[2]); - auto newAType = RankedTensorType::get( - aShape, operandTypes[0], - ttg::DotOperandEncodingAttr::get( - ctx, 0, wmmaEnc, wmmaEnc.getSizePerThreadForOperands(0)[rank - 1])); - auto newBType = RankedTensorType::get( - bShape, operandTypes[1], - ttg::DotOperandEncodingAttr::get( - ctx, 1, wmmaEnc, wmmaEnc.getSizePerThreadForOperands(1)[rank - 2])); + auto newAType = + RankedTensorType::get(aShape, operandTypes[0], + ttg::DotOperandEncodingAttr::get( + ctx, 0, wmmaEnc, + wmmaEnc.getSizePerThreadForOperand( + /*kWidth=*/0, /*opIdx=*/0)[rank - 1])); + auto newBType = + RankedTensorType::get(bShape, operandTypes[1], + ttg::DotOperandEncodingAttr::get( + ctx, 1, wmmaEnc, + wmmaEnc.getSizePerThreadForOperand( + /*kWidth=*/0, /*opIdx=*/1)[rank - 2])); Value castedA = convertAndCastTensor(rewriter, a, newAType.getEncoding(), operandTypes[0]); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index bf033bdd5322..73c21cae6de2 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -603,9 +603,9 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth; int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth; - auto numRep = - mmaLayout.getMMAv2Rep(shapePerCTA, bitwidth, encoding.getOpIdx()); int kWidth = encoding.getKWidth(); + auto numRep = mmaLayout.getMMAv2RepForOperand(shapePerCTA, bitwidth, kWidth, + encoding.getOpIdx()); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); auto order = triton::gpu::getOrder(mmaLayout); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index af897ef546dd..4a3f530a747d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -318,10 +318,12 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth(); auto dotOpA = cast(aTensorTy.getEncoding()); auto repA = cast(dotOpA.getParent()) - .getMMAv2Rep(aShapePerCTA, bitwidth, dotOpA.getOpIdx()); + .getMMAv2RepForOperand(aShapePerCTA, bitwidth, + dotOpA.getKWidth(), dotOpA.getOpIdx()); auto dotOpB = cast(bTensorTy.getEncoding()); auto repB = cast(dotOpB.getParent()) - .getMMAv2Rep(bShapePerCTA, bitwidth, dotOpB.getOpIdx()); + .getMMAv2RepForOperand(bShapePerCTA, bitwidth, + dotOpB.getKWidth(), dotOpB.getOpIdx()); assert(repA[2] == repB[1]); assert(repA[0] == repB[0]); diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 60ccb6c5cad8..c65428d03975 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -51,8 +51,8 @@ class LinearLayoutConversionsTest : public ::testing::Test { isTransposed, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); } - DotOperandEncodingAttr amdDot(AMDMfmaEncodingAttr mfma, unsigned opIdx, - unsigned kWidth) { + DotOperandEncodingAttr mfmaDot(AMDMfmaEncodingAttr mfma, unsigned opIdx, + unsigned kWidth) { return DotOperandEncodingAttr::get(&ctx, opIdx, mfma, kWidth); } @@ -659,9 +659,9 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4x1Warps) { TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/32, /*nDim=*/32, /*isTransposed=*/false); - auto amdDot_1_8 = amdDot(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDot_1_8 = mfmaDot(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); EXPECT_EQ( - toLinearLayout({128, 128}, amdDot_1_8), + toLinearLayout({128, 128}, mfmaDot_1_8), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}}}, {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, @@ -670,7 +670,7 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { {S("dim0"), S("dim1")})); EXPECT_EQ( - toLinearLayout({128, 256}, amdDot_1_8), + toLinearLayout({128, 256}, mfmaDot_1_8), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}}}, {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, @@ -678,7 +678,7 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { {S("block"), {}}}, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({32, 64}, amdDot_1_8), + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDot_1_8), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}}}, {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, @@ -687,7 +687,7 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { {S("dim0"), S("dim1")})); EXPECT_EQ( - toLinearLayout({256, 256}, amdDot_1_8), + toLinearLayout({256, 256}, mfmaDot_1_8), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}, {128, 0}}}, @@ -698,8 +698,8 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/32, /*nDim=*/32, /*isTransposed=*/false); - auto amdDot_1_4 = amdDot(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); - EXPECT_EQ(toLinearLayout({256, 256}, amdDot_1_4), + auto mfmaDot_1_4 = mfmaDot(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_4), LinearLayout( {{S("register"), {{1, 0}, @@ -719,9 +719,9 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/16, /*nDim=*/16, /*isTransposed=*/false); - auto amdDot_1_4 = amdDot(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDot_1_4 = mfmaDot(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); EXPECT_EQ( - toLinearLayout({128, 128}, amdDot_1_4), + toLinearLayout({128, 128}, mfmaDot_1_4), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 64}}}, {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}, @@ -729,7 +729,7 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { {S("block"), {}}}, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({1, 128}, amdDot_1_4), + EXPECT_EQ(toLinearLayout({1, 128}, mfmaDot_1_4), LinearLayout( {{S("register"), {{0, 0}, {0, 0}, {0, 0}, {0, 64}}}, {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}, {0, 0}}}, @@ -737,7 +737,7 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { {S("block"), {}}}, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({128, 1}, amdDot_1_4), + EXPECT_EQ(toLinearLayout({128, 1}, mfmaDot_1_4), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}}}, {S("lane"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {8, 0}, {16, 0}}}, @@ -745,7 +745,7 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { {S("block"), {}}}, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({256, 256}, amdDot_1_4), + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_4), LinearLayout( {{S("register"), {{1, 0}, @@ -763,9 +763,9 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/16, /*nDim=*/16, /*isTransposed=*/false); - auto amdDot_1_8 = amdDot(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDot_1_8 = mfmaDot(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); EXPECT_EQ( - toLinearLayout({256, 256}, amdDot_1_8), + toLinearLayout({256, 256}, mfmaDot_1_8), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {128, 0}, {0, 128}}}, @@ -776,9 +776,9 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { auto parentMfma_1_8_1 = mfma(/*warps=*/{1, 1, 8}, /*mDim=*/16, /*nDim=*/16, /*isTransposed=*/false); - auto amdDot_1_8_1 = amdDot(parentMfma_1_8_1, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDot_1_8_1 = mfmaDot(parentMfma_1_8_1, /*opIdx=*/1, /*kWidth=*/8); - EXPECT_EQ(toLinearLayout({1, 256, 256}, amdDot_1_8_1), + EXPECT_EQ(toLinearLayout({1, 256, 256}, mfmaDot_1_8_1), LinearLayout({{S("register"), {{0, 1, 0}, {0, 2, 0},