diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index 9eb1b9be8f3d..63d75e68ef2b 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -28,6 +28,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/Block.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -514,109 +515,121 @@ static Operation *getConvOp(Operation *op) { return (isa_and_nonnull(op)) ? op : nullptr; } -// This function traverse an upward tree where the root is the input. -// It traverses the tree until it hit the gemm/conv or last elementwise -// operation that may or maynot be interleaved with reshape-like ops. Note -// there is a TODO to explore relaxing reshape-like ops constraints to more -// of rock.transforms. (See the implementation for the TODO) +/* +GEMM+GEMM based ops can have elementwise region between first gemm and second +gemm. This helps with matching such GEMM+GEMM ops and also constructing the +elementwise region afterwards. +*/ template -static std::tuple> -getElementwiseRegion(Value input, OpBuilder ®ionBuilder, Block *block, - SmallVector &elementwiseArgs, - std::optional loc = std::nullopt, - bool doRewrite = false, int recDepth = 0) { - PatternRewriter::InsertionGuard guard(regionBuilder); - regionBuilder.setInsertionPointToEnd(block); - // If the matmul/conv is found, we return this information to the - // root. - LLVM_DEBUG(llvm::dbgs() << std::string(recDepth, '\t') - << "getElementwiseRegion:input=" << input << "\n"); - - OpT fusionOp = input.getDefiningOp(); - Operation *op = input.getDefiningOp(); - // we need to traverse tranposes if it's conv2d - if (std::is_same_v && op) { - Operation *convOp = getConvOp(op); - if (convOp) - fusionOp = cast(convOp); - } - - if (fusionOp) { - Value fusionMemRef; - if (doRewrite) { - fusionMemRef = addBlockArgument(regionBuilder, input, block, loc.value()); - rock::RockGemmGemmWrapperInterface gemmGemmLikeOp = - cast(block->getParentOp()); - gemmGemmLikeOp.setFirstGemmIndices( - {static_cast(block->getArguments().size() - 1)}); - } - LLVM_DEBUG(llvm::dbgs() << std::string(recDepth, '\t') - << "matmul/conv found. terminating recursion.\n"); - return {fusionMemRef, fusionOp}; - } - if (tosa::ConstOp constOp = input.getDefiningOp()) { - Value newConstOpRes; - if (doRewrite) { - auto *newConstOp = regionBuilder.clone(*constOp); - newConstOpRes = newConstOp->getResult(0); - } - LLVM_DEBUG(llvm::dbgs() << std::string(recDepth, '\t') - << "const found. terminating recursion.\n"); - return {newConstOpRes, failure()}; - } - - // Right now, this is a bit restricted that we only allow reshape-like - // ops between in the elementwise tree that get fused to the fusion point. - // TODO: however, the latest code gridwise-gemm-to-blockwise should tackle - // more cases. The absolute restriction is gemm0Output to Linalg block - // should contain invertible transforms, but that's future work. - if (!op || (!isElementwiseOp(op) && - !isa(op))) { - Value blockArg; - if (doRewrite) { - blockArg = addBlockArgument(regionBuilder, input, block, loc.value()); - } - elementwiseArgs.push_back(input); - LLVM_DEBUG(llvm::dbgs() - << std::string(recDepth, '\t') - << "unsupported region op found. terminating recursion.\n"); - return {blockArg, failure()}; - } - // Following section recursively calls into the left and right - // sub-tree to grab as much of the elementwise tree rooted on softmax - // input. - mlir::IRMapping mapper; - SmallVector newOperands; - - FailureOr maybeFusionOp = failure(); - for (auto operand : op->getOperands()) { - auto [result, maybeSubTreeFusionOp] = getElementwiseRegion( - operand, regionBuilder, block, elementwiseArgs, loc, doRewrite, - recDepth + 1); - mapper.map(operand, result); - newOperands.push_back(result); - if (succeeded(maybeSubTreeFusionOp)) { - maybeFusionOp = maybeSubTreeFusionOp; - } - } - - Value res; - if (doRewrite) { - auto *newOp = regionBuilder.clone(*op, mapper); - res = newOp->getResult(0); - } - // We convey to the caller the result - // of the cloning as well if this subtree - // contains the first matmul/conv. - if (succeeded(maybeFusionOp)) { - LLVM_DEBUG(llvm::dbgs() << std::string(recDepth, '\t') - << "a subtree have a matmul/conv in it.\n"); - return {res, maybeFusionOp}; - } - LLVM_DEBUG(llvm::dbgs() << std::string(recDepth, '\t') - << "none of subtress have a matmul/conv in it.\n"); - return {res, failure()}; -} +struct ElementwiseRegionFinder { + /* + This is simple DFS traversal to find out if it can hit gemm/conv op from the + input. It keeps track of visited nodes to avoid cycles. It caches visited ops + in topological order for rewrite. It also caches constant values and block + argument candidates which will be used during rewrite. + */ + void visit(Value input) { + if (visitedSet.contains(input)) + return; + visitedSet.insert(input); + OpT fusionOp = input.getDefiningOp(); + Operation *op = input.getDefiningOp(); + // we need to traverse tranposes if it's conv2d + if (std::is_same_v && op) { + Operation *convOp = getConvOp(op); + if (convOp) + fusionOp = cast(convOp); + } + if (fusionOp) { + firstGemmBasedOp = fusionOp; + firstGemmBasedVal = input; + // cache blockArgCandidates for rewrite + blockArgCandidates.push_back(input); + return; + } + if (op && dyn_cast(op)) { + constantVals.push_back(input); + return; + } + // Right now, this is a bit restricted that we only allow reshape-like + // ops between in the elementwise tree that get fused to the fusion point. + // TODO: however, the latest code gridwise-gemm-to-blockwise should tackle + // more cases. The absolute restriction is gemm0Output to Linalg block + // should contain invertible transforms, but that's future work. + if (!op || (!isElementwiseOp(op) && + !isa(op))) { + // cache blockArgCandidates for rewrite + blockArgCandidates.push_back(input); + return; + } + for (Value operand : op->getOperands()) { + // do a DFS on each operand + visit(operand); + } + // keep topological order for rewrite + visitedOps.push_back(op); + } + + FailureOr getFirstGemmBasedOp() const { + if (!firstGemmBasedOp) + return failure(); + return firstGemmBasedOp; + } + + SmallVector getElementwiseArgs() const { + // ElementwiseArgs doesn't contain output from the first gemm explictly. + // Therefore remove it. + SmallVector elementwiseArgs = blockArgCandidates; + uint64_t firstGemmBlockIndex = getFirstGemmBlockIndex(); + elementwiseArgs.erase(elementwiseArgs.begin() + firstGemmBlockIndex); + return elementwiseArgs; + } + + int64_t getFirstGemmBlockIndex() const { + return std::find_if(blockArgCandidates.begin(), blockArgCandidates.end(), + [this](Value v) { return v == firstGemmBasedVal; }) - + blockArgCandidates.begin(); + } + + void rewrite(Value input, OpBuilder ®ionBuilder, Block *block, + Location loc) const { + PatternRewriter::InsertionGuard guard(regionBuilder); + regionBuilder.setInsertionPointToEnd(block); + IRMapping mapper; + for (Value v : constantVals) { + auto *newConstOp = regionBuilder.clone(*v.getDefiningOp()); + mapper.map(v, newConstOp->getResult(0)); + } + for (Value v : blockArgCandidates) { + auto newBlockArg = addBlockArgument(regionBuilder, v, block, loc); + mapper.map(v, newBlockArg); + } + // make sure firstGemmBasedVal is passed as blockArgument for it is always + // present + Value lastRes = mapper.lookup(firstGemmBasedVal); + for (Operation *op : visitedOps) { + auto *newOp = regionBuilder.clone(*op, mapper); + lastRes = newOp->getResult(0); + mapper.map(lastRes, newOp->getResult(0)); + } + RankedTensorType resTensorType = cast(lastRes.getType()); + MemRefType resMemRefType = MemRefType::get(resTensorType.getShape(), + resTensorType.getElementType()); + Value resMemref = regionBuilder.create( + loc, cast(resMemRefType), lastRes); + Value outMemref = block->addArgument(resMemRefType, loc); + regionBuilder.create(loc, resMemref, outMemref); + regionBuilder.create(loc); + } + +private: + OpT firstGemmBasedOp = nullptr; + Value firstGemmBasedVal = nullptr; + DenseSet visitedSet; + SmallVector blockArgCandidates; + SmallVector constantVals; + SmallVector visitedOps; +}; template class ConvConverter final : public OpConversionPattern { @@ -1132,22 +1145,33 @@ struct ConvElementwiseGemmRewritePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(tosa::MatMulOp op) const { - OpBuilder b{op}; - SmallVector vec; - FailureOr maybeConv; - std::tie(std::ignore, maybeConv) = - getElementwiseRegion(op.getA(), b, nullptr, vec); + FailureOr> + match(tosa::MatMulOp op) const { + ElementwiseRegionFinder elementwiseRegionFinder; + elementwiseRegionFinder.visit(op.getA()); + FailureOr maybeConv = + elementwiseRegionFinder.getFirstGemmBasedOp(); if (succeeded(maybeConv)) LLVM_DEBUG(llvm::dbgs() << "conv = " << maybeConv.value() << "\n"); - else + else { LLVM_DEBUG(llvm::dbgs() << "conv not found\n"); + return failure(); + } - return maybeConv; + tosa::Conv2DOp firstConv = maybeConv.value(); + // bias not supported + if (!isConstantZero(firstConv.getBias())) { + op.emitOpError("bias not supported yet"); + return failure(); + } + return elementwiseRegionFinder; } - void rewrite(tosa::MatMulOp op, PatternRewriter &rewriter) const { + void rewrite( + tosa::MatMulOp op, + const ElementwiseRegionFinder &elementwiseRegionFinder, + PatternRewriter &rewriter) const { Location loc = op.getLoc(); auto outputType = cast(op.getType()); Value output = rewriter.create( @@ -1156,20 +1180,13 @@ struct ConvElementwiseGemmRewritePattern std::optional numCu; rock::GemmFeatures features; std::tie(arch, numCu, features) = getArchAttributes(op, op.getType()); - SmallVector elementwiseOtherArgs; - - FailureOr maybeConv; - std::tie(std::ignore, maybeConv) = getElementwiseRegion( - op.getA(), rewriter, nullptr, elementwiseOtherArgs); // This is guaranteed by the matcher - tosa::Conv2DOp firstConv = maybeConv.value(); + tosa::Conv2DOp firstConv = + elementwiseRegionFinder.getFirstGemmBasedOp().value(); - // bias not supported - if (!isConstantZero(firstConv.getBias())) { - op.emitOpError("bias not supported yet"); - return; - } + SmallVector elementwiseOtherArgs = + elementwiseRegionFinder.getElementwiseArgs(); int64_t group = 1; if (auto attr = op->template getAttrOfType("group")) @@ -1179,6 +1196,7 @@ struct ConvElementwiseGemmRewritePattern output, firstConv.getPadAttr(), firstConv.getStrideAttr(), firstConv.getDilationAttr(), group, numCu, features); + auto firstGemmBlockIndex = elementwiseRegionFinder.getFirstGemmBlockIndex(); auto convElentwiseGemmOp = rewriter.create( loc, outputType, convFields.filterExp, convFields.inputExp, op.getB(), elementwiseOtherArgs, output, @@ -1186,7 +1204,8 @@ struct ConvElementwiseGemmRewritePattern /*oTransposed=*/nullptr, arch, convFields.features, convFields.numCU, convFields.pad, convFields.stride, convFields.dilation, /*params0=*/nullptr, /*params1=*/nullptr, - /*firstGemmIndices=*/rewriter.getDenseI64ArrayAttr({0})); + /*firstGemmIndices=*/ + rewriter.getDenseI64ArrayAttr(firstGemmBlockIndex)); addConvAttributes(rewriter, convElentwiseGemmOp, convFields); @@ -1195,30 +1214,20 @@ struct ConvElementwiseGemmRewritePattern { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(preSecondGemmElemwiseBlock); - Value res; - std::tie(res, std::ignore) = getElementwiseRegion( - op.getA(), rewriter, preSecondGemmElemwiseBlock, elementwiseOtherArgs, - loc, true); - RankedTensorType resTensorType = cast(res.getType()); - MemRefType resMemRefType = MemRefType::get( - resTensorType.getShape(), resTensorType.getElementType()); - Value resMemref = rewriter.create( - loc, cast(resMemRefType), res); - Value outMemref = - preSecondGemmElemwiseBlock->addArgument(resMemRefType, loc); - rewriter.create(loc, resMemref, outMemref); - rewriter.create(loc); + elementwiseRegionFinder.rewrite(op.getA(), rewriter, + preSecondGemmElemwiseBlock, loc); } rewriter.replaceOp(op, convElentwiseGemmOp.getResult()); } LogicalResult matchAndRewrite(tosa::MatMulOp op, PatternRewriter &rewriter) const override { - auto result = match(op); - if (result.succeeded()) { - rewrite(op, rewriter); + FailureOr> elemwiseFinder = + match(op); + if (succeeded(elemwiseFinder)) { + rewrite(op, elemwiseFinder.value(), rewriter); } - return result; + return elemwiseFinder; } }; @@ -1226,42 +1235,43 @@ struct GemmElementwiseGemmRewritePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(tosa::MatMulOp op) const { - OpBuilder b{op}; - SmallVector vec; - FailureOr maybeFirstMatMul; - std::tie(std::ignore, maybeFirstMatMul) = - getElementwiseRegion(op.getA(), b, nullptr, vec); - + FailureOr> + match(tosa::MatMulOp op) const { + ElementwiseRegionFinder elemwiseRegionFinder; + elemwiseRegionFinder.visit(op.getA()); + FailureOr maybeFirstMatMul = + elemwiseRegionFinder.getFirstGemmBasedOp(); if (succeeded(maybeFirstMatMul)) LLVM_DEBUG(llvm::dbgs() << "first matmul = " << maybeFirstMatMul.value() << "\n"); - else + else { LLVM_DEBUG(llvm::dbgs() << "first matmul not found\n"); - - return maybeFirstMatMul; + return failure(); + } + return elemwiseRegionFinder; } - void rewrite(tosa::MatMulOp op, PatternRewriter &rewriter) const { + void rewrite(tosa::MatMulOp op, + const ElementwiseRegionFinder &elemwiseFinder, + PatternRewriter &rewriter) const { Location loc = op.getLoc(); auto outputType = cast(op.getType()); Value output = rewriter.create( loc, outputType, ValueRange{}); + StringAttr arch; std::optional numCu; rock::GemmFeatures features; std::tie(arch, numCu, features) = getArchAttributes(op, op.getType()); - SmallVector elementwiseOtherArgs; - - FailureOr maybeFirstMatMul; - std::tie(std::ignore, maybeFirstMatMul) = - getElementwiseRegion(op.getA(), rewriter, nullptr, - elementwiseOtherArgs); - // This is guranteed by the matcher - tosa::MatMulOp firstMatMulOp = maybeFirstMatMul.value(); IntegerAttr numCUAttr = numCu.has_value() ? rewriter.getI32IntegerAttr(numCu.value()) : nullptr; + SmallVector elementwiseOtherArgs = + elemwiseFinder.getElementwiseArgs(); + // This is guranteed by the matcher + tosa::MatMulOp firstMatMulOp = elemwiseFinder.getFirstGemmBasedOp().value(); + int64_t firstGemmBlockIndex = elemwiseFinder.getFirstGemmBlockIndex(); + rock::GemmElementwiseGemmOp gemmElentwiseGemmOp = rewriter.create( loc, outputType, firstMatMulOp.getA(), firstMatMulOp.getB(), @@ -1273,36 +1283,26 @@ struct GemmElementwiseGemmRewritePattern rewriter.getAttr(features), numCUAttr, /*params0=*/nullptr, /*params1=*/nullptr, /*firstGemmIndices=*/ - rewriter.getDenseI64ArrayAttr(ArrayRef({0}))); + rewriter.getDenseI64ArrayAttr(firstGemmBlockIndex)); Block *preSecondGemmElemwiseBlock = &gemmElentwiseGemmOp.getPreSecondGemmBody().emplaceBlock(); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(preSecondGemmElemwiseBlock); - Value res; - std::tie(res, std::ignore) = getElementwiseRegion( - op.getA(), rewriter, preSecondGemmElemwiseBlock, elementwiseOtherArgs, - loc, true); - RankedTensorType resTensorType = cast(res.getType()); - MemRefType resMemRefType = MemRefType::get( - resTensorType.getShape(), resTensorType.getElementType()); - Value resMemref = rewriter.create( - loc, cast(resMemRefType), res); - Value outMemref = - preSecondGemmElemwiseBlock->addArgument(resMemRefType, loc); - rewriter.create(loc, resMemref, outMemref); - rewriter.create(loc); + elemwiseFinder.rewrite(op.getA(), rewriter, preSecondGemmElemwiseBlock, + loc); } rewriter.replaceOp(op, gemmElentwiseGemmOp.getResult()); } LogicalResult matchAndRewrite(tosa::MatMulOp op, PatternRewriter &rewriter) const override { - auto result = match(op); - if (result.succeeded()) { - rewrite(op, rewriter); + FailureOr> elemwiseFinder = + match(op); + if (succeeded(elemwiseFinder)) { + rewrite(op, elemwiseFinder.value(), rewriter); } - return result; + return elemwiseFinder; } }; @@ -1322,8 +1322,7 @@ struct AttentionMatcherValues { Value currentSeqLen; bool isCausal; Type softmaxType; - tosa::MatMulOp firstMatmulOp; - SmallVector elementwiseOtherArgs; + ElementwiseRegionFinder preSoftmaxElementwiseFinder; }; struct AttentionRewritePattern : public OpRewritePattern { @@ -1919,11 +1918,10 @@ struct AttentionRewritePattern : public OpRewritePattern { Value causalMaskInput = isCausal ? causal.value() : kvCacheInput; OpBuilder b{op}; - SmallVector elementwiseOtherArgs; - FailureOr maybeFirstMatMul; - std::tie(std::ignore, maybeFirstMatMul) = - getElementwiseRegion(causalMaskInput, b, nullptr, - elementwiseOtherArgs); + ElementwiseRegionFinder preSoftmaxElementwiseFinder; + preSoftmaxElementwiseFinder.visit(causalMaskInput); + FailureOr maybeFirstMatMul = + preSoftmaxElementwiseFinder.getFirstGemmBasedOp(); if (failed(maybeFirstMatMul)) { LLVM_DEBUG(llvm::dbgs() << "first matmul not found\n"); return failure(); @@ -1968,8 +1966,8 @@ struct AttentionRewritePattern : public OpRewritePattern { attentionMatcherValues.lse = lse; attentionMatcherValues.causalMaskInput = causalMaskInput; attentionMatcherValues.currentSeqLen = currentSeqLen; - attentionMatcherValues.firstMatmulOp = maybeFirstMatMul.value(); - attentionMatcherValues.elementwiseOtherArgs = elementwiseOtherArgs; + attentionMatcherValues.preSoftmaxElementwiseFinder = + preSoftmaxElementwiseFinder; return attentionMatcherValues; } @@ -2002,13 +2000,16 @@ struct AttentionRewritePattern : public OpRewritePattern { std::optional numCu; rock::GemmFeatures features; std::tie(arch, numCu, features) = getArchAttributes(op, op.getType()); + ElementwiseRegionFinder preSoftmaxElementwiseFinder = + attentionMatcherValues.preSoftmaxElementwiseFinder; SmallVector elementwiseOtherArgs = - attentionMatcherValues.elementwiseOtherArgs; + preSoftmaxElementwiseFinder.getElementwiseArgs(); // causalMaskInput would be equal to kvCacheInput if there is no causal // mask and kvCacheInput would be same as softmaxInput if there is no // kv-cache. see match() for details Value causalMaskInput = attentionMatcherValues.causalMaskInput; - tosa::MatMulOp firstMatMulOp = attentionMatcherValues.firstMatmulOp; + tosa::MatMulOp firstMatMulOp = + preSoftmaxElementwiseFinder.getFirstGemmBasedOp().value(); Value currentSeqLen = attentionMatcherValues.currentSeqLen; bool isCausal = attentionMatcherValues.isCausal; TypeAttr softmaxTypeAttr = @@ -2024,6 +2025,9 @@ struct AttentionRewritePattern : public OpRewritePattern { op.getLoc(), currentSeqLen, reassocIndices); } UnitAttr causalAttr = isCausal ? rewriter.getUnitAttr() : nullptr; + ElementwiseRegionFinder elemwiseRegion = + attentionMatcherValues.preSoftmaxElementwiseFinder; + int64_t firstGemmBlockIndex = elemwiseRegion.getFirstGemmBlockIndex(); rock::AttentionOp attnOp = rewriter.create( loc, outputType, lseType, firstMatMulOp.getA(), firstMatMulOp.getB(), op.getB(), elementwiseOtherArgs, currentSeqLen, output, lseOut, @@ -2034,25 +2038,14 @@ struct AttentionRewritePattern : public OpRewritePattern { rewriter.getAttr(features), softmaxTypeAttr, numCUAttr, /*params0=*/nullptr, /*params1=*/nullptr, - /*firstGemmIndices=*/rewriter.getDenseI64ArrayAttr({0})); - + /*firstGemmIndices=*/ + rewriter.getDenseI64ArrayAttr(firstGemmBlockIndex)); Block *preSoftmaxElemwiseBlock = &attnOp.getPreSoftmaxBody().emplaceBlock(); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(preSoftmaxElemwiseBlock); - Value res; - std::tie(res, std::ignore) = getElementwiseRegion( - causalMaskInput, rewriter, preSoftmaxElemwiseBlock, - elementwiseOtherArgs, loc, true); - RankedTensorType resTensorType = cast(res.getType()); - MemRefType resMemRefType = MemRefType::get( - resTensorType.getShape(), resTensorType.getElementType()); - Value resMemref = rewriter.create( - loc, cast(resMemRefType), res); - Value outMemref = - preSoftmaxElemwiseBlock->addArgument(resMemRefType, loc); - rewriter.create(loc, resMemref, outMemref); - rewriter.create(loc); + elemwiseRegion.rewrite(causalMaskInput, rewriter, preSoftmaxElemwiseBlock, + loc); } tosa::AddOp addOp; Value expandedOutLse; diff --git a/mlir/test/fusion/pr-e2e/gemm-gemm/mixr-gemm-gemm-multiple-traces-to-first-gemm.mlir b/mlir/test/fusion/pr-e2e/gemm-gemm/mixr-gemm-gemm-multiple-traces-to-first-gemm.mlir new file mode 100644 index 000000000000..23ba4be5de6c --- /dev/null +++ b/mlir/test/fusion/pr-e2e/gemm-gemm/mixr-gemm-gemm-multiple-traces-to-first-gemm.mlir @@ -0,0 +1,23 @@ +// RUN: rocmlir-gen -fut mlir_gemm_gemm --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_gemm_gemm_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// CHECK: [1 1 1] +module { + func.func @mlir_gemm_gemm(%arg0: !migraphx.shaped<124x976xf32, 976x1>, %arg1: !migraphx.shaped<976x664xf32, 664x1>, %arg2: !migraphx.shaped<664xf32, 1>, %arg3: !migraphx.shaped<664xf32, 1>, %arg4: !migraphx.shaped<664xf32, 1>, %arg5: !migraphx.shaped<664xf32, 1>, %arg6: !migraphx.shaped<664x88xf32, 88x1>) -> !migraphx.shaped<124x88xf32, 88x1> { + %0 = migraphx.literal(dense<1.000000e+00> : tensor<1xf32>) : <1xf32, 0> + %1 = migraphx.dot %arg0, %arg1 : <124x976xf32, 976x1>, <976x664xf32, 664x1> -> <124x664xf32, 664x1> + %2 = migraphx.multibroadcast %arg2 {out_dyn_dims = [], out_lens = [124, 664]} : <664xf32, 1> -> <124x664xf32, 0x1> + %3 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [124, 664]} : <664xf32, 1> -> <124x664xf32, 0x1> + %4 = migraphx.multibroadcast %arg4 {out_dyn_dims = [], out_lens = [124, 664]} : <664xf32, 1> -> <124x664xf32, 0x1> + %5 = migraphx.multibroadcast %arg5 {out_dyn_dims = [], out_lens = [124, 664]} : <664xf32, 1> -> <124x664xf32, 0x1> + %6 = migraphx.add %2, %1 : <124x664xf32, 0x1>, <124x664xf32, 664x1> -> <124x664xf32, 664x1> + %7 = migraphx.mul %6, %3 : <124x664xf32, 664x1>, <124x664xf32, 0x1> -> <124x664xf32, 664x1> + %8 = migraphx.add %7, %4 : <124x664xf32, 664x1>, <124x664xf32, 0x1> -> <124x664xf32, 664x1> + %9 = migraphx.sigmoid %8 : <124x664xf32, 664x1> -> <124x664xf32, 664x1> + %10 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [124, 664]} : <1xf32, 0> -> <124x664xf32, 0x0> + %11 = migraphx.sub %10, %9 : <124x664xf32, 0x0>, <124x664xf32, 664x1> -> <124x664xf32, 664x1> + %12 = migraphx.mul %11, %5 : <124x664xf32, 664x1>, <124x664xf32, 0x1> -> <124x664xf32, 664x1> + %13 = migraphx.add %9, %12 : <124x664xf32, 664x1>, <124x664xf32, 664x1> -> <124x664xf32, 664x1> + %14 = migraphx.mul %13, %6 : <124x664xf32, 664x1>, <124x664xf32, 664x1> -> <124x664xf32, 664x1> + %15 = migraphx.dot %14, %arg6 : <124x664xf32, 664x1>, <664x88xf32, 88x1> -> <124x88xf32, 88x1> + return %15 : !migraphx.shaped<124x88xf32, 88x1> + } +}