From 68350e991eaa92ca0d30c726469bb490fd4ed8b8 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Tue, 2 Jul 2024 17:06:13 +0200 Subject: [PATCH] Relax dot operand constrains with FMA based dot This PR: - Refactors FMA dot implementation - Supports dot3d in FMA path - Fixes several issues in operand offset computation - Enables small dot operands --- .../Conversion/TritonGPUToLLVM/Utility.h | 16 + include/triton/Dialect/TritonGPU/IR/Dialect.h | 16 + lib/Analysis/Utility.cpp | 12 +- .../SharedToDotOperandFMA.cpp | 333 ++++++++---------- .../TritonGPUToLLVM/DotOpToLLVM/FMA.cpp | 98 +++--- .../TritonToTritonGPUPass.cpp | 5 + lib/Dialect/TritonGPU/IR/Dialect.cpp | 43 ++- .../Transforms/ReduceDataDuplication.cpp | 43 ++- python/test/unit/language/test_core.py | 20 +- third_party/amd/backend/compiler.py | 28 +- third_party/nvidia/backend/compiler.py | 15 +- .../SharedToDotOperandMMAv2.cpp | 17 - 12 files changed, 338 insertions(+), 308 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 8ea01d2f1fb5..0c8d3da0d423 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1473,6 +1473,22 @@ inline bool isLayoutMmaV1(Attribute layout) { return isMmaV1; } +inline SharedMemoryObject +getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, + SharedMemoryObject smemObj, + ArrayRef shape) { + auto strides = smemObj.getStrides(); + auto offsets = smemObj.getOffsets(); + auto rank = strides.size(); + if (rank == 3) + return smemObj; + strides.insert(strides.begin(), i32_val(shape[0] * shape[1])); + offsets.insert(offsets.begin(), i32_val(0)); + auto expandedSmemObj = SharedMemoryObject( + smemObj.getBase(), smemObj.getBaseElemType(), strides, offsets); + return expandedSmemObj; +} + } // namespace mlir #endif diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 16e6506e5bad..3c3e4760a236 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -131,6 +131,22 @@ void dumpHWLayout(RankedTensorType tensorType); // Return a string representation of the layout of the tensor. std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView); +template +llvm::SmallVector expandMatrixShapeWithBatch(llvm::ArrayRef s) { + llvm::SmallVector expanded(3 - s.size(), 1); + expanded.append(s.begin(), s.end()); + return expanded; +} + +template +llvm::SmallVector expandMatrixOrderWithBatch(llvm::ArrayRef o) { + int oldRank = o.size(); + llvm::SmallVector expanded(3, 0); + for (int i = 0; i < oldRank; ++i) + expanded[i] += o[i] + 3 - oldRank; + return expanded; +} + } // namespace gpu } // namespace triton } // namespace mlir diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index be68f416f4e9..3a689b4a3137 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -480,12 +480,18 @@ bool supportMMA(triton::DotOp op, int version) { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 auto aElemTy = op.getA().getType().getElementType(); auto bElemTy = op.getB().getType().getElementType(); + auto retType = op.getType(); + auto retShapePerCTA = getShapePerCTA(retType); + auto rank = retShapePerCTA.size(); + auto aTensorTy = cast(op.getA().getType()); + auto aShape = aTensorTy.getShape(); + auto encoding = cast(aTensorTy.getEncoding()); + if (retShapePerCTA[rank - 2] < 16 || retShapePerCTA[rank - 1] < 16 || + aShape[rank - 1] < 16) + return false; if (version == 3) { if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) return false; - auto retType = op.getType(); - auto retShapePerCTA = getShapePerCTA(retType); - auto rank = retShapePerCTA.size(); auto mod = op->getParentOfType(); int numWarps = TritonGPUDialect::getNumWarps(mod); // TODO(Keren): for now, fallback to MMAv2 if handling batch matmul. diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index b7bd5fbc3432..236ee8538622 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -1,5 +1,6 @@ #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" using ValueTable = std::map, Value>; using ::mlir::LLVM::delinearize; @@ -7,6 +8,8 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getStridesFromShapeAndOrder; using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::expandMatrixOrderWithBatch; +using ::mlir::triton::gpu::expandMatrixShapeWithBatch; using ::mlir::triton::gpu::getContigPerThread; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; @@ -15,47 +18,6 @@ using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::isaDistributedLayout; using ::mlir::triton::gpu::SharedEncodingAttr; -SmallVector -getThreadIds(Value threadId, ArrayRef shapePerCTATile, - ArrayRef sizePerThread, ArrayRef order, - ConversionPatternRewriter &rewriter, Location loc) { - int dim = order.size(); - SmallVector threadIds(dim); - for (unsigned k = 0; k < dim - 1; k++) { - Value dimK = i32_val(shapePerCTATile[order[k]] / sizePerThread[order[k]]); - Value rem = urem(threadId, dimK); - threadId = udiv(threadId, dimK); - threadIds[order[k]] = rem; - } - Value dimK = i32_val(shapePerCTATile[order[dim - 1]]); - threadIds[order[dim - 1]] = urem(threadId, dimK); - return threadIds; -} - -// Get shapePerCTATile for M or N axis. -int getShapePerCTATileForMN(BlockedEncodingAttr layout, bool isM) { - auto order = layout.getOrder(); - auto shapePerCTATile = getShapePerCTATile(layout); - - int mShapePerCTATile = - order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int nShapePerCTATile = - order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - return isM ? mShapePerCTATile : nShapePerCTATile; -} - -// Get sizePerThread for M or N axis. -int getSizePerThreadForMN(BlockedEncodingAttr layout, bool isM) { - auto order = layout.getOrder(); - auto sizePerThread = getSizePerThread(layout); - - int mSizePerThread = - order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int nSizePerThread = - order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - return isM ? mSizePerThread : nSizePerThread; -} - Value getStructFromValueTable(ArrayRef vals, ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter *typeConverter, @@ -71,154 +33,151 @@ Value getStructFromValueTable(ArrayRef vals, return packLLElements(loc, typeConverter, elems, rewriter, structTy); } -ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA, - int sizePerThread, - ConversionPatternRewriter &rewriter, - Location loc, - const LLVMTypeConverter *typeConverter, - Type type) { - ValueTable res; - auto elems = unpackLLElements(loc, val, rewriter); - int index = 0; - for (unsigned k = 0; k < K; ++k) { - for (unsigned m = 0; m < n0; m += shapePerCTA) - for (unsigned mm = 0; mm < sizePerThread; ++mm) { - res[{m + mm, k}] = elems[index++]; - } - } - return res; +SmallVector swizzleIndices(ConversionPatternRewriter &rewriter, + Location loc, SmallVector rawIndices, + SharedEncodingAttr layout) { + const auto &order = layout.getOrder(); + auto rank = order.size(); + + if (layout.getMaxPhase() == 1) + return rawIndices; + + auto vec = i32_val(layout.getVec()); + auto perPhase = i32_val(layout.getPerPhase()); + auto maxPhase = i32_val(layout.getMaxPhase()); + + auto fastIdx = rawIndices[order[0]]; + auto secondIdx = rawIndices[order[1]]; + // Original algorithm taken from getSwizzledSharedPtrs function + // (TritonGPUToLLVMBase.h) + // + // phase = (secondIdx // perPhase) % maxPhase + // swizzledGroup = ((fastIdx // vec) ^ phase) * vec + // groupRemainder = fastIdx % vec + // colOff = swizzledGroup + groupRemainder + auto phase = urem(udiv(secondIdx, perPhase), maxPhase); + auto swizzledGroup = mul(xor_(udiv(fastIdx, vec), phase), vec); + auto groupRemainder = urem(fastIdx, vec); + auto colOff = add(swizzledGroup, groupRemainder); + + SmallVector swizzledIndices = rawIndices; + swizzledIndices[order[0]] = colOff; + + return swizzledIndices; } -Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, - Location loc, const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) { - auto aTensorTy = cast(A.getType()); - auto aLayout = cast(aTensorTy.getEncoding()); - auto aShapePerCTA = getShapePerCTA(aTensorTy); - - auto aOrder = aLayout.getOrder(); - auto order = dLayout.getOrder(); - - bool isARow = aOrder[0] == 1; - - auto aSmem = getSharedMemoryObjectFromStruct( - loc, llA, typeConverter->convertType(aTensorTy.getElementType()), +Value loadFMAOp(Value dotOp, Value llA, BlockedEncodingAttr dLayout, + Value thread, Location loc, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const int dotOpNo) { + auto ctx = dotOp.getContext(); + const int bDim = 0; + const int kDim = dotOpNo == 0 ? 2 : 1; + const int nonKDim = dotOpNo == 0 ? 1 : 2; + auto opTensorTy = cast(dotOp.getType()); + auto opLayout = cast(opTensorTy.getEncoding()); + auto opShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(opTensorTy))); + + auto order = expandMatrixOrderWithBatch(dLayout.getOrder()); + + auto origSmem = getSharedMemoryObjectFromStruct( + loc, llA, typeConverter->convertType(opTensorTy.getElementType()), rewriter); - Value strideAM = aSmem.strides[0]; - Value strideAK = aSmem.strides[1]; - Value strideA0 = isARow ? strideAK : strideAM; - Value strideA1 = isARow ? strideAM : strideAK; - int aNumPtr = 8; - int K = aShapePerCTA[1]; - int M = aShapePerCTA[0]; - - auto shapePerCTATile = getShapePerCTATile(dLayout); - auto sizePerThread = getSizePerThread(dLayout); - - Value _0 = i32_val(0); - - Value mContig = i32_val(sizePerThread[order[1]]); + auto smem = getExpandedSharedMemoryObject(rewriter, loc, origSmem, + opTensorTy.getShape()); + auto strides = smem.strides; + int B = opShapePerCTA[bDim]; + int K = opShapePerCTA[kDim]; + int NonK = opShapePerCTA[nonKDim]; + + auto shapePerCTATile = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); + auto sizePerThread = + expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout))); + auto threadsPerWarp = + expandMatrixShapeWithBatch(ArrayRef(dLayout.getThreadsPerWarp())); + auto warpsPerCTA = + expandMatrixShapeWithBatch(ArrayRef(dLayout.getWarpsPerCTA())); // threadId in blocked layout - auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, - rewriter, loc); - Value threadIdM = threadIds[0]; - - Value offA0 = isARow ? _0 : mul(threadIdM, mContig); - Value offA1 = isARow ? mul(threadIdM, mContig) : _0; - SmallVector aOff(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) { - aOff[i] = add(mul(offA0, strideA0), mul(offA1, strideA1)); - } - auto elemTy = typeConverter->convertType(aTensorTy.getElementType()); - - Type ptrTy = ptr_ty(rewriter.getContext(), 3); - SmallVector aPtrs(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) - aPtrs[i] = gep(ptrTy, elemTy, aSmem.base, aOff[i]); - - SmallVector vas; - - int mShapePerCTATile = getShapePerCTATileForMN(dLayout, true /*isM*/); - int mSizePerThread = getSizePerThreadForMN(dLayout, true /*isM*/); - - for (unsigned k = 0; k < K; ++k) - for (unsigned m = 0; m < M; m += mShapePerCTATile) - for (unsigned mm = 0; mm < mSizePerThread; ++mm) { - Value offset = - add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK)); - Value pa = gep(ptrTy, elemTy, aPtrs[0], offset); - Value va = load(elemTy, pa); - vas.emplace_back(va); - } - - return getStructFromValueTable(vas, rewriter, loc, typeConverter, elemTy); -} - -Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, - Location loc, const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) { - auto bTensorTy = cast(B.getType()); - auto bLayout = cast(bTensorTy.getEncoding()); - auto bShapePerCTA = getShapePerCTA(bTensorTy); - - auto bOrder = bLayout.getOrder(); - auto order = dLayout.getOrder(); - - bool isBRow = bOrder[0] == 1; - - auto bSmem = getSharedMemoryObjectFromStruct( - loc, llB, typeConverter->convertType(bTensorTy.getElementType()), - rewriter); - Value strideBN = bSmem.strides[1]; - Value strideBK = bSmem.strides[0]; - Value strideB0 = isBRow ? strideBN : strideBK; - Value strideB1 = isBRow ? strideBK : strideBN; - int bNumPtr = 8; - int K = bShapePerCTA[0]; - int N = bShapePerCTA[1]; - - auto shapePerCTATile = getShapePerCTATile(dLayout); - auto sizePerThread = getSizePerThread(dLayout); - - Value _0 = i32_val(0); - - Value nContig = i32_val(sizePerThread[order[0]]); - - // threadId in blocked layout - auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, - rewriter, loc); - Value threadIdN = threadIds[1]; - - Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; - Value offB1 = isBRow ? _0 : mul(threadIdN, nContig); - SmallVector bOff(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) { - bOff[i] = add(mul(offB0, strideB0), mul(offB1, strideB1)); - } - auto elemTy = typeConverter->convertType(bTensorTy.getElementType()); - - Type ptrTy = ptr_ty(rewriter.getContext(), 3); - SmallVector bPtrs(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) - bPtrs[i] = gep(ptrTy, elemTy, bSmem.base, bOff[i]); - - SmallVector vbs; - - int nShapePerCTATile = getShapePerCTATileForMN(dLayout, false /*isM*/); - int nSizePerThread = getSizePerThreadForMN(dLayout, false /*isM*/); - - for (unsigned k = 0; k < K; ++k) - for (unsigned n = 0; n < N; n += nShapePerCTATile) - for (unsigned nn = 0; nn < nSizePerThread; ++nn) { - Value offset = - add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK)); - Value pb = gep(ptrTy, elemTy, bPtrs[0], offset); - Value vb = load(elemTy, pb); - vbs.emplace_back(vb); - } - - return getStructFromValueTable(vbs, rewriter, loc, typeConverter, elemTy); + auto warpSize = i32_val(triton::gpu::getWarpSize(dLayout)); + auto laneId = urem(thread, warpSize); + auto warpId = udiv(thread, warpSize); + auto laneIds = + mlir::LLVM::delinearize(rewriter, loc, laneId, threadsPerWarp, order); + auto warpIds = + mlir::LLVM::delinearize(rewriter, loc, warpId, warpsPerCTA, order); + auto sizePerWarpB = sizePerThread[bDim] * threadsPerWarp[bDim]; + auto sizePerWarpNonK = sizePerThread[nonKDim] * threadsPerWarp[nonKDim]; + + Value bTileOffset = mul(laneIds[bDim], i32_val(sizePerThread[bDim])); + bTileOffset = add(bTileOffset, mul(warpIds[bDim], i32_val(sizePerWarpB))); + Value nonKTileOffset = mul(laneIds[nonKDim], i32_val(sizePerThread[nonKDim])); + nonKTileOffset = + add(nonKTileOffset, mul(warpIds[nonKDim], i32_val(sizePerWarpNonK))); + + auto elemTy = typeConverter->convertType(opTensorTy.getElementType()); + Type ptrTy = ptr_ty(ctx, 3); + + unsigned vectorSize = order[0] == kDim ? K : sizePerThread[order[0]]; + if (opLayout.getMaxPhase() > 0) + vectorSize = std::min(vectorSize, opLayout.getVec()); + auto vecTy = vec_ty(elemTy, vectorSize); + + unsigned dimStep[3] = {1, 1, 1}; + dimStep[order[0]] = vectorSize; + + int shapePerCTABTile = shapePerCTATile[bDim]; + int shapePerCTANonKTile = shapePerCTATile[nonKDim]; + int sizeBPerThread = sizePerThread[bDim]; + int sizeNonKPerThread = sizePerThread[nonKDim]; + int numBTiles = std::max(1, B / shapePerCTABTile); + int numNonKTiles = std::max(1, NonK / shapePerCTANonKTile); + + SmallVector opValues(numBTiles * sizeBPerThread * K * numNonKTiles * + sizeNonKPerThread); + + for (unsigned bTile = 0; bTile < numBTiles; ++bTile) + for (unsigned b = 0; b < sizeBPerThread; b += dimStep[bDim]) + for (unsigned k = 0; k < K; k += dimStep[kDim]) + for (unsigned nonKTile = 0; nonKTile < numNonKTiles; ++nonKTile) + for (unsigned nonK = 0; nonK < sizeNonKPerThread; + nonK += dimStep[nonKDim]) { + SmallVector rawIndices(3); + rawIndices[bDim] = + add(bTileOffset, i32_val(bTile * shapePerCTABTile + b)); + rawIndices[nonKDim] = add( + nonKTileOffset, i32_val(nonKTile * shapePerCTANonKTile + nonK)); + rawIndices[kDim] = i32_val(k); + + SmallVector swizzledIndices = + swizzleIndices(rewriter, loc, rawIndices, opLayout); + + Value offset = i32_val(0); + for (int dim = 0; dim < order.size(); ++dim) + offset = add(offset, mul(urem(swizzledIndices[dim], + i32_val(opShapePerCTA[dim])), + strides[dim])); + + Value elemAddr = gep(ptrTy, elemTy, smem.base, offset); + Value vecAddr = bitcast(elemAddr, ptr_ty(ctx, 3)); + Value vec = load(vecTy, elemAddr); + for (int elem = 0; elem < vectorSize; ++elem) { + int outIdx[3] = {}; + outIdx[bDim] = bTile * sizeBPerThread + b; + outIdx[kDim] = k; + outIdx[nonKDim] = nonKTile * sizeNonKPerThread + nonK; + outIdx[order[0]] += elem; + int idx = (outIdx[bDim] * K + outIdx[kDim]) * numNonKTiles * + sizeNonKPerThread + + outIdx[nonKDim]; + opValues[idx] = extract_element(elemTy, vec, i32_val(elem)); + } + } + + return getStructFromValueTable(opValues, rewriter, loc, typeConverter, + elemTy); } namespace SharedToDotOperandFMA { @@ -226,9 +185,7 @@ Value convertLayout(int opIdx, Value val, Value llVal, BlockedEncodingAttr dLayout, Value thread, Location loc, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter) { - if (opIdx == 0) - return loadAFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); - else - return loadBFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); + return loadFMAOp(val, llVal, dLayout, thread, loc, typeConverter, rewriter, + opIdx); } } // namespace SharedToDotOperandFMA diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp index afb5bf01d48b..29bb10d4a5a1 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -1,29 +1,30 @@ #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" using namespace mlir; using namespace mlir::triton; +using namespace ::mlir::triton::gpu; -using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::expandMatrixOrderWithBatch; +using ::mlir::triton::gpu::expandMatrixShapeWithBatch; using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::getSizePerThread; -using ValueTableFMA = std::map, Value>; +using ValueTableFMA = std::map, Value>; static ValueTableFMA -getValueTableFromStructFMA(Value val, int K, int n0, int shapePerCTATile, - int sizePerThread, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter, Type type) { +getValueTableFromStructFMA(Value val, int batch, int nonK, int K, + ConversionPatternRewriter &rewriter, Location loc) { ValueTableFMA res; auto elems = unpackLLElements(loc, val, rewriter); + assert(elems.size() == K * nonK * batch); int index = 0; - for (unsigned k = 0; k < K; ++k) { - for (unsigned m = 0; m < n0; m += shapePerCTATile) - for (unsigned mm = 0; mm < sizePerThread; ++mm) { - res[{m + mm, k}] = elems[index++]; - } - } + for (unsigned b = 0; b < batch; ++b) + for (unsigned k = 0; k < K; ++k) + for (unsigned i = 0; i < nonK; ++i) + res[{b, i, k}] = elems[index++]; return res; } @@ -39,61 +40,56 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, auto D = op.getResult(); auto aTensorTy = cast(A.getType()); - auto bTensorTy = cast(B.getType()); auto dTensorTy = cast(D.getType()); + auto dElemTy = dTensorTy.getElementType(); - auto aShapePerCTA = getShapePerCTA(aTensorTy); - auto bShapePerCTA = getShapePerCTA(bTensorTy); + SmallVector aShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); + auto dShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); BlockedEncodingAttr dLayout = cast(dTensorTy.getEncoding()); - auto order = dLayout.getOrder(); + auto order = expandMatrixOrderWithBatch(dLayout.getOrder()); auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); Value llA = adaptor.getA(); Value llB = adaptor.getB(); - auto sizePerThread = getSizePerThread(dLayout); - auto shapePerCTATile = getShapePerCTATile(dLayout); + auto sizePerThread = + expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout))); + auto shapePerCTATile = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); - int K = aShapePerCTA[1]; - int M = aShapePerCTA[0]; - int N = bShapePerCTA[1]; + int K = aShapePerCTA[2]; - int mShapePerCTATile = - order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int mSizePerThread = - order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int nShapePerCTATile = - order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int nSizePerThread = - order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + unsigned retSize[3]; + for (int i = 0; i < 3; ++i) { + unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i]; + numRep = std::max(static_cast(1), numRep); + retSize[i] = numRep * sizePerThread[i]; + } auto has = - getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread, - rewriter, loc, typeConverter, aTensorTy); + getValueTableFromStructFMA(llA, retSize[0], retSize[1], K, rewriter, loc); auto hbs = - getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread, - rewriter, loc, typeConverter, bTensorTy); + getValueTableFromStructFMA(llB, retSize[0], retSize[2], K, rewriter, loc); SmallVector ret = cc; - bool isCRow = order[0] == 1; - - for (unsigned k = 0; k < K; k++) { - for (unsigned m = 0; m < M; m += mShapePerCTATile) - for (unsigned n = 0; n < N; n += nShapePerCTATile) - for (unsigned mm = 0; mm < mSizePerThread; ++mm) - for (unsigned nn = 0; nn < nSizePerThread; ++nn) { - int mIdx = m / mShapePerCTATile * mSizePerThread + mm; - int nIdx = n / nShapePerCTATile * nSizePerThread + nn; - - int z = isCRow - ? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx - : nIdx * M / mShapePerCTATile * nSizePerThread + mIdx; - ret[z] = rewriter.create(loc, has[{m + mm, k}], - hbs[{n + nn, k}], ret[z]); - } - } + + for (unsigned b = 0; b < retSize[0]; ++b) + for (unsigned m = 0; m < retSize[1]; ++m) + for (unsigned n = 0; n < retSize[2]; ++n) { + unsigned idx[] = {b, m, n}; + unsigned linearIdx = 0; + for (auto dim : llvm::reverse(order)) { + linearIdx = linearIdx * retSize[dim] + idx[dim]; + } + for (unsigned k = 0; k < K; ++k) { + ret[linearIdx] = rewriter.create( + loc, has[{b, m, k}], hbs[{b, n, k}], ret[linearIdx]); + } + } auto res = packLLElements(loc, typeConverter, ret, rewriter, dTensorTy); rewriter.replaceOp(op, res); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 4aa2712ec939..daa788d5c9af 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -235,6 +235,11 @@ struct TritonDotPattern : public OpConversionPattern { retSizePerThread[rank - 1] = 4; retSizePerThread[rank - 2] = 4; } + retSizePerThread[rank - 1] = std::min( + retSizePerThread[rank - 1], static_cast(origShape[rank - 1])); + retSizePerThread[rank - 2] = std::min( + retSizePerThread[rank - 2], static_cast(origShape[rank - 2])); + SmallVector retOrder(rank); for (unsigned i = 0; i < rank; ++i) retOrder[i] = rank - 1 - i; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d5b5d459a910..344f6181b207 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -939,29 +939,26 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, getKWidth(), getOpIdx()); } if (auto blockedLayout = mlir::dyn_cast(getParent())) { - auto shapePerCTA = getShapePerCTA(*this, shape); - auto shapePerCTATile = ::getShapePerCTATile(blockedLayout); - auto order = blockedLayout.getOrder(); - auto sizePerThread = ::getSizePerThread(blockedLayout); - - int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0]; - int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0]; - - bool isM = getOpIdx() == 0; - - int mSizePerThread = - order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int nSizePerThread = - order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int sizePerThreadMN = isM ? mSizePerThread : nSizePerThread; - - int mShapePerCTATile = - order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int nShapePerCTATile = - order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int shapePerCTAMNTile = isM ? mShapePerCTATile : nShapePerCTATile; - - return K * std::max(otherDim / shapePerCTAMNTile, 1) * sizePerThreadMN; + auto shapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(*this, shape))); + auto shapePerCTATile = expandMatrixShapeWithBatch( + ArrayRef(::getShapePerCTATile(blockedLayout))); + auto sizePerThread = + expandMatrixShapeWithBatch(ArrayRef(::getSizePerThread(blockedLayout))); + + int batchDim = 0; + int kDim = getOpIdx() == 0 ? 2 : 1; + int nonKDim = getOpIdx() == 0 ? 1 : 2; + + int batchSize = + std::max(shapePerCTA[batchDim] / shapePerCTATile[batchDim], 1) * + sizePerThread[batchDim]; + int kSize = shapePerCTA[kDim]; + int nonKSize = + std::max(shapePerCTA[nonKDim] / shapePerCTATile[nonKDim], 1) * + sizePerThread[nonKDim]; + + return batchSize * kSize * nonKSize; } llvm_unreachable("unknown dot operand parent layout"); return 0; diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index 8c1f18e459c5..fef219a20180 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -28,6 +28,31 @@ namespace gpu { class TritonGPUReduceDataDuplicationPass : public impl::TritonGPUReduceDataDuplicationBase< TritonGPUReduceDataDuplicationPass> { + + static bool isLayoutConvertShortcut(RankedTensorType srcType, + RankedTensorType dstType) { + auto srcEncoding = srcType.getEncoding(); + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (auto srcMmaEncoding = + dyn_cast(srcEncoding)) { + + if (srcMmaEncoding.getVersionMajor() != 2 || + (srcMmaEncoding.getWarpsPerCTA()[1] == 1 && + dstDotOp.getParent() == srcMmaEncoding)) + return true; + } + if (auto srcMfmaEncoding = + dyn_cast(srcEncoding)) { + + if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 && + srcMfmaEncoding.getIsTransposed() && + dstDotOp.getParent() == srcMfmaEncoding) + return true; + } + return false; + } + public: void runOnOperation() override { ModuleOp mod = getOperation(); @@ -42,22 +67,8 @@ class TritonGPUReduceDataDuplicationPass dyn_cast(dstType.getEncoding()); if (!dstDotOp) return; - if (auto srcMmaEncoding = - dyn_cast(srcEncoding)) { - - if (srcMmaEncoding.getVersionMajor() != 2 || - (srcMmaEncoding.getWarpsPerCTA()[1] == 1 && - dstDotOp.getParent() == srcMmaEncoding)) - return; - } - if (auto srcMfmaEncoding = - dyn_cast(srcEncoding)) { - - if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 && - srcMfmaEncoding.getIsTransposed() && - dstDotOp.getParent() == srcMfmaEncoding) - return; - } + if (isLayoutConvertShortcut(srcType, dstType)) + return; auto srcOrder = triton::gpu::getOrder(srcEncoding); auto rank = srcOrder.size(); SmallVector sharedOrder; diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 028db766b3b0..6e7e963717e8 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3090,7 +3090,11 @@ def convert_fp8_to_fp32(x, device, dtype_str): ([(16, 16, 8, 4, False, False, 'None', 'ieee', 'float32', 'float32', 1), (32, 16, 8, 4, False, False, 'None', 'ieee', 'float16', 'float16', 1)] if "gfx9" in get_arch() else []) + [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1) - for float8_type in ["float8e5", "float8e4nv"]]) + for float8_type in ["float8e5", "float8e4nv"]] + + [(*shape_nw, False, False, epilogue, 'ieee', in_dtype, out_dtype, 1) + for shape_nw in [(2, 2, 16, 1), (1, 64, 64, 1), (64, 2, 64, 2), (64, 64, 4, 4)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] + for in_dtype, out_dtype in [('float16', 'float16'), ('float32', 'float32')]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, num_ctas, device): if is_interpreter(): @@ -3280,6 +3284,9 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid return # make sure ld/st are vectorized ptx = pgm.asm['ptx'] + is_fma = K < 16 or N < 16 or M < 16 + if is_fma: + return if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4): # XXX: skip small sizes because they are not vectorized assert 'ld.global.v4' in ptx @@ -3323,7 +3330,14 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid for in_dtype_str, out_dtype_str in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')]] + # Large block sizes - [(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')]) + [(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')] + + # Small block sizes + [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) + for B in [1, 2, 8] + for num_warps in [1, 2, 4] + for BLOCK_M, BLOCK_N in [(1, 32), (32, 2), (8, 8)] + for M, N, K in [(32, 32, 32)] + for in_dtype_str, out_dtype_str in [('float16', 'float16'), ('float32', 'float32')]]) def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str, device): if is_hip(): # hip does not support tf32 precision, so use ieee for all tests @@ -3394,6 +3408,8 @@ def kernel( if in_dtype_str == 'int8': out = numpy_random((B, M, N), dtype_str='int32', rs=rs) else: + x *= 0.1 + y *= 0.1 out = numpy_random((B, M, N), dtype_str=out_dtype_str, rs=rs) x_tri = to_triton(x, device=device) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index bf966ae2af74..4d3cfa73849d 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -12,16 +12,30 @@ def min_dot_size(target: GPUTarget): + + def fma_supported(lhsType, rhsType): + return lhsType == rhsType and (lhsType.is_fp16() or lhsType.is_fp32()) + + def gfx94_limits(lhsType, rhsType): + if fma_supported(lhsType.scalar, rhsType.scalar): + return (1, 1, 1) + # CDNA 3.0 supports k==8 in all mfma variants except for int8 + # (where the smallest `k` supported is 16) + return (16, 16, 16) if (lhsType.scalar.is_int8() or rhsType.scalar.is_int8()) else (16, 16, 8) + + def gfx9_limits(lhsType, rhsType): + if fma_supported(lhsType.scalar, rhsType.scalar): + return (1, 1, 1) + # CDNA 2.0 always supports `k==8` + return (16, 16, 8) + arch_str = target.arch - # CDNA 3.0 supports k==8 in all mfma variants except for int8 - # (where the smallest `k` supported is 16) if "gfx94" in arch_str: - return lambda lhsType, rhsType: (16, 16, 16) if (lhsType.is_int8() or rhsType.is_int8()) else (16, 16, 8) - # CDNA 2.0 always supports `k==8` + return gfx94_limits if "gfx9" in arch_str: - return lambda lhsType, rhsType: (16, 16, 8) - # Other architectures will only support 16,16,16 - return lambda lhsType, rhsType: (16, 16, 16) + return gfx9_limits + # Other architectures will only support 16,16,16 with mfma instructions + return lambda lhsType, rhsType: (1, 1, 1) if fma_supported(lhsType.scalar, rhsType.scalar) else (16, 16, 16) @dataclass(frozen=True) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 5dd75e530fec..51abf0e3618e 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -14,7 +14,20 @@ def min_dot_size(target: GPUTarget): - return lambda lhsType, rhsType: (16, 32, 16) if lhsType.is_int8() else (16, 16, 16) + + def fma_supported(lhsType, rhsType): + return lhsType == rhsType and (lhsType.is_fp16() or lhsType.is_fp32()) + + def limits(lhsType, rhsType): + if fma_supported(lhsType.scalar, rhsType.scalar): + return (1, 1, 1) + # TODO it should be lhsType.scalar.is_int8() + if lhsType.is_int8(): + return (16, 32, 16) + else: + return (16, 16, 16) + + return limits @functools.lru_cache() diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index d1086c189d33..09162f5c6e9a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -743,23 +743,6 @@ MemDescType getExpandedDesc(MemDescType descTy) { return expandedDesc; } -SharedMemoryObject -getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, - SharedMemoryObject smemObj, - ArrayRef shape) { - auto strides = smemObj.getStrides(); - auto offsets = smemObj.getOffsets(); - auto rank = strides.size(); - if (rank == 3) - return smemObj; - auto expandedStrides = insertValue(strides, 0, i32_val(shape[0] * shape[1])); - auto expandedOffsets = insertValue(offsets, 0, i32_val(0)); - auto expandedSmemObj = - SharedMemoryObject(smemObj.getBase(), smemObj.getBaseElemType(), - expandedStrides, expandedOffsets); - return expandedSmemObj; -} - namespace SharedToDotOperandMMAv2 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding,