diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index c098dc1e7cc4..2ad9edb845ab 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -46,66 +46,275 @@ #define DEBUG_TYPE "iree-dispatch-creation-form-dispatch-regions" -static const char kRootOpAttr[] = "__root_op__"; -static const char kFusionGroupsAttr[] = "__fused_op__"; - namespace mlir::iree_compiler::DispatchCreation { #define GEN_PASS_DEF_FORMDISPATCHREGIONSPASS #include "iree/compiler/DispatchCreation/Passes.h.inc" +/// Returns a bit vector of size number of loops of the `interfaceOp` with +/// the bits corresponding to outer parallel loops set to `true`. +static llvm::SmallBitVector getOuterParallelLoops(Operation *op) { + if (auto setEncodingOp = dyn_cast(op)) { + return llvm::SmallBitVector(setEncodingOp.getResultType().getRank(), true); + } + if (auto unsetEncodingOp = dyn_cast(op)) { + return llvm::SmallBitVector(unsetEncodingOp.getResultType().getRank(), + true); + } + + auto interfaceOp = dyn_cast(op); + if (!interfaceOp) { + // For ops that dont implement the `TilingInterface` just return empty. + return llvm::SmallBitVector{}; + } + SmallVector loopIteratorTypes = + interfaceOp.getLoopIteratorTypes(); + llvm::SmallBitVector parallelLoops(loopIteratorTypes.size()); + for (auto iteratorType : llvm::enumerate(loopIteratorTypes)) { + if (iteratorType.value() != utils::IteratorType::parallel) + break; + parallelLoops.set(iteratorType.index()); + } + return parallelLoops; +} + //===----------------------------------------------------------------------===// -// Root and fusion group attribute handling +// Root and fusion group handling //===----------------------------------------------------------------------===// -/// Returns true if an op has a root operation. -static bool hasRootOpAttribute(Operation *op) { - return static_cast(op->getAttrOfType(kRootOpAttr)); -} -/// Removes root attribute. Asserts if root attribute is not present. -static void removeRootOpAttribute(Operation *op) { - op->removeAttr(kRootOpAttr); -} -/// Sets the root attribute for an operation. The root attribute needs a number -/// to identify the root. Asserts if root attribute is already set on an -/// operation. -static void setRootAttribute(MLIRContext *context, Operation *op, - int64_t rootNumber) { - assert(!op->hasAttr(kRootOpAttr) && - "invalid to update root attribute on an op"); - op->setAttr(kRootOpAttr, - IntegerAttr::get(IntegerType::get(context, 64), rootNumber)); -} -/// Returns the number of the root. Asserts if the operation is not already set -/// as a root. -static int64_t getRootNumber(Operation *op) { - return op->getAttrOfType(kRootOpAttr).getInt(); -} -/// Returns true if an op is part of a fusion group. -static bool hasFusionGroupsAttribute(Operation *op) { - return static_cast(op->getAttrOfType(kFusionGroupsAttr)); -} -/// Returns the fusion groups for the given `op`. -static SmallVector getFusionGroups(Operation *op) { - SmallVector fusionGroups = {}; - if (auto fusionGroupsAttr = op->getAttrOfType(kFusionGroupsAttr)) { - fusionGroups = llvm::map_to_vector(fusionGroupsAttr, [](Attribute attr) { - return llvm::cast(attr).getInt(); - }); +namespace { +// `FusionGroup` is used to track operations that are to be fused with a given +// `rootOp`. +// +// This class contains an AffineMap for each operation to be fused. This map +// represents a mapping from the root op's outer parallel dims to this op's +// iteration space. `0` is used to represent when the iteration dimension has no +// mapping to the root op's outer parallel dimensions. +// +// For example: +// affine_map<(d0, d1) -> (d0, 0, d1)> +// +// The root op has 2 outer parallel loops (`d0` and `d1`) and the example op +// has 3 dimensions where the first and last map `d0` and `d1` and the middle +// has no mapping to the root's outer parallel dimensions. +class FusionGroup { +public: + FusionGroup(Operation *op) : rootOp(op) { + llvm::SmallBitVector loops = getOuterParallelLoops(op); + auto map = AffineMap::getFilteredIdentityMap( + op->getContext(), loops.size(), [&](AffineDimExpr dimExpr) { + return loops.test(dimExpr.getPosition()); + }); + map = inverseAndBroadcastProjectedPermutation(map); + loopMaps.insert({op, map}); + }; + + SmallVector getFusedOperations() const { + return llvm::map_to_vector( + loopMaps.getArrayRef(), + [](std::pair pair) { return pair.first; }); + } + + Operation *getRoot() const { return rootOp; } + + // Get the mapping from `rootOp`'s outer parallel loops to `op`. This assumes + // that the dependency chain from `rootOp` to `op` has already been inserted + // into the group. + // + // Returns failure when there is no mapping or more than one mapping exists. + FailureOr getRootParallelLoopToOpMap(Operation *op) const; + + bool isFusable(Operation *op) const { + return succeeded(getRootParallelLoopToOpMap(op)); + } + + bool contains(Operation *op) const { return loopMaps.contains(op); } + + // Insert `op` into the fusion group. + void insert(Operation *op); + +private: + Operation *rootOp; + // All operations to be fused with the root op. This does not include + // `rootOp`. + llvm::MapVector loopMaps; +}; +} // namespace + +void FusionGroup::insert(Operation *op) { + assert(!contains(op) && "op already fused"); + FailureOr map = getRootParallelLoopToOpMap(op); + if (succeeded(map)) { + loopMaps.insert({op, map.value()}); + } else { + // TODO(IanWood1): some ops can be fused but don't implement + // `LinalgFusionOpInterface` e.g. `tensor.insert_slice` or `linalg.unpack`. + // `getRootParallelLoopToOpMap` fails when `op` is trying to fuse with one + // of these ops. So, give `op` a root map. + llvm::SmallBitVector loops = getOuterParallelLoops(op); + auto map = AffineMap::getFilteredIdentityMap( + op->getContext(), loops.size(), [&](AffineDimExpr dimExpr) { + return loops.test(dimExpr.getPosition()); + }); + map = inverseAndBroadcastProjectedPermutation(map); + loopMaps.insert({op, map}); } - return fusionGroups; -} -/// Appends the given `op` to the `newGroups` fusion groups. -static void appendToFusionGroup(Operation *op, ArrayRef newGroups) { - SmallVector fusionGroups = getFusionGroups(op); - fusionGroups.append(newGroups.begin(), newGroups.end()); - op->setAttr(kFusionGroupsAttr, Builder(op).getI64ArrayAttr(fusionGroups)); } -/// Removes the fusion groups attribute. -static void removeFusionGroupsAttribute(Operation *op) { - op->removeAttr(kFusionGroupsAttr); + +FailureOr +FusionGroup::getRootParallelLoopToOpMap(Operation *op) const { + assert(!contains(op) && "op cannot already be in group"); + auto fusionOp = dyn_cast(op); + if (!fusionOp) { + return failure(); + } + + bool isConsumer = llvm::any_of(op->getOperands(), [this](Value v) { + return contains(v.getDefiningOp()); + }); + assert(isConsumer != + llvm::any_of(op->getUsers(), + [this](Operation *op) { return contains(op); }) && + "op must be not be a producer and consumer"); + + /// Computes the mapping from the root ops outer parallel loops to `op`'s + /// iteration space via a direct producer/consumer of `op` that is already in + /// the fusion group. + auto getMapFromOpInFusionGroup = + [&](AffineMap otherToOperand, AffineMap thisToOperand, + AffineMap otherMap) -> FailureOr { + if (!otherToOperand || !thisToOperand || + !otherToOperand.isProjectedPermutation() || + !thisToOperand.isProjectedPermutation()) { + return failure(); + } + + // `thisToOperand` is a mapping from the iteration space of `op` to the + // operand's data space. + // `inverseMap` is the mapping from the operand data space to `op`'s + // iteration space. + AffineMap inverseMap = + inverseAndBroadcastProjectedPermutation(thisToOperand); + + // `otherToOperand` maps "other's" (an op in the fusion group) iteration + // space to the same operand's data space. Composing the two yields a + // mapping from other's iteration space to `op`'s iteration space. + AffineMap composedMap = inverseMap.compose(otherToOperand); + + // `otherMap` is other's mapping from the root's outer parallel loops to + // other's iteration space. `composedMap.compose(otherMap)` computes the + // mapping from the root's outer parallel loops to `op`'s iteration space. + return composedMap.compose(otherMap); + }; + + AffineMap newMap; + if (isConsumer) { + for (OpOperand &operand : op->getOpOperands()) { + Operation *definingOp = operand.get().getDefiningOp(); + if (!contains(definingOp)) { + continue; + } + auto fusionProducer = + operand.get() + .getDefiningOp(); + if (!fusionProducer) { + return failure(); + } + auto it = loopMaps.find(fusionProducer); + assert(it != loopMaps.end()); + + AffineMap producerResultMap = fusionProducer.getIndexingMapMatchingResult( + cast(operand.get())); + AffineMap consumerOperandMap = fusionOp.getMatchingIndexingMap(&operand); + FailureOr composedMap = getMapFromOpInFusionGroup( + producerResultMap, consumerOperandMap, it->second); + // Mapping must be the same for all operands. + if (failed(composedMap) || (newMap && composedMap != newMap)) { + return failure(); + } + newMap = composedMap.value(); + } + } else { + for (OpOperand &operand : op->getUses()) { + if (!contains(operand.getOwner())) { + continue; + } + auto fusionConsumer = dyn_cast( + operand.getOwner()); + if (!fusionConsumer) { + return failure(); + } + auto it = loopMaps.find(operand.getOwner()); + assert(it != loopMaps.end()); + + AffineMap consumerOperandMap = + fusionConsumer.getMatchingIndexingMap(&operand); + AffineMap producerResultMap = + fusionOp.getIndexingMapMatchingResult(cast(operand.get())); + FailureOr composedMap = getMapFromOpInFusionGroup( + consumerOperandMap, producerResultMap, it->second); + // Mapping must be the same for all operands. + if (failed(composedMap) || (newMap && composedMap != newMap)) { + return failure(); + } + newMap = composedMap.value(); + + // Producers cannot be more parallel than consumers. + if (compressUnusedDims(newMap).getNumDims() != it->second.getNumDims()) { + return failure(); + } + } + } + if (!newMap) { + return failure(); + } + return newMap; } +namespace { + +/// Tracks all the FusionGroups for the program. +class FusionTracker { +public: + /// Create a new fusion group with `op` as the root. + FusionGroup &createFusionGroup(MLIRContext *ctx, Operation *op) { + fusionGroups.push_back(std::make_unique(op)); + opToGroup[op] = fusionGroups.back().get(); + return *fusionGroups.back(); + } + + // Get the fusion group that contains `op`. + const FusionGroup &getFusionGroup(Operation *op) const { + return *opToGroup.at(op); + } + + // Get the fusion group that contains `op`. + FusionGroup &getFusionGroup(Operation *op) { return *opToGroup.at(op); } + + const SmallVector> &getFusionGroups() const { + return fusionGroups; + } + + void appendToFusionGroup(Operation *op, FusionGroup &fusionGroup) { + assert(!isFusedOp(op) && "op already in a group"); + fusionGroup.insert(op); + opToGroup[op] = &fusionGroup; + } + + // Returns if `op` has been added to a FusionGroup in the tracker. + bool isFusedOp(Operation *op) const { return opToGroup.contains(op); } + + // Returns if `op` is the root of a FusionGroup. + bool isRootOp(Operation *op) const { + return isFusedOp(op) && op == getFusionGroup(op).getRoot(); + } + +private: + SmallVector> fusionGroups; + DenseMap opToGroup; +}; +} // namespace + //===----------------------------------------------------------------------===// // Op property charecterizations //===----------------------------------------------------------------------===// @@ -156,7 +365,7 @@ static bool hasFusableUnpackProducer(linalg::LinalgOp linalgOp) { /// Operations that are treated as root operations for dispatch region /// formation. -static bool isRootOp(Operation *op) { +static bool isRootLikeOp(Operation *op) { if (op->getParentOfType()) { return false; } @@ -196,133 +405,6 @@ static bool isUnpackLikeOp(Operation *op) { // Heuristics for fusing dispatchble ops with root ops using tile + fuse. //===----------------------------------------------------------------------===// -/// Returns a bit vector of size number of loops of the `interfaceOp` with -/// the bits corresponding to outer parallel loops set to `true`. -static llvm::SmallBitVector getOuterParallelLoops(Operation *op) { - if (auto setEncodingOp = dyn_cast(op)) { - return llvm::SmallBitVector(setEncodingOp.getResultType().getRank(), true); - } - if (auto unsetEncodingOp = dyn_cast(op)) { - return llvm::SmallBitVector(unsetEncodingOp.getResultType().getRank(), - true); - } - - auto interfaceOp = dyn_cast(op); - if (!interfaceOp) { - // For ops that dont implement the `TilingInterface` just return empty. - return llvm::SmallBitVector{}; - } - SmallVector loopIteratorTypes = - interfaceOp.getLoopIteratorTypes(); - llvm::SmallBitVector parallelLoops(loopIteratorTypes.size()); - for (auto iteratorType : llvm::enumerate(loopIteratorTypes)) { - if (iteratorType.value() != utils::IteratorType::parallel) - break; - parallelLoops.set(iteratorType.index()); - } - return parallelLoops; -} - -/// Returns true if `map` is an identity map with zeros, i.e. if you -/// drop the result exprs that are constant zeros, the `map` will become an -/// identity. -static bool isIdentityMapWithZeros(AffineMap map) { - if (map.getNumSymbols() != 0) - return false; - if (map.isEmpty()) - return false; - unsigned dimsSeen = 0; - for (AffineExpr result : map.getResults()) { - if (auto dimExpr = dyn_cast(result)) { - if (dimExpr.getPosition() != dimsSeen) { - return false; - } - dimsSeen++; - } else if (auto constExpr = dyn_cast(result)) { - if (constExpr.getValue() != 0) { - return false; - } - } else { - return false; - } - } - return dimsSeen == map.getNumDims(); -} - -static bool -matchIteratorTypes(const llvm::SmallBitVector &rootOuterParallelLoop, - const llvm::SmallBitVector &candidateOuterParallelLoop) { - // If the candidate is not all parallel, then its loop configuration should be - // the same as the root. - if (candidateOuterParallelLoop.size() != candidateOuterParallelLoop.count()) { - return rootOuterParallelLoop == candidateOuterParallelLoop; - } - - // If the candidate is all parallel, then it should be at least as parallel as - // the root. - for (int pos : llvm::seq(0, std::min(candidateOuterParallelLoop.size(), - rootOuterParallelLoop.size()))) { - // If we reach the end of the outer loops of the root, break out of the - // loop. - if (!rootOuterParallelLoop.test(pos)) - break; - // If the root loop is parallel, the candidate loop should also be parallel. - if (!candidateOuterParallelLoop.test(pos)) - return false; - } - return true; -} - -// Method to check if the op with have compatible indexing map on outer-parallel -// loops. Currently it means the map needs to be identity on the those -// dimensions, ignoring its reduction dimensions. -static bool hasCompatibleOuterParallelLoops( - TilingInterface tileOp, AffineMap indexingMap, - const llvm::SmallBitVector &rootOuterParallelLoops) { - if (!indexingMap.isProjectedPermutation()) { - return false; - } - - llvm::SmallBitVector parallelLoops = getOuterParallelLoops(tileOp); - if (!matchIteratorTypes(rootOuterParallelLoops, parallelLoops)) { - return false; - } - - /// Project out the non-parallel dimensions. - llvm::SmallBitVector projectedDims(rootOuterParallelLoops); - projectedDims.flip(); - projectedDims.resize(tileOp.getLoopIteratorTypes().size(), true); - auto projectedMap = getProjectedMap(indexingMap, projectedDims); - return isIdentityMapWithZeros(projectedMap); -} - -// Method to check if two `linalg.generic` op with producer-consumer -// relationship through `operand` have compatible outer-parallel loops. -static bool hasCompatibleOuterParallelLoops( - OpOperand &operand, const llvm::SmallBitVector &rootOuterParallelLoops) { - auto producer = - operand.get().getDefiningOp(); - auto consumer = - dyn_cast(operand.getOwner()); - if (!producer || !consumer) - return false; - - auto producerIndexingMap = producer.getIndexingMapMatchingResult( - llvm::cast(operand.get())); - auto consumerIndexingMap = consumer.getMatchingIndexingMap(&operand); - - if (!producerIndexingMap || !consumerIndexingMap) { - return false; - } - - return hasCompatibleOuterParallelLoops( - cast(producer.getOperation()), - producerIndexingMap, rootOuterParallelLoops) && - hasCompatibleOuterParallelLoops( - cast(consumer.getOperation()), - consumerIndexingMap, rootOuterParallelLoops); -} - /// For all uses of an operation, return the uses that could be fused. /// The returned vector contains the uses in dominance order. static SmallVector @@ -353,160 +435,6 @@ getFusableUses(MLIRContext *context, Operation *op, return usesVec; } -/// Returns true if the operands are fusable. -static bool areOpsFusable(Operation *producer, Operation *consumer, - const llvm::SmallBitVector &rootOuterParallelLoops) { - // Collect all the uses from producer to consumer. - SmallVector allUses; - for (OpOperand &producerUse : producer->getUses()) { - if (producerUse.getOwner() != consumer) - continue; - allUses.push_back(&producerUse); - } - - // Check that the consumer and producer have compatible outer parallel loops. - if (!llvm::all_of(allUses, [&](OpOperand *operand) { - return hasCompatibleOuterParallelLoops(*operand, - rootOuterParallelLoops); - })) { - return false; - } - return true; -} - -/// The logic to decide fusability (using the `hasCompatibleOuterParallelLoops`) -/// currently works when the indexing map corresponding to result of the -/// producer and indexing map corresponding to operand in the result are not -/// transposed with respect to each other. To find more fusion opportunities for -/// consumer elementwise operation, the indexing maps in the consumer can be -/// made to "align" with the indexing map of the producer to enhance fusion. -static bool makeConsumerFusableViaInterchange( - OpOperand &fusableOperand, - const llvm::SmallBitVector &rootOuterParallelLoops) { - auto producer = - fusableOperand.get() - .getDefiningOp(); - if (!producer) { - return false; - } - - auto consumer = dyn_cast(fusableOperand.getOwner()); - if (!consumer) { - return false; - } - - if (!linalg::isElementwise(consumer) || consumer.getNumResults() != 1) { - return false; - } - - // If the indexing map in the consumer is already "compatible" with the - // indexing map in the producer, do nothing. - AffineMap producerIndexingMap = producer.getIndexingMapMatchingResult( - cast(fusableOperand.get())); - if (!producerIndexingMap) { - return false; - } - producerIndexingMap = getProjectedMap( - producerIndexingMap, getUnusedDimsBitVector(producerIndexingMap)); - AffineMap consumerIndexingMap = - consumer.getMatchingIndexingMap(&fusableOperand); - - // Since the iteration space of the consumer is going to be permuted - // to make it match with the indexing map in the producer, the interchange - // requires the indexing map in the consumer to be a permutation. - // If the producer indexing map and consumer indexing map are the same, - // then the permutation of iteration space becomes a no-op, in which - // case the permutation wasnt required for fusion. Return false here - // to indicate that the permutation is not going to "enhance" the - // fusion opportunities. - if (!consumerIndexingMap.isPermutation() || - producerIndexingMap == consumerIndexingMap) { - return false; - } - OpResult result = cast(consumer.getResult(0)); - if (!consumer.getIndexingMapMatchingResult(result).isPermutation()) { - return false; - } - - // For now this is restricting that all indexing maps corresponding to the - // input are same as the indexing map of the fused operand, or are projected - // permutations. This avoids ping-ponging between different iteration space - // permutations without having any way to pick which is better. - if (!llvm::all_of( - consumer.getDpsInputOperands(), [&](OpOperand *inputOperand) { - AffineMap map = consumer.getMatchingIndexingMap(inputOperand); - return map == consumerIndexingMap || - (map.isProjectedPermutation() && !map.isPermutation()); - })) { - return false; - } - - // Make the input map match the producer map by applying a permutation map - // computed with consumerIndexingMap.compose(inv(producerIndexingMap)) - AffineMap invProducerIndexingMap = inversePermutation(producerIndexingMap); - AffineMap permutationMap = - consumerIndexingMap.compose(invProducerIndexingMap); - auto perm = llvm::map_to_vector(permutationMap.getResults(), - [](AffineExpr e) -> unsigned { - return cast(e).getPosition(); - }); - IRRewriter rewriter(consumer->getContext()); - FailureOr interchangedOp = - linalg::interchangeGenericOp(rewriter, consumer, perm); - (void)interchangedOp; - assert(succeeded(interchangedOp) && "expected interchange to succeed"); - assert(interchangedOp.value() == consumer && - "expected interchange to happen in place"); - return true; -} - -static bool makeProducerFusableViaInterchange( - OpOperand &fusableOperand, - const llvm::SmallBitVector &rootOuterParallelLoops) { - auto producer = fusableOperand.get().getDefiningOp(); - if (!producer) { - return false; - } - - auto consumer = dyn_cast( - fusableOperand.getOwner()); - if (!consumer) { - return false; - } - - if (!linalg::isElementwise(producer) || producer.getNumResults() != 1) { - return false; - } - - AffineMap producerIndexingMap = producer.getIndexingMapMatchingResult( - cast(fusableOperand.get())); - producerIndexingMap = getProjectedMap( - producerIndexingMap, getUnusedDimsBitVector(producerIndexingMap)); - AffineMap consumerIndexingMap = - consumer.getMatchingIndexingMap(&fusableOperand); - if (!consumerIndexingMap || !consumerIndexingMap.isPermutation() || - producerIndexingMap == consumerIndexingMap) { - return false; - } - - // Make the input map match the consumer map by applying a permutation map - AffineMap invProducerIndexingMap = inversePermutation(producerIndexingMap); - AffineMap permutationMap = - consumerIndexingMap.compose(invProducerIndexingMap); - auto perm = llvm::map_to_vector(permutationMap.getResults(), - [](AffineExpr e) -> unsigned { - return cast(e).getPosition(); - }); - IRRewriter rewriter(consumer->getContext()); - FailureOr interchangedOp = - linalg::interchangeGenericOp(rewriter, producer, perm); - (void)interchangedOp; - assert(succeeded(interchangedOp) && "expected interchange to succeed"); - assert(interchangedOp.value() == producer && - "expected interchange to happen in place"); - return true; -} - /// For the fusion of root op -> elementwise operation to be bufferized /// in-place without use of extra memory, the result of the root operation /// must be able to reuse the buffer for the result of the elementwise @@ -552,8 +480,7 @@ static bool canUseInOperandAsInitOperand(OpOperand *inOperand, /// Returns true if this is a fusable use, while fusing a root with its /// consumer. static bool -isFusableWithConsumer(OpOperand &fusedOperand, - const llvm::SmallBitVector &rootOuterParallelLoops, +isFusableWithConsumer(OpOperand &fusedOperand, const FusionTracker &tracker, FormDispatchRegionsPassOptions const &options) { Operation *producer = fusedOperand.get().getDefiningOp(); Operation *consumer = fusedOperand.getOwner(); @@ -585,13 +512,11 @@ isFusableWithConsumer(OpOperand &fusedOperand, return TypeSwitch(producer) .Case([&](auto padOp) { return true; }) .Case([&](auto linalgOp) { - auto producerIndexingMap = linalgOp.getIndexingMapMatchingResult( + AffineMap producerIndexingMap = linalgOp.getIndexingMapMatchingResult( llvm::cast(fusedOperand.get())); // Make sure the producer op has an identity result indexing map. As // CPU backend currently can't handle transpose between fused ops. - return hasCompatibleOuterParallelLoops( - cast(linalgOp.getOperation()), - producerIndexingMap, rootOuterParallelLoops); + return producerIndexingMap.isIdentity(); }) .Default([](Operation *) { return false; }); } @@ -639,15 +564,8 @@ isFusableWithConsumer(OpOperand &fusedOperand, return false; } - if (!areOpsFusable(producer, consumer, rootOuterParallelLoops)) { - // Check if interchange in the consumer makes it fusable. - // Currently limit it to horizontally fused gemms. - // TODO(#20019) to remove this restriction. - if (!IREE::LinalgExt::isaHorizontallyFusedContraction(producer) || - !makeConsumerFusableViaInterchange(fusedOperand, - rootOuterParallelLoops)) { - return false; - } + if (!tracker.getFusionGroup(producer).isFusable(consumer)) { + return false; } // Check if the iteration spaces of the producer and consumer are same. @@ -702,12 +620,12 @@ isFusableWithConsumer(OpOperand &fusedOperand, static void fuseRootsWithConsumers(MLIRContext *context, ArrayRef roots, DominanceInfo const &dominanceInfo, - FormDispatchRegionsPassOptions const &options) { + FormDispatchRegionsPassOptions const &options, + FusionTracker &tracker) { // Fuse with consumers where possible. for (Operation *root : roots) { SmallVector workList; - llvm::SmallBitVector rootOuterParallelLoops = getOuterParallelLoops(root); - int64_t rootNumber = getRootNumber(root); + FusionGroup &fusionGroup = tracker.getFusionGroup(root); workList.push_back(root); while (!workList.empty()) { Operation *currRoot = workList.pop_back_val(); @@ -733,14 +651,12 @@ fuseRootsWithConsumers(MLIRContext *context, ArrayRef roots, // Analyse the use to see if it is fusable. for (OpOperand *fusableUse : fusableUses) { Operation *consumerOp = fusableUse->getOwner(); - if (hasRootOpAttribute(consumerOp) || - hasFusionGroupsAttribute(consumerOp)) { + if (tracker.isRootOp(consumerOp) || tracker.isFusedOp(consumerOp)) { continue; } - if (isFusableWithConsumer(*fusableUse, rootOuterParallelLoops, - options)) { - appendToFusionGroup(consumerOp, rootNumber); + if (isFusableWithConsumer(*fusableUse, tracker, options)) { + tracker.appendToFusionGroup(consumerOp, fusionGroup); workList.push_back(consumerOp); } else { break; @@ -751,9 +667,10 @@ fuseRootsWithConsumers(MLIRContext *context, ArrayRef roots, } /// Method to check if the consumer of a use can be fused with its producer. -static bool isFusableWithProducer( - OpOperand &operand, const llvm::SmallBitVector &rootOuterParallelLoops, - FormDispatchRegionsPassOptions const &options, bool fuseWithTruncate) { +static bool isFusableWithProducer(OpOperand &operand, + const FusionTracker &tracker, + FormDispatchRegionsPassOptions const &options, + bool fuseWithTruncate) { Operation *producer = operand.get().getDefiningOp(); Operation *consumer = operand.getOwner(); @@ -790,13 +707,11 @@ static bool isFusableWithProducer( return false; } } - auto producerIndexingMap = linalgOp.getIndexingMapMatchingResult( + AffineMap producerIndexingMap = linalgOp.getIndexingMapMatchingResult( llvm::cast(operand.get())); // Make sure the producer op has an identity result indexing map. As // CPU backend currently can't handle transpose between fused ops. - return hasCompatibleOuterParallelLoops( - cast(linalgOp.getOperation()), - producerIndexingMap, rootOuterParallelLoops); + return producerIndexingMap.isIdentity(); }) .Default([](Operation *) { return false; }); } @@ -813,10 +728,8 @@ static bool isFusableWithProducer( } } - if (!areOpsFusable(producer, consumer, rootOuterParallelLoops)) { - if (!makeProducerFusableViaInterchange(operand, rootOuterParallelLoops)) { - return false; - } + if (!tracker.getFusionGroup(consumer).isFusable(producer)) { + return false; } return true; } @@ -824,13 +737,13 @@ static bool isFusableWithProducer( /// Starting from the `root` op, traverse the operand use-def chain /// in reverse to fuse with producers. static void -fuseRootsWithProducers(MLIRContext *context, Operation *root, unsigned groupNum, +fuseRootsWithProducers(MLIRContext *context, Operation *root, + FusionGroup &fusionGroup, DominanceInfo const &dominanceInfo, FormDispatchRegionsPassOptions const &options, - bool fuseWithTruncate) { + FusionTracker &tracker, bool fuseWithTruncate) { SmallVector worklist; worklist.push_back(root); - llvm::SmallBitVector rootOuterParallelLoops = getOuterParallelLoops(root); IREE::Flow::ClonableIntoDispatchOptions clonableOptions; clonableOptions.aggressive = options.aggressiveFusion; while (!worklist.empty()) { @@ -840,12 +753,11 @@ fuseRootsWithProducers(MLIRContext *context, Operation *root, unsigned groupNum, if (!producer) continue; if (IREE::Flow::isClonableIntoDispatchOp(producer, clonableOptions) || - hasFusionGroupsAttribute(producer) || hasRootOpAttribute(producer)) { + tracker.isFusedOp(producer) || tracker.isRootOp(producer)) { continue; } - if (!isFusableWithProducer(operand, rootOuterParallelLoops, options, - fuseWithTruncate)) { + if (!isFusableWithProducer(operand, tracker, options, fuseWithTruncate)) { continue; } @@ -855,24 +767,18 @@ fuseRootsWithProducers(MLIRContext *context, Operation *root, unsigned groupNum, if (fusableUses.empty() || fusableUses.front()->getOwner() != candidate) continue; - appendToFusionGroup(producer, groupNum); + tracker.appendToFusionGroup(producer, fusionGroup); worklist.push_back(producer); } } } /// Some heuristic is needed to fuse a dispatchable op with root operations -/// using tile + fuse. Using some heuristic, each root operation is tagged with -/// an ID (using an IntegerAttr with name `kRootOpAttr`) and all dispatchable -/// ops to be fused with it is tagged with the same ID (using a list of -/// IntegerAttr with name `kFusionGroupsAttr`). Each dispatchable operation can -/// be marked to fuse with multiple root operations (i.e. replicated). For now a -/// very simple heuristic is used below, but the mechanism should be general -/// enough to capture any heuristic. -static unsigned +/// using tile + fuse. +static void decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo, FormDispatchRegionsPassOptions const &options, - unsigned numRootOps = 0) { + FusionTracker &tracker, unsigned numRootOps = 0) { MLIRContext *context = region.getContext(); OpBuilder builder(context); IREE::Flow::ClonableIntoDispatchOptions clonableOptions; @@ -887,27 +793,27 @@ decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo, for (Operation &op : llvm::reverse(block)) { if (isa(op.getDialect())) { for (auto ®ion : op.getRegions()) { - numRootOps = decideFusableLinalgOps(region, dominanceInfo, options, - numRootOps); + decideFusableLinalgOps(region, dominanceInfo, options, tracker, + numRootOps); } continue; } // Start with a root operation and fuse its producers. - if (hasFusionGroupsAttribute(&op) || !isRootOp(&op)) + if (tracker.isFusedOp(&op) || !isRootLikeOp(&op)) continue; - unsigned newGroup = numRootOps++; - setRootAttribute(context, &op, newGroup); - + FusionGroup &newGroup = tracker.createFusionGroup(context, &op); fuseRootsWithProducers(context, &op, newGroup, dominanceInfo, options, + tracker, /*fuseWithTruncate=*/false); roots.push_back(&op); } roots = llvm::to_vector(llvm::reverse(roots)); - fuseRootsWithConsumers(context, roots, dominanceInfo, options); + fuseRootsWithConsumers(context, roots, dominanceInfo, options, tracker); for (Operation *root : roots) { - int64_t rootNumber = getRootNumber(root); - fuseRootsWithProducers(context, root, rootNumber, dominanceInfo, options, + FusionGroup &fusionGroup = tracker.getFusionGroup(root); + fuseRootsWithProducers(context, root, fusionGroup, dominanceInfo, options, + tracker, /*fuseWithTruncate=*/true); } } @@ -918,7 +824,7 @@ decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo, SmallVector roots; for (Operation &op : llvm::reverse(block)) { // If it is part of a fusion group or root op, ignore it. - if (hasFusionGroupsAttribute(&op) || hasRootOpAttribute(&op)) + if (tracker.isFusedOp(&op) || tracker.isRootOp(&op)) continue; // Only look for Linalg ops here. Avoid moving `linalg.fill` that aren't // fused with anything else into their own dispatches since it is better @@ -942,23 +848,21 @@ decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo, continue; } - unsigned newGroup = numRootOps++; - setRootAttribute(context, &op, newGroup); - + FusionGroup &newGroup = tracker.createFusionGroup(context, &op); fuseRootsWithProducers(context, &op, newGroup, dominanceInfo, options, + tracker, /*fuseWithTruncate=*/false); roots.push_back(&op); } roots = llvm::to_vector(llvm::reverse(roots)); - fuseRootsWithConsumers(context, roots, dominanceInfo, options); + fuseRootsWithConsumers(context, roots, dominanceInfo, options, tracker); for (Operation *root : roots) { - int64_t rootNumber = getRootNumber(root); - fuseRootsWithProducers(context, root, rootNumber, dominanceInfo, options, + FusionGroup &fusionGroup = tracker.getFusionGroup(root); + fuseRootsWithProducers(context, root, fusionGroup, dominanceInfo, options, + tracker, /*fuseWithTruncate=*/true); } } - - return numRootOps; } //===----------------------------------------------------------------------===// @@ -971,12 +875,10 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter, mlir::FunctionOpInterface funcOp, DominanceInfo &dominanceInfo, FormDispatchRegionsPassOptions const &options) { - // Step 1: Decide fusion groups (heuristic). This marks rootOps with an - // attribute - unsigned numRoots = - decideFusableLinalgOps(funcOp.getFunctionBody(), dominanceInfo, options); - SmallVector roots(numRoots, nullptr); - DenseMap> fusedOperations; + // Step 1: Decide fusion groups (heuristic). + FusionTracker tracker; + decideFusableLinalgOps(funcOp.getFunctionBody(), dominanceInfo, options, + tracker); LLVM_DEBUG({ llvm::dbgs() << "\n--- After deciding fusion groups ---\n"; @@ -984,29 +886,15 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter, llvm::dbgs() << "\n\n"; }); - // TODO: Incrementally add ops to an empty DispatchGroupOp instead of - // annotating fusion group IDs via attributes. - funcOp.walk([&](Operation *op) { - if (hasRootOpAttribute(op)) { - roots[getRootNumber(op)] = op; - fusedOperations[getRootNumber(op)].push_back(op); - removeRootOpAttribute(op); - } - if (hasFusionGroupsAttribute(op)) { - assert(getFusionGroups(op).size() == 1 && "expected exactly one group"); - fusedOperations[getFusionGroups(op).front()].push_back(op); - removeFusionGroupsAttribute(op); - } - }); - // Step 2. Create a DispatchRegionOp for every fusion group. OpBuilder::InsertionGuard g(rewriter); SmallVector regionOps; - for (auto [rootIndex, root] : llvm::enumerate(roots)) { - + for (const auto &fusionGroup : tracker.getFusionGroups()) { + Operation *root = fusionGroup->getRoot(); // Sort producers and consumers topologically. All fused ops must be in the // same block as the root. - SmallVector &currFusedOperations = fusedOperations[rootIndex]; + SmallVector currFusedOperations = + fusionGroup->getFusedOperations(); bool sortResult = mlir::computeTopologicalSorting(currFusedOperations); (void)sortResult; assert(sortResult && "could not compute topological sorting"); diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp index cf51973de655..5640fe37c3a9 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp @@ -69,6 +69,11 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, if (!options.fuseTruncateOps && IREE::LinalgExt::isBitTruncateOp(producerOp)) { + // TODO(IanWood1): do this regardless of `options.fuseTruncateOps`. + // Never fuse truncate -> extend. + if (IREE::LinalgExt::isBitExtendOp(consumerOp)) { + return false; + } // Do not fuse with bit-truncate-like operations with their consumers // unless: // diff --git a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir index 59c001472dff..f362de965ad6 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors.mlir @@ -1541,7 +1541,7 @@ util.func public @fuse_conv2d_with_multiple_uses(%input: tensor<1x225x225x16xf32 // ----- -util.func public @dont_fuse_conv2d_with_non_identity_map(%input: tensor<1x225x225x16xf32>, %filter: tensor<3x3x16x32xf32>, %offset: tensor<32xf32>) -> tensor<1x112x112x32xf32> { +util.func public @fuse_conv2d_with_non_identity_map(%input: tensor<1x225x225x16xf32>, %filter: tensor<3x3x16x32xf32>, %offset: tensor<32xf32>) -> tensor<1x112x112x32xf32> { %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.empty() : tensor<1x112x112x32xf32> %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> @@ -1565,13 +1565,11 @@ util.func public @dont_fuse_conv2d_with_non_identity_map(%input: tensor<1x225x22 util.return %3 : tensor<1x112x112x32xf32> } -// CHECK-LABEL: util.func public @dont_fuse_conv2d_with_non_identity_map - -// CHECK: flow.dispatch.workgroups -// CHECK: linalg.conv_2d_nhwc_hwcf - -// CHECK: flow.dispatch.workgroups -// CHECK: linalg.generic +// CHECK-LABEL: util.func public @fuse_conv2d_with_non_identity_map +// CHECK: flow.dispatch.workgroups +// CHECK: linalg.conv_2d_nhwc_hwcf +// CHECK-NOT: flow.dispatch.workgroups +// CHECK: linalg.generic // ----- diff --git a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir index 031482c7cd4b..5ce65e6466e8 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir @@ -1184,6 +1184,7 @@ util.func @avoid_use_def_violation_on_consumer_fusion(%arg0 : tensor, // ----- +// Test transposed output. util.func @horizontal_fusion3(%lhs : tensor<2x4096x640xf16>, %rhs0 : tensor<10x64x640xf16>, %rhs1 : tensor<10x64x640xf16>, %rhs2 : tensor<10x64x640xf16>) -> @@ -1254,7 +1255,6 @@ util.func @horizontal_fusion3(%lhs : tensor<2x4096x640xf16>, } -> tensor<2x10x64x4096xf16> util.return %8, %9, %10 : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16> } -// CHECK: #[[INTERCHANGED_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)> // CHECK: func public @horizontal_fusion3 // CHECK: %[[DISPATCH:.+]]:3 = flow.dispatch.region // CHECK: %[[GENERIC:.+]]:3 = linalg.generic @@ -1263,7 +1263,6 @@ util.func @horizontal_fusion3(%lhs : tensor<2x4096x640xf16>, // CHECK: %[[TRUNC1:.+]] = linalg.generic // CHECK-SAME: ins(%[[GENERIC]]#1 : // CHECK: %[[TRUNC2:.+]] = linalg.generic -// CHECK-SANE: indexing_maps = [#[[INTERCHANGED_MAP]], #[[INTERCHANGED_MAP]]] // CHECK-SAME: ins(%[[GENERIC]]#2 : // CHECK: flow.return %[[TRUNC0]], %[[TRUNC1]], %[[TRUNC2]] // CHECK: util.return %[[DISPATCH]]#0, %[[DISPATCH]]#1, %[[DISPATCH]]#2 @@ -1597,3 +1596,170 @@ util.func public @dynamic_quantization_fp4(%arg0 : tensor, %arg1 : ind // CHECK: flow.return %[[SCALE]], %[[QUANTIZED]] // CHECK: %[[BITCAST:.+]] = iree_tensor_ext.bitcast %[[DISPATCH]]#1 // CHECK: return %[[BITCAST]], %[[DISPATCH]]#0 + +// ----- + +util.func public @fuse_both_uses_transposed(%arg0 : tensor, + %arg1 : tensor) -> tensor { + %cst = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg1, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %empty = tensor.empty(%d0, %d1) : tensor + %3 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%empty : tensor) -> tensor + %5 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%3, %3 : tensor, tensor) + outs(%empty : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 :f32) : + %6 = arith.addf %b0, %b1 : f32 + linalg.yield %6 : f32 + } -> tensor + util.return %5 : tensor +} +// CHECK-LABEL: @fuse_both_uses_transposed +// CHECK: flow.dispatch.region +// CHECK: %[[MATMUL:.+]] = linalg.matmul +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[MATMUL]], %[[MATMUL]] + +// ----- + +util.func public @dont_fuse_use_transpose_and_identity(%arg0 : tensor, + %arg1 : tensor) -> tensor { + %cst = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg1, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %empty = tensor.empty(%d0, %d1) : tensor + %3 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%empty : tensor) -> tensor + %5 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%3, %3 : tensor, tensor) + outs(%empty : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 :f32) : + %6 = arith.addf %b0, %b1 : f32 + linalg.yield %6 : f32 + } -> tensor + util.return %5 : tensor +} +// CHECK-LABEL: @dont_fuse_use_transpose_and_identity +// CHECK: flow.dispatch.region +// CHECK: linalg.matmul +// CHECK: flow.dispatch.region +// CHECK: linalg.generic + +// ----- + +util.func public @dont_fuse_use_consumer_transposed_use_of_producer(%arg0 : tensor) -> tensor { + %cst = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %empty = tensor.empty(%d0, %d0, %d0) : tensor + %empty2 = tensor.empty(%d0, %d0) : tensor + %5 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%empty : tensor) { + ^bb0(%b0 : f32, %b1 :f32) : + %6 = arith.addf %b0, %b1 : f32 + linalg.yield %6 : f32 + } -> tensor + %6 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%5 : tensor) + outs(%empty2 : tensor) { + ^bb0(%b0 : f32, %b1 :f32) : + %6 = arith.addf %b0, %b1 : f32 + linalg.yield %6 : f32 + } -> tensor + // The transpose on %5 makes this unfusable + %7 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%5, %6 : tensor, tensor) + outs(%empty : tensor) { + ^bb0(%b0 : f32, %b1 :f32, %b2 : f32) : + %8 = arith.addf %b0, %b1 : f32 + linalg.yield %8 : f32 + } -> tensor + util.return %7 : tensor +} +// CHECK-LABEL: @dont_fuse_use_consumer_transposed_use_of_producer +// CHECK: flow.dispatch.region +// CHECK: linalg.generic +// CHECK: linalg.generic +// CHECK: flow.dispatch.region +// CHECK: linalg.generic + +// ----- + +util.func public @unpack_multi_elementwise_fusion( + %arg0: tensor, + %arg1: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %d2 = tensor.dim %arg0, %c2 : tensor + %d3 = tensor.dim %arg0, %c3 : tensor + %folded_dim0 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%d0, %d2] + %folded_dim1 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%d1, %d3] + %dest = tensor.empty(%folded_dim0, %folded_dim1) : tensor + %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [%d2, %d3] + into %dest : tensor -> tensor + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0, %arg1 : tensor, tensor) + outs(%dest : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %2 = arith.addf %b0, %b1 : f32 + linalg.yield %2 : f32 + } -> tensor + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%1, %arg1 : tensor, tensor) + outs(%dest : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %2 = arith.addf %b0, %b1 : f32 + linalg.yield %2 : f32 + } -> tensor + util.return %2 : tensor +} +// CHECK-LABEL: util.func public @unpack_multi_elementwise_fusion( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:.+]]: tensor) +// CHECK: %[[RESULT:.+]] = flow.dispatch.region +// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK-SAME: ins(%[[UNPACK]], %[[ARG1]] +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC0]], %[[ARG1]] +// CHECK: flow.return %[[GENERIC1]] +// CHECK: util.return %[[RESULT]]