From f79d525d8b076d0fcfab71d90a69f7fa4c37840b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 26 Aug 2025 20:04:59 +0000 Subject: [PATCH 1/7] Refactor matching logic for elemwise tree --- mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | 382 ++++++++---------- ...mm-gemm-multiple-traces-to-first-gemm.mlir | 23 ++ 2 files changed, 202 insertions(+), 203 deletions(-) create mode 100644 mlir/test/fusion/pr-e2e/gemm-gemm/mixr-gemm-gemm-multiple-traces-to-first-gemm.mlir diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index 2e66ef7df548..dc7113848301 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -29,6 +29,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/Block.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -46,6 +47,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" +#include #include #define DEBUG_TYPE "convert-tosa-to-rock" @@ -493,109 +495,108 @@ static Operation *getConvOp(Operation *op) { return (isa_and_nonnull(op)) ? op : nullptr; } -// This function traverse an upward tree where the root is the input. -// It traverses the tree until it hit the gemm/conv or last elementwise -// operation that may or maynot be interleaved with reshape-like ops. Note -// there is a TODO to explore relaxing reshape-like ops constraints to more -// of rock.transforms. (See the implementation for the TODO) template -static std::tuple> -getElementwiseRegion(Value input, OpBuilder ®ionBuilder, Block *block, - SmallVector &elementwiseArgs, - std::optional loc = std::nullopt, - bool doRewrite = false, int recDepth = 0) { - PatternRewriter::InsertionGuard guard(regionBuilder); - regionBuilder.setInsertionPointToEnd(block); - // If the matmul/conv is found, we return this information to the - // root. - LLVM_DEBUG(llvm::dbgs() << std::string(recDepth, '\t') - << "getElementwiseRegion:input=" << input << "\n"); - - OpT fusionOp = input.getDefiningOp(); - Operation *op = input.getDefiningOp(); - // we need to traverse tranposes if it's conv2d - if (std::is_same_v && op) { - Operation *convOp = getConvOp(op); - if (convOp) - fusionOp = cast(convOp); - } - - if (fusionOp) { - Value fusionMemRef; - if (doRewrite) { - fusionMemRef = addBlockArgument(regionBuilder, input, block, loc.value()); - rock::RockGemmGemmWrapperInterface gemmGemmLikeOp = - cast(block->getParentOp()); - gemmGemmLikeOp.setFirstGemmIndices( - {static_cast(block->getArguments().size() - 1)}); - } - LLVM_DEBUG(llvm::dbgs() << std::string(recDepth, '\t') - << "matmul/conv found. terminating recursion.\n"); - return {fusionMemRef, fusionOp}; - } - if (tosa::ConstOp constOp = input.getDefiningOp()) { - Value newConstOpRes; - if (doRewrite) { - auto *newConstOp = regionBuilder.clone(*constOp); - newConstOpRes = newConstOp->getResult(0); - } - LLVM_DEBUG(llvm::dbgs() << std::string(recDepth, '\t') - << "const found. terminating recursion.\n"); - return {newConstOpRes, failure()}; - } - - // Right now, this is a bit restricted that we only allow reshape-like - // ops between in the elementwise tree that get fused to the fusion point. - // TODO: however, the latest code gridwise-gemm-to-blockwise should tackle - // more cases. The absolute restriction is gemm0Output to Linalg block - // should contain invertible transforms, but that's future work. - if (!op || (!isElementwiseOp(op) && - !isa(op))) { - Value blockArg; - if (doRewrite) { - blockArg = addBlockArgument(regionBuilder, input, block, loc.value()); - } - elementwiseArgs.push_back(input); - LLVM_DEBUG(llvm::dbgs() - << std::string(recDepth, '\t') - << "unsupported region op found. terminating recursion.\n"); - return {blockArg, failure()}; - } - // Following section recursively calls into the left and right - // sub-tree to grab as much of the elementwise tree rooted on softmax - // input. - mlir::IRMapping mapper; - SmallVector newOperands; - - FailureOr maybeFusionOp = failure(); - for (auto operand : op->getOperands()) { - auto [result, maybeSubTreeFusionOp] = getElementwiseRegion( - operand, regionBuilder, block, elementwiseArgs, loc, doRewrite, - recDepth + 1); - mapper.map(operand, result); - newOperands.push_back(result); - if (succeeded(maybeSubTreeFusionOp)) { - maybeFusionOp = maybeSubTreeFusionOp; - } - } - - Value res; - if (doRewrite) { - auto *newOp = regionBuilder.clone(*op, mapper); - res = newOp->getResult(0); - } - // We convey to the caller the result - // of the cloning as well if this subtree - // contains the first matmul/conv. - if (succeeded(maybeFusionOp)) { - LLVM_DEBUG(llvm::dbgs() << std::string(recDepth, '\t') - << "a subtree have a matmul/conv in it.\n"); - return {res, maybeFusionOp}; - } - LLVM_DEBUG(llvm::dbgs() << std::string(recDepth, '\t') - << "none of subtress have a matmul/conv in it.\n"); - return {res, failure()}; -} +struct ElementwiseRegionFinder { + void visit(Value input) { + if (visitedSet.contains(input)) + return; + visitedSet.insert(input); + OpT fusionOp = input.getDefiningOp(); + Operation *op = input.getDefiningOp(); + // we need to traverse tranposes if it's conv2d + if (std::is_same_v && op) { + Operation *convOp = getConvOp(op); + if (convOp) + fusionOp = cast(convOp); + } + if (fusionOp) { + firstGemmBasedOp = fusionOp; + firstGemmBasedVal = input; + // cache blockArgCandidates for rewrite + blockArgCandidates.push_back(input); + return; + } + if (op && dyn_cast(op)) { + constantVals.push_back(input); + return; + } + // Right now, this is a bit restricted that we only allow reshape-like + // ops between in the elementwise tree that get fused to the fusion point. + // TODO: however, the latest code gridwise-gemm-to-blockwise should tackle + // more cases. The absolute restriction is gemm0Output to Linalg block + // should contain invertible transforms, but that's future work. + if (!op || (!isElementwiseOp(op) && + !isa(op))) { + // cache blockArgCandidates for rewrite + blockArgCandidates.push_back(input); + return; + } + for (Value operand : op->getOperands()) { + // do a DFS on each operand + visit(operand); + } + // keep topological order for rewrite + visitedOps.push_back(op); + } + + FailureOr getFirstGemmBasedOp() const { + if (!firstGemmBasedOp) + return failure(); + return firstGemmBasedOp; + } + + SmallVector getElementwiseArgs() const { + // remove fusionValue from the candidates + SmallVector elementwiseArgs = blockArgCandidates; + uint64_t firstGemmBlockIndex = getFirstGemmBlockIndex(); + elementwiseArgs.erase(elementwiseArgs.begin() + firstGemmBlockIndex); + return elementwiseArgs; + } + + uint64_t getFirstGemmBlockIndex() const { + return std::find_if(blockArgCandidates.begin(), blockArgCandidates.end(), + [this](Value v) { return v == firstGemmBasedVal; }) - + blockArgCandidates.begin(); + } + + void rewrite(Value input, OpBuilder ®ionBuilder, Block *block, + Location loc, uint32_t recDepth = 0) const { + PatternRewriter::InsertionGuard guard(regionBuilder); + regionBuilder.setInsertionPointToEnd(block); + IRMapping mapper; + for (Value v : constantVals) { + auto *newConstOp = regionBuilder.clone(*v.getDefiningOp()); + mapper.map(v, newConstOp->getResult(0)); + } + for (Value v : blockArgCandidates) { + auto newBlockArg = addBlockArgument(regionBuilder, v, block, loc); + mapper.map(v, newBlockArg); + } + // make sure firstGemmInput is passed through + Value lastRes = mapper.lookup(firstGemmBasedVal); + for (Operation *op : visitedOps) { + auto *newOp = regionBuilder.clone(*op, mapper); + lastRes = newOp->getResult(0); + mapper.map(lastRes, newOp->getResult(0)); + } + RankedTensorType resTensorType = cast(lastRes.getType()); + MemRefType resMemRefType = MemRefType::get(resTensorType.getShape(), + resTensorType.getElementType()); + Value resMemref = regionBuilder.create( + loc, cast(resMemRefType), lastRes); + Value outMemref = block->addArgument(resMemRefType, loc); + regionBuilder.create(loc, resMemref, outMemref); + regionBuilder.create(loc); + } + +private: + OpT firstGemmBasedOp = nullptr; + Value firstGemmBasedVal = nullptr; + DenseSet visitedSet; + SmallVector blockArgCandidates; + SmallVector constantVals; + SmallVector visitedOps; +}; template class ConvConverter final : public OpConversionPattern { @@ -1103,40 +1104,44 @@ struct ConvElementwiseGemmRewritePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(tosa::MatMulOp op) const { - OpBuilder b{op}; - SmallVector vec; - FailureOr maybeConv; - std::tie(std::ignore, maybeConv) = - getElementwiseRegion(op.getA(), b, nullptr, vec); + FailureOr> + match(tosa::MatMulOp op) const { + ElementwiseRegionFinder elementwiseRegionFinder; + elementwiseRegionFinder.visit(op.getA()); + FailureOr maybeConv = + elementwiseRegionFinder.getFirstGemmBasedOp(); if (succeeded(maybeConv)) LLVM_DEBUG(llvm::dbgs() << "conv = " << maybeConv.value() << "\n"); - else + else { LLVM_DEBUG(llvm::dbgs() << "conv not found\n"); + return failure(); + } - return maybeConv; + tosa::Conv2DOp firstConv = maybeConv.value(); + // bias not supported + if (!isConstantZero(firstConv.getBias())) { + op.emitOpError("bias not supported yet"); + return failure(); + } + return elementwiseRegionFinder; } - void rewrite(tosa::MatMulOp op, PatternRewriter &rewriter) const { + void rewrite( + tosa::MatMulOp op, + const ElementwiseRegionFinder &elementwiseRegionFinder, + PatternRewriter &rewriter) const { Location loc = op.getLoc(); auto outputType = cast(op.getType()); Value output = rewriter.create( loc, outputType, ValueRange{}); - SmallVector elementwiseOtherArgs; - - FailureOr maybeConv; - std::tie(std::ignore, maybeConv) = getElementwiseRegion( - op.getA(), rewriter, nullptr, elementwiseOtherArgs); // This is guaranteed by the matcher - tosa::Conv2DOp firstConv = maybeConv.value(); + tosa::Conv2DOp firstConv = + elementwiseRegionFinder.getFirstGemmBasedOp().value(); - // bias not supported - if (!isConstantZero(firstConv.getBias())) { - op.emitOpError("bias not supported yet"); - return; - } + SmallVector elementwiseOtherArgs = + elementwiseRegionFinder.getElementwiseArgs(); int64_t group = 1; if (auto attr = op->template getAttrOfType("group")) @@ -1145,7 +1150,7 @@ struct ConvElementwiseGemmRewritePattern commonConv(rewriter, op, firstConv.getInput(), firstConv.getWeight(), output, firstConv.getPadAttr(), firstConv.getStrideAttr(), firstConv.getDilationAttr(), group); - + auto firstGemmBlockIndex = elementwiseRegionFinder.getFirstGemmBlockIndex(); auto convElentwiseGemmOp = rewriter.create( loc, outputType, convFields.filterExp, convFields.inputExp, op.getB(), elementwiseOtherArgs, output, @@ -1153,7 +1158,8 @@ struct ConvElementwiseGemmRewritePattern /*oTransposed=*/nullptr, /*features=*/nullptr, convFields.pad, convFields.stride, convFields.dilation, /*params0=*/nullptr, /*params1=*/nullptr, - /*firstGemmIndices=*/rewriter.getDenseI64ArrayAttr({0})); + /*firstGemmIndices=*/ + rewriter.getDenseI64ArrayAttr(firstGemmBlockIndex)); addConvAttributes(rewriter, convElentwiseGemmOp, convFields); @@ -1162,30 +1168,20 @@ struct ConvElementwiseGemmRewritePattern { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(preSecondGemmElemwiseBlock); - Value res; - std::tie(res, std::ignore) = getElementwiseRegion( - op.getA(), rewriter, preSecondGemmElemwiseBlock, elementwiseOtherArgs, - loc, true); - RankedTensorType resTensorType = cast(res.getType()); - MemRefType resMemRefType = MemRefType::get( - resTensorType.getShape(), resTensorType.getElementType()); - Value resMemref = rewriter.create( - loc, cast(resMemRefType), res); - Value outMemref = - preSecondGemmElemwiseBlock->addArgument(resMemRefType, loc); - rewriter.create(loc, resMemref, outMemref); - rewriter.create(loc); + elementwiseRegionFinder.rewrite(op.getA(), rewriter, + preSecondGemmElemwiseBlock, loc); } rewriter.replaceOp(op, convElentwiseGemmOp.getResult()); } LogicalResult matchAndRewrite(tosa::MatMulOp op, PatternRewriter &rewriter) const override { - auto result = match(op); - if (result.succeeded()) { - rewrite(op, rewriter); + FailureOr> elemwiseFinder = + match(op); + if (succeeded(elemwiseFinder)) { + rewrite(op, elemwiseFinder.value(), rewriter); } - return result; + return elemwiseFinder; } }; @@ -1193,35 +1189,33 @@ struct GemmElementwiseGemmRewritePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(tosa::MatMulOp op) const { - OpBuilder b{op}; - SmallVector vec; - FailureOr maybeFirstMatMul; - std::tie(std::ignore, maybeFirstMatMul) = - getElementwiseRegion(op.getA(), b, nullptr, vec); - + FailureOr> + match(tosa::MatMulOp op) const { + ElementwiseRegionFinder elemwiseRegionFinder; + elemwiseRegionFinder.visit(op.getA()); + FailureOr maybeFirstMatMul = + elemwiseRegionFinder.getFirstGemmBasedOp(); if (succeeded(maybeFirstMatMul)) LLVM_DEBUG(llvm::dbgs() << "first matmul = " << maybeFirstMatMul.value() << "\n"); - else + else { LLVM_DEBUG(llvm::dbgs() << "first matmul not found\n"); - - return maybeFirstMatMul; + return failure(); + } + return elemwiseRegionFinder; } - void rewrite(tosa::MatMulOp op, PatternRewriter &rewriter) const { + void rewrite(tosa::MatMulOp op, + const ElementwiseRegionFinder &elemwiseFinder, + PatternRewriter &rewriter) const { Location loc = op.getLoc(); auto outputType = cast(op.getType()); Value output = rewriter.create( loc, outputType, ValueRange{}); - SmallVector elementwiseOtherArgs; - - FailureOr maybeFirstMatMul; - std::tie(std::ignore, maybeFirstMatMul) = - getElementwiseRegion(op.getA(), rewriter, nullptr, - elementwiseOtherArgs); + SmallVector elementwiseOtherArgs = + elemwiseFinder.getElementwiseArgs(); // This is guranteed by the matcher - tosa::MatMulOp firstMatMulOp = maybeFirstMatMul.value(); + tosa::MatMulOp firstMatMulOp = elemwiseFinder.getFirstGemmBasedOp().value(); rock::GemmElementwiseGemmOp gemmElentwiseGemmOp = rewriter.create( loc, outputType, firstMatMulOp.getA(), firstMatMulOp.getB(), @@ -1233,36 +1227,26 @@ struct GemmElementwiseGemmRewritePattern /*features=*/nullptr, /*params0=*/nullptr, /*params1=*/nullptr, /*firstGemmIndices=*/ - rewriter.getDenseI64ArrayAttr(ArrayRef({0}))); + rewriter.getDenseI64ArrayAttr(ArrayRef(elemwiseFinder.getFirstGemmBlockIndex()))); Block *preSecondGemmElemwiseBlock = &gemmElentwiseGemmOp.getPreSecondGemmBody().emplaceBlock(); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(preSecondGemmElemwiseBlock); - Value res; - std::tie(res, std::ignore) = getElementwiseRegion( - op.getA(), rewriter, preSecondGemmElemwiseBlock, elementwiseOtherArgs, - loc, true); - RankedTensorType resTensorType = cast(res.getType()); - MemRefType resMemRefType = MemRefType::get( - resTensorType.getShape(), resTensorType.getElementType()); - Value resMemref = rewriter.create( - loc, cast(resMemRefType), res); - Value outMemref = - preSecondGemmElemwiseBlock->addArgument(resMemRefType, loc); - rewriter.create(loc, resMemref, outMemref); - rewriter.create(loc); + elemwiseFinder.rewrite(op.getA(), rewriter, + preSecondGemmElemwiseBlock, loc); } rewriter.replaceOp(op, gemmElentwiseGemmOp.getResult()); } LogicalResult matchAndRewrite(tosa::MatMulOp op, PatternRewriter &rewriter) const override { - auto result = match(op); - if (result.succeeded()) { - rewrite(op, rewriter); + FailureOr> elemwiseFinder = + match(op); + if (succeeded(elemwiseFinder)) { + rewrite(op, elemwiseFinder.value(), rewriter); } - return result; + return elemwiseFinder; } }; @@ -1282,8 +1266,7 @@ struct AttentionMatcherValues { Value currentSeqLen; bool isCausal; Type softmaxType; - tosa::MatMulOp firstMatmulOp; - SmallVector elementwiseOtherArgs; + ElementwiseRegionFinder preSoftmaxElementwiseFinder; }; struct AttentionRewritePattern : public OpRewritePattern { @@ -1879,11 +1862,10 @@ struct AttentionRewritePattern : public OpRewritePattern { Value causalMaskInput = isCausal ? causal.value() : kvCacheInput; OpBuilder b{op}; - SmallVector elementwiseOtherArgs; - FailureOr maybeFirstMatMul; - std::tie(std::ignore, maybeFirstMatMul) = - getElementwiseRegion(causalMaskInput, b, nullptr, - elementwiseOtherArgs); + ElementwiseRegionFinder preSoftmaxElementwiseFinder; + preSoftmaxElementwiseFinder.visit(causalMaskInput); + FailureOr maybeFirstMatMul = + preSoftmaxElementwiseFinder.getFirstGemmBasedOp(); if (failed(maybeFirstMatMul)) { LLVM_DEBUG(llvm::dbgs() << "first matmul not found\n"); return failure(); @@ -1928,8 +1910,8 @@ struct AttentionRewritePattern : public OpRewritePattern { attentionMatcherValues.lse = lse; attentionMatcherValues.causalMaskInput = causalMaskInput; attentionMatcherValues.currentSeqLen = currentSeqLen; - attentionMatcherValues.firstMatmulOp = maybeFirstMatMul.value(); - attentionMatcherValues.elementwiseOtherArgs = elementwiseOtherArgs; + attentionMatcherValues.preSoftmaxElementwiseFinder = + preSoftmaxElementwiseFinder; return attentionMatcherValues; } @@ -1958,13 +1940,16 @@ struct AttentionRewritePattern : public OpRewritePattern { lseOut = rewriter.create(loc, lseType, ValueRange{}); } + ElementwiseRegionFinder preSoftmaxElementwiseFinder = + attentionMatcherValues.preSoftmaxElementwiseFinder; SmallVector elementwiseOtherArgs = - attentionMatcherValues.elementwiseOtherArgs; + preSoftmaxElementwiseFinder.getElementwiseArgs(); // causalMaskInput would be equal to kvCacheInput if there is no causal // mask and kvCacheInput would be same as softmaxInput if there is no // kv-cache. see match() for details Value causalMaskInput = attentionMatcherValues.causalMaskInput; - tosa::MatMulOp firstMatMulOp = attentionMatcherValues.firstMatmulOp; + tosa::MatMulOp firstMatMulOp = + preSoftmaxElementwiseFinder.getFirstGemmBasedOp().value(); Value currentSeqLen = attentionMatcherValues.currentSeqLen; bool isCausal = attentionMatcherValues.isCausal; TypeAttr softmaxTypeAttr = @@ -1978,6 +1963,9 @@ struct AttentionRewritePattern : public OpRewritePattern { op.getLoc(), currentSeqLen, reassocIndices); } UnitAttr causalAttr = isCausal ? rewriter.getUnitAttr() : nullptr; + ElementwiseRegionFinder + elemwiseRegion = + attentionMatcherValues.preSoftmaxElementwiseFinder; rock::AttentionOp attnOp = rewriter.create( loc, outputType, lseType, firstMatMulOp.getA(), firstMatMulOp.getB(), op.getB(), elementwiseOtherArgs, currentSeqLen, output, lseOut, @@ -1987,25 +1975,13 @@ struct AttentionRewritePattern : public OpRewritePattern { /*oTransposed=*/nullptr, causalAttr, /*features=*/nullptr, softmaxTypeAttr, /*params0=*/nullptr, /*params1=*/nullptr, - /*firstGemmIndices=*/rewriter.getDenseI64ArrayAttr({0})); - + /*firstGemmIndices=*/rewriter.getDenseI64ArrayAttr(elemwiseRegion.getFirstGemmBlockIndex())); Block *preSoftmaxElemwiseBlock = &attnOp.getPreSoftmaxBody().emplaceBlock(); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(preSoftmaxElemwiseBlock); - Value res; - std::tie(res, std::ignore) = getElementwiseRegion( - causalMaskInput, rewriter, preSoftmaxElemwiseBlock, - elementwiseOtherArgs, loc, true); - RankedTensorType resTensorType = cast(res.getType()); - MemRefType resMemRefType = MemRefType::get( - resTensorType.getShape(), resTensorType.getElementType()); - Value resMemref = rewriter.create( - loc, cast(resMemRefType), res); - Value outMemref = - preSoftmaxElemwiseBlock->addArgument(resMemRefType, loc); - rewriter.create(loc, resMemref, outMemref); - rewriter.create(loc); + elemwiseRegion.rewrite( + causalMaskInput, rewriter, preSoftmaxElemwiseBlock, loc); } tosa::AddOp addOp; Value expandedOutLse; diff --git a/mlir/test/fusion/pr-e2e/gemm-gemm/mixr-gemm-gemm-multiple-traces-to-first-gemm.mlir b/mlir/test/fusion/pr-e2e/gemm-gemm/mixr-gemm-gemm-multiple-traces-to-first-gemm.mlir new file mode 100644 index 000000000000..23ba4be5de6c --- /dev/null +++ b/mlir/test/fusion/pr-e2e/gemm-gemm/mixr-gemm-gemm-multiple-traces-to-first-gemm.mlir @@ -0,0 +1,23 @@ +// RUN: rocmlir-gen -fut mlir_gemm_gemm --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_gemm_gemm_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// CHECK: [1 1 1] +module { + func.func @mlir_gemm_gemm(%arg0: !migraphx.shaped<124x976xf32, 976x1>, %arg1: !migraphx.shaped<976x664xf32, 664x1>, %arg2: !migraphx.shaped<664xf32, 1>, %arg3: !migraphx.shaped<664xf32, 1>, %arg4: !migraphx.shaped<664xf32, 1>, %arg5: !migraphx.shaped<664xf32, 1>, %arg6: !migraphx.shaped<664x88xf32, 88x1>) -> !migraphx.shaped<124x88xf32, 88x1> { + %0 = migraphx.literal(dense<1.000000e+00> : tensor<1xf32>) : <1xf32, 0> + %1 = migraphx.dot %arg0, %arg1 : <124x976xf32, 976x1>, <976x664xf32, 664x1> -> <124x664xf32, 664x1> + %2 = migraphx.multibroadcast %arg2 {out_dyn_dims = [], out_lens = [124, 664]} : <664xf32, 1> -> <124x664xf32, 0x1> + %3 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [124, 664]} : <664xf32, 1> -> <124x664xf32, 0x1> + %4 = migraphx.multibroadcast %arg4 {out_dyn_dims = [], out_lens = [124, 664]} : <664xf32, 1> -> <124x664xf32, 0x1> + %5 = migraphx.multibroadcast %arg5 {out_dyn_dims = [], out_lens = [124, 664]} : <664xf32, 1> -> <124x664xf32, 0x1> + %6 = migraphx.add %2, %1 : <124x664xf32, 0x1>, <124x664xf32, 664x1> -> <124x664xf32, 664x1> + %7 = migraphx.mul %6, %3 : <124x664xf32, 664x1>, <124x664xf32, 0x1> -> <124x664xf32, 664x1> + %8 = migraphx.add %7, %4 : <124x664xf32, 664x1>, <124x664xf32, 0x1> -> <124x664xf32, 664x1> + %9 = migraphx.sigmoid %8 : <124x664xf32, 664x1> -> <124x664xf32, 664x1> + %10 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [124, 664]} : <1xf32, 0> -> <124x664xf32, 0x0> + %11 = migraphx.sub %10, %9 : <124x664xf32, 0x0>, <124x664xf32, 664x1> -> <124x664xf32, 664x1> + %12 = migraphx.mul %11, %5 : <124x664xf32, 664x1>, <124x664xf32, 0x1> -> <124x664xf32, 664x1> + %13 = migraphx.add %9, %12 : <124x664xf32, 664x1>, <124x664xf32, 664x1> -> <124x664xf32, 664x1> + %14 = migraphx.mul %13, %6 : <124x664xf32, 664x1>, <124x664xf32, 664x1> -> <124x664xf32, 664x1> + %15 = migraphx.dot %14, %arg6 : <124x664xf32, 664x1>, <664x88xf32, 88x1> -> <124x88xf32, 88x1> + return %15 : !migraphx.shaped<124x88xf32, 88x1> + } +} From dc9212455d38d826ec60ef6a2af5f1e51fc244fd Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 26 Aug 2025 20:43:10 +0000 Subject: [PATCH 2/7] formatting --- mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index dc7113848301..48414208970e 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -47,7 +47,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" -#include #include #define DEBUG_TYPE "convert-tosa-to-rock" @@ -498,7 +497,7 @@ static Operation *getConvOp(Operation *op) { template struct ElementwiseRegionFinder { void visit(Value input) { - if (visitedSet.contains(input)) + if (visitedSet.contains(input)) return; visitedSet.insert(input); OpT fusionOp = input.getDefiningOp(); @@ -1227,14 +1226,15 @@ struct GemmElementwiseGemmRewritePattern /*features=*/nullptr, /*params0=*/nullptr, /*params1=*/nullptr, /*firstGemmIndices=*/ - rewriter.getDenseI64ArrayAttr(ArrayRef(elemwiseFinder.getFirstGemmBlockIndex()))); + rewriter.getDenseI64ArrayAttr( + ArrayRef(elemwiseFinder.getFirstGemmBlockIndex()))); Block *preSecondGemmElemwiseBlock = &gemmElentwiseGemmOp.getPreSecondGemmBody().emplaceBlock(); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(preSecondGemmElemwiseBlock); - elemwiseFinder.rewrite(op.getA(), rewriter, - preSecondGemmElemwiseBlock, loc); + elemwiseFinder.rewrite(op.getA(), rewriter, preSecondGemmElemwiseBlock, + loc); } rewriter.replaceOp(op, gemmElentwiseGemmOp.getResult()); } @@ -1963,9 +1963,8 @@ struct AttentionRewritePattern : public OpRewritePattern { op.getLoc(), currentSeqLen, reassocIndices); } UnitAttr causalAttr = isCausal ? rewriter.getUnitAttr() : nullptr; - ElementwiseRegionFinder - elemwiseRegion = - attentionMatcherValues.preSoftmaxElementwiseFinder; + ElementwiseRegionFinder elemwiseRegion = + attentionMatcherValues.preSoftmaxElementwiseFinder; rock::AttentionOp attnOp = rewriter.create( loc, outputType, lseType, firstMatMulOp.getA(), firstMatMulOp.getB(), op.getB(), elementwiseOtherArgs, currentSeqLen, output, lseOut, @@ -1975,13 +1974,14 @@ struct AttentionRewritePattern : public OpRewritePattern { /*oTransposed=*/nullptr, causalAttr, /*features=*/nullptr, softmaxTypeAttr, /*params0=*/nullptr, /*params1=*/nullptr, - /*firstGemmIndices=*/rewriter.getDenseI64ArrayAttr(elemwiseRegion.getFirstGemmBlockIndex())); + /*firstGemmIndices=*/ + rewriter.getDenseI64ArrayAttr(elemwiseRegion.getFirstGemmBlockIndex())); Block *preSoftmaxElemwiseBlock = &attnOp.getPreSoftmaxBody().emplaceBlock(); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(preSoftmaxElemwiseBlock); - elemwiseRegion.rewrite( - causalMaskInput, rewriter, preSoftmaxElemwiseBlock, loc); + elemwiseRegion.rewrite(causalMaskInput, rewriter, preSoftmaxElemwiseBlock, + loc); } tosa::AddOp addOp; Value expandedOutLse; From 6d11fa2ffb1db09a84a62857ca446ce97ac4679e Mon Sep 17 00:00:00 2001 From: Umang Yadav <29876643+umangyadav@users.noreply.github.com> Date: Tue, 26 Aug 2025 16:53:48 -0400 Subject: [PATCH 3/7] Update mlir/lib/Conversion/TosaToRock/TosaToRock.cpp --- mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index 48414208970e..91d74433f5a4 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -545,7 +545,7 @@ struct ElementwiseRegionFinder { } SmallVector getElementwiseArgs() const { - // remove fusionValue from the candidates + // ElementwiseArgs doesn't contain output from the first gemm explictly. Therefore remove it. SmallVector elementwiseArgs = blockArgCandidates; uint64_t firstGemmBlockIndex = getFirstGemmBlockIndex(); elementwiseArgs.erase(elementwiseArgs.begin() + firstGemmBlockIndex); From bfb760560ba5edfdbce6c4ef94b790c136886033 Mon Sep 17 00:00:00 2001 From: Umang Yadav <29876643+umangyadav@users.noreply.github.com> Date: Tue, 26 Aug 2025 16:55:05 -0400 Subject: [PATCH 4/7] Update mlir/lib/Conversion/TosaToRock/TosaToRock.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index 91d74433f5a4..1f7b169d4f43 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -1158,7 +1158,7 @@ struct ConvElementwiseGemmRewritePattern convFields.stride, convFields.dilation, /*params0=*/nullptr, /*params1=*/nullptr, /*firstGemmIndices=*/ - rewriter.getDenseI64ArrayAttr(firstGemmBlockIndex)); + rewriter.getDenseI64ArrayAttr(ArrayRef({firstGemmBlockIndex}))); addConvAttributes(rewriter, convElentwiseGemmOp, convFields); From b16900e20eac4eef88ea226ab9a65fc59368fa7a Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 26 Aug 2025 21:00:41 +0000 Subject: [PATCH 5/7] Remove unused param --- mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index 1f7b169d4f43..7bffbf17d20c 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -559,7 +559,7 @@ struct ElementwiseRegionFinder { } void rewrite(Value input, OpBuilder ®ionBuilder, Block *block, - Location loc, uint32_t recDepth = 0) const { + Location loc) const { PatternRewriter::InsertionGuard guard(regionBuilder); regionBuilder.setInsertionPointToEnd(block); IRMapping mapper; From be1866b68d957bf54e6670ae17c89c5658c9fbc4 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 26 Aug 2025 21:06:10 +0000 Subject: [PATCH 6/7] add some comments --- mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index 7bffbf17d20c..022b1f97b4cf 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -494,8 +494,19 @@ static Operation *getConvOp(Operation *op) { return (isa_and_nonnull(op)) ? op : nullptr; } +/* +GEMM+GEMM based ops can have elementwise region between first gemm and second +gemm. This helps with matching such GEMM+GEMM ops and also constructing the +elementwise region afterwards. +*/ template struct ElementwiseRegionFinder { + /* + This is simple DFS traversal to find out if it can hit gemm/conv op from the + input. It keeps track of visited nodes to avoid cycles. It caches visited ops + in topological order for rewrite. It also caches constant values and block + argument candidates which will be used during rewrite. + */ void visit(Value input) { if (visitedSet.contains(input)) return; @@ -571,7 +582,8 @@ struct ElementwiseRegionFinder { auto newBlockArg = addBlockArgument(regionBuilder, v, block, loc); mapper.map(v, newBlockArg); } - // make sure firstGemmInput is passed through + // make sure firstGemmBasedVal is passed as blockArgument for it is always + // present Value lastRes = mapper.lookup(firstGemmBasedVal); for (Operation *op : visitedOps) { auto *newOp = regionBuilder.clone(*op, mapper); From c707440f53307cc1012501e7a0a57d3456123702 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 26 Aug 2025 23:20:34 +0000 Subject: [PATCH 7/7] Revert back copilot reivew comment --- mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index 022b1f97b4cf..af999920ff9f 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -556,14 +556,15 @@ struct ElementwiseRegionFinder { } SmallVector getElementwiseArgs() const { - // ElementwiseArgs doesn't contain output from the first gemm explictly. Therefore remove it. + // ElementwiseArgs doesn't contain output from the first gemm explictly. + // Therefore remove it. SmallVector elementwiseArgs = blockArgCandidates; uint64_t firstGemmBlockIndex = getFirstGemmBlockIndex(); elementwiseArgs.erase(elementwiseArgs.begin() + firstGemmBlockIndex); return elementwiseArgs; } - uint64_t getFirstGemmBlockIndex() const { + int64_t getFirstGemmBlockIndex() const { return std::find_if(blockArgCandidates.begin(), blockArgCandidates.end(), [this](Value v) { return v == firstGemmBasedVal; }) - blockArgCandidates.begin(); @@ -1170,7 +1171,7 @@ struct ConvElementwiseGemmRewritePattern convFields.stride, convFields.dilation, /*params0=*/nullptr, /*params1=*/nullptr, /*firstGemmIndices=*/ - rewriter.getDenseI64ArrayAttr(ArrayRef({firstGemmBlockIndex}))); + rewriter.getDenseI64ArrayAttr(firstGemmBlockIndex)); addConvAttributes(rewriter, convElentwiseGemmOp, convFields); @@ -1227,6 +1228,7 @@ struct GemmElementwiseGemmRewritePattern elemwiseFinder.getElementwiseArgs(); // This is guranteed by the matcher tosa::MatMulOp firstMatMulOp = elemwiseFinder.getFirstGemmBasedOp().value(); + int64_t firstGemmBlockIndex = elemwiseFinder.getFirstGemmBlockIndex(); rock::GemmElementwiseGemmOp gemmElentwiseGemmOp = rewriter.create( loc, outputType, firstMatMulOp.getA(), firstMatMulOp.getB(), @@ -1238,8 +1240,7 @@ struct GemmElementwiseGemmRewritePattern /*features=*/nullptr, /*params0=*/nullptr, /*params1=*/nullptr, /*firstGemmIndices=*/ - rewriter.getDenseI64ArrayAttr( - ArrayRef(elemwiseFinder.getFirstGemmBlockIndex()))); + rewriter.getDenseI64ArrayAttr(firstGemmBlockIndex)); Block *preSecondGemmElemwiseBlock = &gemmElentwiseGemmOp.getPreSecondGemmBody().emplaceBlock(); { @@ -1977,6 +1978,7 @@ struct AttentionRewritePattern : public OpRewritePattern { UnitAttr causalAttr = isCausal ? rewriter.getUnitAttr() : nullptr; ElementwiseRegionFinder elemwiseRegion = attentionMatcherValues.preSoftmaxElementwiseFinder; + int64_t firstGemmBlockIndex = elemwiseRegion.getFirstGemmBlockIndex(); rock::AttentionOp attnOp = rewriter.create( loc, outputType, lseType, firstMatMulOp.getA(), firstMatMulOp.getB(), op.getB(), elementwiseOtherArgs, currentSeqLen, output, lseOut, @@ -1987,7 +1989,7 @@ struct AttentionRewritePattern : public OpRewritePattern { /*features=*/nullptr, softmaxTypeAttr, /*params0=*/nullptr, /*params1=*/nullptr, /*firstGemmIndices=*/ - rewriter.getDenseI64ArrayAttr(elemwiseRegion.getFirstGemmBlockIndex())); + rewriter.getDenseI64ArrayAttr(firstGemmBlockIndex)); Block *preSoftmaxElemwiseBlock = &attnOp.getPreSoftmaxBody().emplaceBlock(); { PatternRewriter::InsertionGuard guard(rewriter);