diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index a26a18ed96bc..dbf5be69325b 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -5,6 +5,8 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include @@ -36,16 +38,15 @@ int getWmmaVersion(StringRef archGen) { return 0; } -SmallVector warpsPerTile(tt::DotOp dotOp, - const ArrayRef shape, - int numWarps, - SmallVector shapePerWarp) { +SmallVector +warpsPerTile(Operation *dotOp, ArrayRef shape, int numWarps, + std::pair shapePerWarp) { auto rank = shape.size(); // Early exit for batched matmul if (rank == 3) return {(unsigned)numWarps, 1, 1}; - auto filter = [&dotOp](Operation *op) { + auto filter = [dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); }; ForwardSliceOptions fwdOpt; @@ -55,7 +56,7 @@ SmallVector warpsPerTile(tt::DotOp dotOp, bwdOpt.filter = filter; auto slices = getSlice(dotOp, bwdOpt, fwdOpt); for (Operation *op : slices) - if (isa(op) && (op != dotOp)) + if (op->hasTrait() && (op != dotOp)) return {(unsigned)numWarps, 1}; SmallVector tensorShape = {shape[0], shape[1]}; @@ -63,9 +64,9 @@ SmallVector warpsPerTile(tt::DotOp dotOp, do { if (ret[0] * ret[1] >= numWarps) break; - if (tensorShape[0] / (shapePerWarp[0] * 2) / ret[0] >= - tensorShape[1] / shapePerWarp[1] / ret[1]) { - if (ret[0] < tensorShape[0] / shapePerWarp[0]) { + if (tensorShape[0] / (shapePerWarp.first * 2) / ret[0] >= + tensorShape[1] / shapePerWarp.second / ret[1]) { + if (ret[0] < tensorShape[0] / shapePerWarp.first) { ret[0] *= 2; } else ret[1] *= 2; @@ -74,24 +75,89 @@ SmallVector warpsPerTile(tt::DotOp dotOp, } } while (true); - if (ret[1] * shapePerWarp[1] > tensorShape[1]) { + if (ret[1] * shapePerWarp.second > tensorShape[1]) { return {ret[1], ret[0]}; } return ret; } -SmallVector -warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps, - SmallVector shapePerWarp) { +SmallVector +warpsPerTileMFMA(Operation *dotOp, ArrayRef shape, int numWarps, + std::pair shapePerWarp) { return warpsPerTile(dotOp, shape, numWarps, shapePerWarp); } -SmallVector -warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { - return warpsPerTile(dotOp, shape, numWarps, - {ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr()[0], - ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr()[1]}); +SmallVector +warpsPerTileWMMA(Operation *dotOp, ArrayRef shape, int numWarps) { + auto mnk = ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr(); + return warpsPerTile(dotOp, shape, numWarps, {mnk[0], mnk[1]}); +} + +// Chooses a proper MFMA instruction that can used to compute the given dot op. +// If enforcedNonKDim is not zero, it will be used to overwrite the default +// logic to chose a MFMA with matching M/N dim. +FailureOr chooseMfmaInstruction(RankedTensorType cType, + Type aElemType, Type bElemType, + int inputKSize, int mfmaVersion, + int enforcedNonKDim) { + // number of matrix elements along k dim per one MFMA intruction + unsigned kDim = 0; + + auto resShape = cType.getShape(); + auto rank = resShape.size(); + auto M = resShape[rank - 2]; + auto N = resShape[rank - 1]; + + unsigned mDim = 0; + unsigned nDim = 0; + if (enforcedNonKDim != 0) { + mDim = nDim = enforcedNonKDim; + } else { + int minSize = std::min(M, N); + if (minSize >= 32) { + mDim = 32; + nDim = 32; + } + if (minSize >= 16 && minSize < 32) { + mDim = 16; + nDim = 16; + } + if (minSize < 16) { + if (M < 16 && N >= 64) { + mDim = 4; + nDim = 64; + } else if (M >= 64 && N < 16) { + mDim = 64; + nDim = 4; + } else { + assert(inputKSize >= 64 && + "k should be at least 64 to use this layout"); + mDim = 4; + nDim = 4; + } + } + } + assert(mDim != 0 && nDim != 0); + + auto maybeMfmaInsn = + MfmaInsn::selectMfma(mDim, nDim, aElemType, bElemType, mfmaVersion); + if (failed(maybeMfmaInsn)) + llvm::report_fatal_error("No match found in MFMA database\n"); + + kDim = maybeMfmaInsn->getKDim(); + assert(kDim != 0); + assert(M % mDim == 0 && N % nDim == 0); + assert(inputKSize % kDim == 0); + return maybeMfmaInsn; +} + +FailureOr chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion, + int nonKDim) { + RankedTensorType aType = dot.getA().getType(); + return chooseMfmaInstruction(dot.getC().getType(), aType.getElementType(), + dot.getB().getType().getElementType(), + aType.getShape().back(), mfmaVersion, nonKDim); } using OperandTypesVector = SmallVector; @@ -259,15 +325,16 @@ Value convertAndCastTensor(PatternRewriter &rewriter, Value value, return castedTensor; } -class BlockedToMFMA : public RewritePattern { +class BlockedToMFMA : public OpRewritePattern { int mfmaVersion; - int enforcedNonKDim; + int nonKDim; int kPack; public: - BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack) - : RewritePattern(tt::DotOp::getOperationName(), 2, context), - mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {} + BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), + nonKDim(nonKDim), kPack(kPack) {} bool isSecondDot(tt::DotOp &dotOp) const { auto filter = [&dotOp](Operation *op) { @@ -285,75 +352,15 @@ class BlockedToMFMA : public RewritePattern { return false; } - /// @brief Choose MFMA instruction parameters - /// @param dot target dot operation - /// @return MfmaInsn or failure - FailureOr chooseMfmaInstruction(tt::DotOp dot) const { - // number of matrix elements along k dim per one MFMA intruction - unsigned kDim = 0; - auto opType = cast(dot.getA().getType()); - auto dataTypeA = opType.getElementType(); - auto dataTypeB = - cast(dot.getB().getType()).getElementType(); - - auto resType = cast(dot.getD().getType()); - auto resShape = resType.getShape(); - auto rank = resShape.size(); - auto M = resShape[rank - 2]; - auto N = resShape[rank - 1]; - - unsigned mDim = 0; - unsigned nDim = 0; - if (enforcedNonKDim != 0) { - mDim = enforcedNonKDim; - nDim = enforcedNonKDim; - } else { - int minSize = std::min(M, N); - if (minSize >= 32) { - mDim = 32; - nDim = 32; - } - if (minSize >= 16 && minSize < 32) { - mDim = 16; - nDim = 16; - } - if (minSize < 16) { - if (M < 16 && N >= 64) { - mDim = 4; - nDim = 64; - } else if (M >= 64 && N < 16) { - mDim = 64; - nDim = 4; - } else { - assert(opType.getShape()[rank - 1] >= 64 && - "k should be at least 64 to use this layout"); - mDim = 4; - nDim = 4; - } - } - } - assert(mDim != 0 && nDim != 0); - - auto maybeMfmaInsn = - MfmaInsn::selectMfma(mDim, nDim, dataTypeA, dataTypeB, mfmaVersion); - if (failed(maybeMfmaInsn)) - llvm::report_fatal_error("No match found in MFMA database\n"); - - kDim = maybeMfmaInsn->getKDim(); - assert(kDim != 0); - assert(M % mDim == 0 && N % nDim == 0); - assert(opType.getShape()[rank - 1] % kDim == 0); - return maybeMfmaInsn; - } - - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(tt::DotOp dotOp, PatternRewriter &rewriter) const override { - auto dotOp = cast(op); - RankedTensorType oldRetType = dotOp.getType(); if (!oldRetType.getEncoding() || !isa(oldRetType.getEncoding())) return failure(); + if (!isa_and_nonnull(dotOp.getType().getEncoding())) + return rewriter.notifyMatchFailure( + dotOp, "expected blocked encoding result tensor"); if (!supportMFMA(dotOp)) return failure(); @@ -362,7 +369,7 @@ class BlockedToMFMA : public RewritePattern { // get MFMA encoding for the given number of warps auto retShape = oldRetType.getShape(); - auto mod = op->getParentOfType(); + auto mod = dotOp->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); // operands @@ -374,7 +381,7 @@ class BlockedToMFMA : public RewritePattern { ttg::AMDMfmaEncodingAttr mfmaEnc; - auto mfmaInstr = chooseMfmaInstruction(dotOp); + auto mfmaInstr = chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim); auto mDim = mfmaInstr.value().getMDim(); auto nDim = mfmaInstr.value().getNDim(); auto kDim = mfmaInstr.value().getKDim(); @@ -397,7 +404,7 @@ class BlockedToMFMA : public RewritePattern { mfmaAccType = rewriter.getF32Type(); // convert accumulator - auto oldAcc = dotOp.getOperand(2); + auto oldAcc = dotOp.getC(); auto newAcc = convertAndCastTensor(rewriter, oldAcc, mfmaEnc, mfmaAccType); // Here is a brief explanation of kWidth, kBase, and kDim @@ -456,11 +463,12 @@ class BlockedToMFMA : public RewritePattern { convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(), oldRetType.getElementType()); - rewriter.replaceOp(op, dotOutput); + rewriter.replaceOp(dotOp, dotOutput); return success(); } }; + static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, Type promotedType) { Type tensorPromotedType = cast(operand.getType()) @@ -566,18 +574,17 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { }); } -class BlockedToWMMA : public RewritePattern { +class BlockedToWMMA : public OpRewritePattern { int wmmaVersion; public: - BlockedToWMMA(MLIRContext *context, int wmmaVersion) - : RewritePattern(tt::DotOp::getOperationName(), 2, context), - wmmaVersion(wmmaVersion) {} + BlockedToWMMA(MLIRContext *context, int wmmaVersion, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), wmmaVersion(wmmaVersion) {} - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(tt::DotOp dotOp, PatternRewriter &rewriter) const override { - auto ctx = op->getContext(); - auto dotOp = cast(op); + auto ctx = dotOp->getContext(); Value a = dotOp.getA(); Value b = dotOp.getB(); @@ -603,7 +610,7 @@ class BlockedToWMMA : public RewritePattern { if (wmmaVersion == 2 && llvm::isa(oldAType) && oldAType.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure(op, "not supported yet"); + return rewriter.notifyMatchFailure(dotOp, "not supported yet"); } // get operand types @@ -612,7 +619,7 @@ class BlockedToWMMA : public RewritePattern { return failure(); // get WMMA encoding for the given number of warps - auto mod = op->getParentOfType(); + auto mod = dotOp->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); ttg::AMDWmmaEncodingAttr wmmaEnc; @@ -626,7 +633,7 @@ class BlockedToWMMA : public RewritePattern { auto newRetType = RankedTensorType::get(retShape, operandTypes[3], wmmaEnc); // convert accumulator - auto oldAcc = dotOp.getOperand(2); + auto oldAcc = dotOp.getC(); auto newAcc = convertAndCastTensor(rewriter, oldAcc, wmmaEnc, operandTypes[2]); @@ -653,7 +660,7 @@ class BlockedToWMMA : public RewritePattern { Value dotOutput = convertAndCastTensor(rewriter, newDot, oldRetEncoding, oldRetType.getElementType()); - rewriter.replaceOp(op, dotOutput); + rewriter.replaceOp(dotOp, dotOutput); return success(); } };