@@ -1446,24 +1446,20 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
14461446 }
14471447}
14481448
1449- template < typename LinalgType>
1450- Operation * createCollapsedOp (LinalgType op ,
1451- const CollapsingInfo &collapsingInfo ,
1452- RewriterBase &rewriter) {
1453- static_assert (llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value ,
1454- " unsupported linalg op type to create " );
1449+ void collapseOperandsAndResults (LinalgOp op,
1450+ const CollapsingInfo &collapsingInfo ,
1451+ RewriterBase &rewriter ,
1452+ SmallVectorImpl<Value> &inputOperands,
1453+ SmallVectorImpl<Value> &outputOperands ,
1454+ SmallVectorImpl<Type> &resultTypes) {
14551455 Location loc = op->getLoc ();
1456-
1457- // Get the input operands.
1458- SmallVector<Value> inputOperands =
1456+ inputOperands =
14591457 llvm::map_to_vector (op.getDpsInputOperands (), [&](OpOperand *opOperand) {
14601458 return getCollapsedOpOperand (loc, op, opOperand, collapsingInfo,
14611459 rewriter);
14621460 });
14631461
14641462 // Get the output operands and result types.
1465- SmallVector<Type> resultTypes;
1466- SmallVector<Value> outputOperands;
14671463 resultTypes.reserve (op.getNumDpsInits ());
14681464 outputOperands.reserve (op.getNumDpsInits ());
14691465 for (OpOperand &output : op.getDpsInitsMutable ()) {
@@ -1475,41 +1471,69 @@ Operation *createCollapsedOp(LinalgType op,
14751471 if (!op.hasPureBufferSemantics ())
14761472 resultTypes.push_back (newOutput.getType ());
14771473 }
1474+ }
14781475
1479- if (isa<linalg::CopyOp>(op)) {
1480- return rewriter.create <linalg::CopyOp>(loc, inputOperands[0 ],
1481- outputOperands[0 ]);
1482- }
1476+ // / Clone a `LinalgOp` to a collapsed version of same name
1477+ template <typename OpTy>
1478+ OpTy cloneToCollapsedOp (RewriterBase &rewriter, OpTy origOp,
1479+ const CollapsingInfo &collapsingInfo) {
1480+ return nullptr ;
1481+ }
14831482
1484- // Get the iterator types for the operand.
1485- SmallVector<utils::IteratorType> iteratorTypes =
1486- getCollapsedOpIteratorTypes (op.getIteratorTypesArray (), collapsingInfo);
1483+ // / Collapse any `LinalgOp` that does not require any specialization such as
1484+ // / indexing_maps, iterator_types, etc.
1485+ template <>
1486+ LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1487+ const CollapsingInfo &collapsingInfo) {
1488+ SmallVector<Value> inputOperands, outputOperands;
1489+ SmallVector<Type> resultTypes;
1490+ collapseOperandsAndResults (origOp, collapsingInfo, rewriter, inputOperands,
1491+ outputOperands, resultTypes);
1492+ return cast<LinalgOp>(clone (
1493+ rewriter, origOp, resultTypes,
1494+ llvm::to_vector (llvm::concat<Value>(inputOperands, outputOperands))));
1495+ }
14871496
1488- // Get the indexing maps.
1489- auto indexingMaps =
1490- llvm::map_to_vector (op.getIndexingMapsArray (), [&](AffineMap map) {
1497+ // / Collapse a `GenericOp`
1498+ template <>
1499+ GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1500+ GenericOp origOp,
1501+ const CollapsingInfo &collapsingInfo) {
1502+ SmallVector<Value> inputOperands, outputOperands;
1503+ SmallVector<Type> resultTypes;
1504+ collapseOperandsAndResults (origOp, collapsingInfo, rewriter, inputOperands,
1505+ outputOperands, resultTypes);
1506+ SmallVector<AffineMap> indexingMaps (
1507+ llvm::map_range (origOp.getIndexingMapsArray (), [&](AffineMap map) {
14911508 return getCollapsedOpIndexingMap (map, collapsingInfo);
1492- });
1509+ }));
1510+
1511+ SmallVector<utils::IteratorType> iteratorTypes (getCollapsedOpIteratorTypes (
1512+ origOp.getIteratorTypesArray (), collapsingInfo));
14931513
1494- Operation * collapsedOp = rewriter.create <linalg::GenericOp>(
1495- loc , resultTypes, inputOperands, outputOperands, indexingMaps,
1514+ GenericOp collapsedOp = rewriter.create <linalg::GenericOp>(
1515+ origOp. getLoc () , resultTypes, inputOperands, outputOperands, indexingMaps,
14961516 iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1497- Block *origOpBlock = &op ->getRegion (0 ).front ();
1517+ Block *origOpBlock = &origOp ->getRegion (0 ).front ();
14981518 Block *collapsedOpBlock = &collapsedOp->getRegion (0 ).front ();
14991519 rewriter.mergeBlocks (origOpBlock, collapsedOpBlock,
15001520 collapsedOpBlock->getArguments ());
1501-
15021521 return collapsedOp;
15031522}
15041523
1524+ LinalgOp createCollapsedOp (LinalgOp op, const CollapsingInfo &collapsingInfo,
1525+ RewriterBase &rewriter) {
1526+ if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation ())) {
1527+ return cloneToCollapsedOp (rewriter, genericOp, collapsingInfo);
1528+ } else {
1529+ return cloneToCollapsedOp (rewriter, op, collapsingInfo);
1530+ }
1531+ }
1532+
15051533// / Implementation of fusion with reshape operation by collapsing dimensions.
1506- template <typename LinalgType>
1507- FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims (
1508- LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
1534+ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims (
1535+ LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
15091536 RewriterBase &rewriter) {
1510- static_assert (llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
1511- " unsupported linalg op type to collapse" );
1512-
15131537 // Bail on trivial no-op cases.
15141538 if (op.getNumLoops () <= 1 || foldedIterationDims.empty () ||
15151539 llvm::all_of (foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
@@ -1538,8 +1562,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
15381562 }
15391563
15401564 // Bail on non-canonical ranges.
1541- SmallVector<Range> loopRanges =
1542- cast<LinalgOp>(op.getOperation ()).createLoopRanges (rewriter, op.getLoc ());
1565+ SmallVector<Range> loopRanges = op.createLoopRanges (rewriter, op.getLoc ());
15431566 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
15441567 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
15451568 return cast<IntegerAttr>(attr).getInt () == value;
@@ -1555,8 +1578,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
15551578 op, " expected all loop ranges to have zero start and unit stride" );
15561579 }
15571580
1558- LinalgType collapsedOp = cast<LinalgType>(
1559- createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
1581+ LinalgOp collapsedOp = createCollapsedOp (op, collapsingInfo, rewriter);
15601582
15611583 Location loc = op->getLoc ();
15621584 if (collapsedOp.hasIndexSemantics ()) {
@@ -1597,7 +1619,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
15971619 results.push_back (collapsedOpResult);
15981620 }
15991621 }
1600- return results;
1622+ return CollapseResult{ results, collapsedOp} ;
16011623}
16021624
16031625namespace {
@@ -1629,15 +1651,14 @@ class FoldWithProducerReshapeOpByCollapsing
16291651 continue ;
16301652 }
16311653
1632- std::optional<SmallVector<Value>> replacements =
1633- collapseOpIterationDims<linalg::GenericOp>(
1634- genericOp, collapsableIterationDims, rewriter);
1635- if (!replacements) {
1654+ std::optional<CollapseResult> collapseResult = collapseOpIterationDims (
1655+ genericOp, collapsableIterationDims, rewriter);
1656+ if (!collapseResult) {
16361657 return rewriter.notifyMatchFailure (
16371658 genericOp, " failed to do the fusion by collapsing transformation" );
16381659 }
16391660
1640- rewriter.replaceOp (genericOp, *replacements );
1661+ rewriter.replaceOp (genericOp, collapseResult-> results );
16411662 return success ();
16421663 }
16431664 return failure ();
@@ -1671,13 +1692,12 @@ class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
16711692 op, " specified dimensions cannot be collapsed" );
16721693 }
16731694
1674- std::optional<SmallVector<Value>> replacements =
1675- collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
1676- rewriter);
1677- if (!replacements) {
1695+ std::optional<CollapseResult> collapseResult =
1696+ collapseOpIterationDims (op, collapsableIterationDims, rewriter);
1697+ if (!collapseResult) {
16781698 return rewriter.notifyMatchFailure (op, " failed to collapse dimensions" );
16791699 }
1680- rewriter.replaceOp (op, *replacements );
1700+ rewriter.replaceOp (op, collapseResult-> results );
16811701 return success ();
16821702 }
16831703
0 commit comments