Skip to content

[mlir][SCF] Add scf::tileAndFuseConsumer that tiles a consumer into a given tiled loop nest.#167634

Merged
MaheshRavishankar merged 4 commits intollvm:mainfrom
MaheshRavishankar:users/MaheshRavishankar/tileAndFuseConsumer
Nov 20, 2025
Merged

[mlir][SCF] Add scf::tileAndFuseConsumer that tiles a consumer into a given tiled loop nest.#167634
MaheshRavishankar merged 4 commits intollvm:mainfrom
MaheshRavishankar:users/MaheshRavishankar/tileAndFuseConsumer

Conversation

@MaheshRavishankar
Copy link
Contributor

The existing scf::tileAndFuseConsumerOfSlices takes a list of slices (and loops they are part of), tries to find the consumer of these slices (all slices are expected to be the same consumer), and then tiles the consumer into the loop nest using the TilingInterface. A more natural way of doing consumer fusion is to just start from the consumer, look for operands that are produced by the loop nest passed in as loops (presumably these loops are generated by tiling, but that is not a requirement for consumer fusion). Using the consumer you can find the slices of the operands that are accessed within the loop which you can then use to tile and fuse the consumer (using TilingInterface). This handles more naturally the case where multiple operands of the consumer come from the loop nest.

The scf::tileAndFuseConsumerOfSlices was implemented as a mirror of scf::tileAndFuseProducerOfSlice. For the latter, the slice has a single producer for the source of the slice, which makes it a natural way of specifying producer fusion. But for consumers, the result might have multiple users, resulting in multiple candidates for fusion, as well as a fusion candidate using multiple results from the tiled loop nest. This means using slices
(tensor.insert_slice/tensor.parallel_insert_slice) as a hook for consumer fusion turns out to be quite hard to navigate. The use of the consumer directly avoids all those pain points. In time the scf::tileAndFuseConsumerOfSlices should be deprecated in favor of scf::tileAndFuseConsumer. There is a lot of tech-debt that has accumulated in scf::tileAndFuseConsumerOfSlices that needs to be cleanedup. So while that gets cleaned up, and required functionality is moved to scf::tileAndFuseConsumer, the old path is still maintained.

The test for scf::tileAndFuseConsumerUsingSlices is copied to tile-and-fuse-consumer.mlir to
tile-and-fuse-consumer-using-slices.mlir. All the tests that were there in this file are now using the tileAndFuseConsumer method. The test op test.tile_and_fuse_consumer is modified to call scf::tileAndFuseConsumer, while a new op
test.tile_and_fuse_consumer_of_slice is used to keep the old path tested while it is deprecated.

… a given tiled loop nest.

The existing `scf::tileAndFuseConsumerOfSlices` takes a list of slices
(and loops they are part of), tries to find the consumer of these
slices (all slices are expected to be the same consumer), and then
tiles the consumer into the loop nest using the `TilingInterface`. A
more natural way of doing consumer fusion is to just start from the
consumer, look for operands that are produced by the loop nest passed
in as `loops` (presumably these loops are generated by tiling, but
that is not a requirement for consumer fusion). Using the consumer you
can find the slices of the operands that are accessed within the loop
which you can then use to tile and fuse the consumer (using
`TilingInterface`). This handles more naturally the case where
multiple operands of the consumer come from the loop nest.

The `scf::tileAndFuseConsumerOfSlices` was implemented as a mirror of
`scf::tileAndFuseProducerOfSlice`. For the latter, the slice has a
single producer for the source of the slice, which makes it a natural
way of specifying producer fusion. But for consumers, the result might
have multiple users, resulting in multiple candidates for fusion, as
well as a fusion candidate using multiple results from the tiled loop
nest. This means using slices
(`tensor.insert_slice`/`tensor.parallel_insert_slice`) as a hook for
consumer fusion turns out to be quite hard to navigate. The use of the
consumer directly avoids all those pain points. In time the
`scf::tileAndFuseConsumerOfSlices` should be deprecated in favor of
`scf::tileAndFuseConsumer`. There is a lot of tech-debt that has
accumulated in `scf::tileAndFuseConsumerOfSlices` that needs to be
cleanedup. So while that gets cleaned up, and required functionality
is moved to `scf::tileAndFuseConsumer`, the old path is still
maintained.

The test for `scf::tileAndFuseConsumerUsingSlices` is copied to
`tile-and-fuse-consumer.mlir` to
`tile-and-fuse-consumer-using-slices.mlir`. All the tests that were
there in this file are now using the `tileAndFuseConsumer` method. The
test op `test.tile_and_fuse_consumer` is modified to call
`scf::tileAndFuseConsumer`, while a new op
`test.tile_and_fuse_consumer_of_slice` is used to keep the old path
tested while it is deprecated.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: None (MaheshRavishankar)

Changes

The existing scf::tileAndFuseConsumerOfSlices takes a list of slices (and loops they are part of), tries to find the consumer of these slices (all slices are expected to be the same consumer), and then tiles the consumer into the loop nest using the TilingInterface. A more natural way of doing consumer fusion is to just start from the consumer, look for operands that are produced by the loop nest passed in as loops (presumably these loops are generated by tiling, but that is not a requirement for consumer fusion). Using the consumer you can find the slices of the operands that are accessed within the loop which you can then use to tile and fuse the consumer (using TilingInterface). This handles more naturally the case where multiple operands of the consumer come from the loop nest.

The scf::tileAndFuseConsumerOfSlices was implemented as a mirror of scf::tileAndFuseProducerOfSlice. For the latter, the slice has a single producer for the source of the slice, which makes it a natural way of specifying producer fusion. But for consumers, the result might have multiple users, resulting in multiple candidates for fusion, as well as a fusion candidate using multiple results from the tiled loop nest. This means using slices
(tensor.insert_slice/tensor.parallel_insert_slice) as a hook for consumer fusion turns out to be quite hard to navigate. The use of the consumer directly avoids all those pain points. In time the scf::tileAndFuseConsumerOfSlices should be deprecated in favor of scf::tileAndFuseConsumer. There is a lot of tech-debt that has accumulated in scf::tileAndFuseConsumerOfSlices that needs to be cleanedup. So while that gets cleaned up, and required functionality is moved to scf::tileAndFuseConsumer, the old path is still maintained.

The test for scf::tileAndFuseConsumerUsingSlices is copied to tile-and-fuse-consumer.mlir to
tile-and-fuse-consumer-using-slices.mlir. All the tests that were there in this file are now using the tileAndFuseConsumer method. The test op test.tile_and_fuse_consumer is modified to call scf::tileAndFuseConsumer, while a new op
test.tile_and_fuse_consumer_of_slice is used to keep the old path tested while it is deprecated.


Patch is 141.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167634.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+5)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+12)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+172-49)
  • (modified) mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir (+2-2)
  • (added) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir (+1156)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+189-191)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+71-8)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+23-1)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index cd033c140a233..8bdf3e0b566ef 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -613,6 +613,11 @@ def ForallOp : SCF_Op<"forall", [
                                     getNumDynamicControlOperands() + getRank());
     }
 
+    BlockArgument getTiedBlockArgument(OpResult opResult) {
+      assert(opResult.getDefiningOp() == getOperation()  && "invalid OpResult");
+      return getBody()->getArgument(getRank() + opResult.getResultNumber());
+    }
+
     ::mlir::Value getInductionVar(int64_t idx) {
       return getInductionVars()[idx];
     }
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 7c735d825b445..0005fad3d5c01 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -415,6 +415,10 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
 /// tiled in a manner that is consistent for all the passed slices. Note that
 /// the method replaces the uses of `candidateSlices` with the tiled and fused
 /// consumer value but does not delete the slice operations.
+/// TODO(MaheshRavishankar): A more natural way of exposing the consumer fusion
+/// is to take the consumer operation, and find the slices to use for fusion
+/// by walking its operands to the `loops` and then into the body to get the
+/// slices used for fusion.
 struct SCFFuseConsumerOfSliceResult {
   // Original untiled consumer operands.
   SmallVector<OpOperand *> origConsumerOperands;
@@ -427,6 +431,14 @@ tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
                             ArrayRef<Operation *> candidateSlices,
                             MutableArrayRef<LoopLikeOpInterface> loops);
 
+/// Fuse the `consumer` operation into the loop nest provided by `loops`.
+/// The transformation looks for operands in the `consumer` that are defined
+/// by the outermost loop of the loop nest in `loops`. The nested loop is
+/// expected to have the structure of the loops generated through tiling.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
+                    MutableArrayRef<LoopLikeOpInterface> loops);
+
 /// Method to lower an `op` that implements the `TilingInterface` to
 /// loops/scalars.
 FailureOr<SmallVector<scf::ForOp>>
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 29b770fb4b279..7e715ee189740 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest(
   for (auto [outerLoop, innerLoop] :
        llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
     // Again assume that all the outer loops are scf.for operations.
-    auto outerForLoop = cast<scf::ForOp>(outerLoop);
+    auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
     auto outerLoopYield =
         cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
     SmallVector<Value> newYields =
@@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter,
   return clonedSlices;
 }
 
-/// Implementation of fusing consumer of a single slice by computing the
-/// slice of the consumer in-place for scf loop.
-FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlices(
-    RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
-    MutableArrayRef<LoopLikeOpInterface> loops) {
-  if (candidateSlices.empty()) {
-    return rewriter.notifyMatchFailure(
-        rewriter.getUnknownLoc(),
-        "no candidate slices provided for consumer fusion");
-  }
-  // Return if `loops` is empty, return an error for now. Caller is expected
-  // to handle this case.
-  if (loops.empty()) {
-    return rewriter.notifyMatchFailure(
-        candidateSlices.front(),
-        "cannot call tile and fuse consumer with an empty loop nest");
-  }
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp,
+                                ArrayRef<OpOperand *> consumerOpOperands,
+                                ArrayRef<Operation *> candidateSlices,
+                                MutableArrayRef<LoopLikeOpInterface> loops) {
+  assert(!loops.empty() && "expected loops to be not empty");
 
-  if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
-        llvm::all_of(candidateSlices,
-                     llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
+  // 1. Check assumption for loop with `reorderOperations` disabled.
+  if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) {
     return rewriter.notifyMatchFailure(
-        candidateSlices.front(),
-        "candidates slices need to be all `tensor.extract_slice`s or "
-        "`tensor.parallel_insert_slice`s");
-  }
-
-  // 1. Get the consumer of scf.for for the result yielded by
-  // tensor.insert_slice/parallel_insert_slice.
-  SmallVector<OpOperand *> consumerOpOperands;
-  Operation *consumerOp;
-  {
-    FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
-        getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
-    if (failed(maybeConsumerOpOperand)) {
-      return rewriter.notifyMatchFailure(candidateSlices.front(),
-                                         "could not fetch consumer to fuse");
-    }
-    std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
-    consumerOp = consumerOpOperands.front()->getOwner();
+        loops.front(), "the first user of loop should not dominate any define "
+                       "of consumer operand(s)");
   }
 
   LoopLikeOpInterface outerMostLoop = loops.front();
   LoopLikeOpInterface innerMostLoop = loops.back();
 
-  // Check assumption for loop with `reorderOperations` disabled.
-  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
-    return rewriter.notifyMatchFailure(
-        outerMostLoop, "the first user of loop should not dominate any define "
-                       "of consumer operand(s)");
-  }
-
   OpBuilder::InsertionGuard g(rewriter);
-
   // 2. Check consumer is not using scf loop's output as init.
   auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
   if (!dstOp)
@@ -2428,11 +2391,171 @@ mlir::scf::tileAndFuseConsumerOfSlices(
       llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
         return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
       });
+  auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
   return scf::SCFFuseConsumerOfSliceResult{
-      std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
+      std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
       std::move(tileAndFuseResult->tiledOps)};
 }
 
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlices(
+    RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
+    MutableArrayRef<LoopLikeOpInterface> loops) {
+  if (candidateSlices.empty()) {
+    return rewriter.notifyMatchFailure(
+        rewriter.getUnknownLoc(),
+        "no candidate slices provided for consumer fusion");
+  }
+  // Return if `loops` is empty, return an error for now. Caller is expected
+  // to handle this case.
+  if (loops.empty()) {
+    return rewriter.notifyMatchFailure(
+        candidateSlices.front(),
+        "cannot call tile and fuse consumer with an empty loop nest");
+  }
+
+  if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
+        llvm::all_of(candidateSlices,
+                     llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
+    return rewriter.notifyMatchFailure(
+        candidateSlices.front(),
+        "candidates slices need to be all `tensor.extract_slice`s or "
+        "`tensor.parallel_insert_slice`s");
+  }
+
+  // Get the consumer of scf.for for the result yielded by
+  // tensor.insert_slice/parallel_insert_slice.
+  SmallVector<OpOperand *> consumerOpOperands;
+  Operation *consumerOp;
+  {
+    FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
+        getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
+    if (failed(maybeConsumerOpOperand)) {
+      return rewriter.notifyMatchFailure(candidateSlices.front(),
+                                         "could not fetch consumer to fuse");
+    }
+    std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
+    consumerOp = consumerOpOperands.front()->getOwner();
+  }
+
+  return tileAndFuseConsumerOfSlicesImpl(
+      rewriter, consumerOp, consumerOpOperands, candidateSlices, loops);
+}
+
+/// For a given `result` of a `forallOp` return the
+/// `tensor.parallel_insert_slice` op (or combining op) that is used to
+/// construct this result.
+static std::optional<Operation *>
+getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) {
+  if (result.getOwner() != forallOp)
+    return std::nullopt;
+  BlockArgument bbArg = forallOp.getTiedBlockArgument(result);
+  SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg);
+  // If the number of combining ops is not 1, then this is unexpected. Return
+  // nullopt.
+  if (combiningOps.size() != 1) {
+    return std::nullopt;
+  }
+  return combiningOps[0];
+}
+
+/// For a given result of the loop nest that is a tiled loop nest, return the
+/// insert slice-like op that is used for consumer fusion
+std::optional<Operation *>
+getProducingInsertSliceLikeOp(OpResult result,
+                              ArrayRef<LoopLikeOpInterface> loops) {
+  assert(!loops.empty() && "Expected loops to be not empty");
+  LoopLikeOpInterface outermostLoop = loops.front();
+
+  if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) {
+    assert(loops.size() == 1 &&
+           "expected only a single loop when tiling using scf.forall");
+    return getProducingParallelInsertSlice(forallOp, result);
+  }
+  // Assume that the loop nest is a nested `scf.for` that is created through
+  // tiling and retrieve the `tensor.insert_slice` operation used to construct
+  // the result.
+  while (loops.size() != 1) {
+    if (result.getOwner() != loops.front())
+      return std::nullopt;
+    auto forOp = dyn_cast<scf::ForOp>(loops.front());
+    if (!forOp)
+      return std::nullopt;
+    auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+    OpResult innerForResult =
+        dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber()));
+    if (!innerForResult)
+      return std::nullopt;
+    result = innerForResult;
+    loops = loops.drop_front();
+  }
+  if (result.getOwner() != loops.front())
+    return std::nullopt;
+  auto forOp = dyn_cast<scf::ForOp>(loops.front());
+  if (!forOp)
+    return std::nullopt;
+  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+  auto insertSliceOp = yieldOp.getOperand(result.getResultNumber())
+                           .getDefiningOp<tensor::InsertSliceOp>();
+  if (!insertSliceOp)
+    return std::nullopt;
+  return insertSliceOp;
+}
+
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user,
+                               MutableArrayRef<LoopLikeOpInterface> loops) {
+  // Only handle users that implement the `TilingInterface`.
+  if (!isa<TilingInterface>(user)) {
+    return rewriter.notifyMatchFailure(
+        user, "unhandled user that does not implement TilingInterface");
+  }
+
+  // Return if `loops` is empty, return an error for now. Caller is expected
+  // to handle this case.
+  if (loops.empty()) {
+    return rewriter.notifyMatchFailure(
+        user, "cannot call tile and fuse consumer with an empty loop nest");
+  }
+
+  LoopLikeOpInterface outermostLoop = loops.front();
+
+  // Collect the operands of the user that come from the outermost loop of the
+  // loop nest.
+  SmallVector<OpOperand *> consumerFusableOperands;
+  for (OpOperand &opOperand : user->getOpOperands()) {
+    if (opOperand.get().getDefiningOp() == outermostLoop) {
+      consumerFusableOperands.push_back(&opOperand);
+    }
+  }
+
+  // Nothing to fuse. Just return an empty set.
+  if (consumerFusableOperands.empty()) {
+    return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
+                                                   SmallVector<OpOperand *>{},
+                                                   SmallVector<Operation *>{}};
+  }
+
+  // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices
+  // for fusion.
+  SmallVector<Operation *> candidateSlices;
+  candidateSlices.reserve(consumerFusableOperands.size());
+  for (OpOperand *opOperand : consumerFusableOperands) {
+    std::optional<Operation *> slice =
+        getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops);
+    if (!slice) {
+      return rewriter.notifyMatchFailure(
+          user,
+          "couldnt find producing insert-slice like operation for operand");
+    }
+    candidateSlices.push_back(slice.value());
+  }
+  return tileAndFuseConsumerOfSlicesImpl(
+      rewriter, user, consumerFusableOperands, candidateSlices, loops);
+}
+
 //===----------------------------------------------------------------------===//
 // lowerToLoopsUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
index 185fb9b358055..d72ab080f3c5c 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
@@ -170,7 +170,7 @@ module {
       // Fuse the consumer operation into the tiled loop.
       %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
           : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
-      transform.test.fuse_consumer %slice_op in (%forall_op)
+      transform.test.fuse_consumer_using_slice %slice_op in (%forall_op)
         : (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
       transform.yield
     }
@@ -231,7 +231,7 @@ module {
       // Fuse the consumer operation into the tiled loop.
       %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
           : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
-      // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
+      // Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice
       // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
       // to fuse" error.
       transform.yield
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
new file mode 100644
index 0000000000000..62dd7faec4eb7
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
@@ -0,0 +1,1156 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0)>
+module {
+  func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.addf %out, %13 : f32
+          linalg.yield %14 : f32
+        } -> tensor<32xf32>
+      %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+      scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+    }
+    %in_operand_2 = tensor.empty() : tensor<64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64xf32>
+    %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32>
+    return %2 : tensor<64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.for"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_for(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+//      CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//      CHECK:   %0 = tensor.empty() : tensor<64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME:      iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0)
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+//      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[ELEM_OUT:.*]] = linalg.add
+// CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT]] :
+//      CHECK:      %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+module {
+  func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+         tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+      }
+    }
+    %in_ope...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2025

@llvm/pr-subscribers-mlir-linalg

Author: None (MaheshRavishankar)

Changes

The existing scf::tileAndFuseConsumerOfSlices takes a list of slices (and loops they are part of), tries to find the consumer of these slices (all slices are expected to be the same consumer), and then tiles the consumer into the loop nest using the TilingInterface. A more natural way of doing consumer fusion is to just start from the consumer, look for operands that are produced by the loop nest passed in as loops (presumably these loops are generated by tiling, but that is not a requirement for consumer fusion). Using the consumer you can find the slices of the operands that are accessed within the loop which you can then use to tile and fuse the consumer (using TilingInterface). This handles more naturally the case where multiple operands of the consumer come from the loop nest.

The scf::tileAndFuseConsumerOfSlices was implemented as a mirror of scf::tileAndFuseProducerOfSlice. For the latter, the slice has a single producer for the source of the slice, which makes it a natural way of specifying producer fusion. But for consumers, the result might have multiple users, resulting in multiple candidates for fusion, as well as a fusion candidate using multiple results from the tiled loop nest. This means using slices
(tensor.insert_slice/tensor.parallel_insert_slice) as a hook for consumer fusion turns out to be quite hard to navigate. The use of the consumer directly avoids all those pain points. In time the scf::tileAndFuseConsumerOfSlices should be deprecated in favor of scf::tileAndFuseConsumer. There is a lot of tech-debt that has accumulated in scf::tileAndFuseConsumerOfSlices that needs to be cleanedup. So while that gets cleaned up, and required functionality is moved to scf::tileAndFuseConsumer, the old path is still maintained.

The test for scf::tileAndFuseConsumerUsingSlices is copied to tile-and-fuse-consumer.mlir to
tile-and-fuse-consumer-using-slices.mlir. All the tests that were there in this file are now using the tileAndFuseConsumer method. The test op test.tile_and_fuse_consumer is modified to call scf::tileAndFuseConsumer, while a new op
test.tile_and_fuse_consumer_of_slice is used to keep the old path tested while it is deprecated.


Patch is 141.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167634.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+5)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+12)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+172-49)
  • (modified) mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir (+2-2)
  • (added) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir (+1156)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+189-191)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+71-8)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+23-1)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index cd033c140a233..8bdf3e0b566ef 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -613,6 +613,11 @@ def ForallOp : SCF_Op<"forall", [
                                     getNumDynamicControlOperands() + getRank());
     }
 
+    BlockArgument getTiedBlockArgument(OpResult opResult) {
+      assert(opResult.getDefiningOp() == getOperation()  && "invalid OpResult");
+      return getBody()->getArgument(getRank() + opResult.getResultNumber());
+    }
+
     ::mlir::Value getInductionVar(int64_t idx) {
       return getInductionVars()[idx];
     }
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 7c735d825b445..0005fad3d5c01 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -415,6 +415,10 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
 /// tiled in a manner that is consistent for all the passed slices. Note that
 /// the method replaces the uses of `candidateSlices` with the tiled and fused
 /// consumer value but does not delete the slice operations.
+/// TODO(MaheshRavishankar): A more natural way of exposing the consumer fusion
+/// is to take the consumer operation, and find the slices to use for fusion
+/// by walking its operands to the `loops` and then into the body to get the
+/// slices used for fusion.
 struct SCFFuseConsumerOfSliceResult {
   // Original untiled consumer operands.
   SmallVector<OpOperand *> origConsumerOperands;
@@ -427,6 +431,14 @@ tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
                             ArrayRef<Operation *> candidateSlices,
                             MutableArrayRef<LoopLikeOpInterface> loops);
 
+/// Fuse the `consumer` operation into the loop nest provided by `loops`.
+/// The transformation looks for operands in the `consumer` that are defined
+/// by the outermost loop of the loop nest in `loops`. The nested loop is
+/// expected to have the structure of the loops generated through tiling.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
+                    MutableArrayRef<LoopLikeOpInterface> loops);
+
 /// Method to lower an `op` that implements the `TilingInterface` to
 /// loops/scalars.
 FailureOr<SmallVector<scf::ForOp>>
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 29b770fb4b279..7e715ee189740 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest(
   for (auto [outerLoop, innerLoop] :
        llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
     // Again assume that all the outer loops are scf.for operations.
-    auto outerForLoop = cast<scf::ForOp>(outerLoop);
+    auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
     auto outerLoopYield =
         cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
     SmallVector<Value> newYields =
@@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter,
   return clonedSlices;
 }
 
-/// Implementation of fusing consumer of a single slice by computing the
-/// slice of the consumer in-place for scf loop.
-FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlices(
-    RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
-    MutableArrayRef<LoopLikeOpInterface> loops) {
-  if (candidateSlices.empty()) {
-    return rewriter.notifyMatchFailure(
-        rewriter.getUnknownLoc(),
-        "no candidate slices provided for consumer fusion");
-  }
-  // Return if `loops` is empty, return an error for now. Caller is expected
-  // to handle this case.
-  if (loops.empty()) {
-    return rewriter.notifyMatchFailure(
-        candidateSlices.front(),
-        "cannot call tile and fuse consumer with an empty loop nest");
-  }
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp,
+                                ArrayRef<OpOperand *> consumerOpOperands,
+                                ArrayRef<Operation *> candidateSlices,
+                                MutableArrayRef<LoopLikeOpInterface> loops) {
+  assert(!loops.empty() && "expected loops to be not empty");
 
-  if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
-        llvm::all_of(candidateSlices,
-                     llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
+  // 1. Check assumption for loop with `reorderOperations` disabled.
+  if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) {
     return rewriter.notifyMatchFailure(
-        candidateSlices.front(),
-        "candidates slices need to be all `tensor.extract_slice`s or "
-        "`tensor.parallel_insert_slice`s");
-  }
-
-  // 1. Get the consumer of scf.for for the result yielded by
-  // tensor.insert_slice/parallel_insert_slice.
-  SmallVector<OpOperand *> consumerOpOperands;
-  Operation *consumerOp;
-  {
-    FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
-        getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
-    if (failed(maybeConsumerOpOperand)) {
-      return rewriter.notifyMatchFailure(candidateSlices.front(),
-                                         "could not fetch consumer to fuse");
-    }
-    std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
-    consumerOp = consumerOpOperands.front()->getOwner();
+        loops.front(), "the first user of loop should not dominate any define "
+                       "of consumer operand(s)");
   }
 
   LoopLikeOpInterface outerMostLoop = loops.front();
   LoopLikeOpInterface innerMostLoop = loops.back();
 
-  // Check assumption for loop with `reorderOperations` disabled.
-  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
-    return rewriter.notifyMatchFailure(
-        outerMostLoop, "the first user of loop should not dominate any define "
-                       "of consumer operand(s)");
-  }
-
   OpBuilder::InsertionGuard g(rewriter);
-
   // 2. Check consumer is not using scf loop's output as init.
   auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
   if (!dstOp)
@@ -2428,11 +2391,171 @@ mlir::scf::tileAndFuseConsumerOfSlices(
       llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
         return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
       });
+  auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
   return scf::SCFFuseConsumerOfSliceResult{
-      std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
+      std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
       std::move(tileAndFuseResult->tiledOps)};
 }
 
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlices(
+    RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
+    MutableArrayRef<LoopLikeOpInterface> loops) {
+  if (candidateSlices.empty()) {
+    return rewriter.notifyMatchFailure(
+        rewriter.getUnknownLoc(),
+        "no candidate slices provided for consumer fusion");
+  }
+  // Return if `loops` is empty, return an error for now. Caller is expected
+  // to handle this case.
+  if (loops.empty()) {
+    return rewriter.notifyMatchFailure(
+        candidateSlices.front(),
+        "cannot call tile and fuse consumer with an empty loop nest");
+  }
+
+  if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
+        llvm::all_of(candidateSlices,
+                     llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
+    return rewriter.notifyMatchFailure(
+        candidateSlices.front(),
+        "candidates slices need to be all `tensor.extract_slice`s or "
+        "`tensor.parallel_insert_slice`s");
+  }
+
+  // Get the consumer of scf.for for the result yielded by
+  // tensor.insert_slice/parallel_insert_slice.
+  SmallVector<OpOperand *> consumerOpOperands;
+  Operation *consumerOp;
+  {
+    FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
+        getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
+    if (failed(maybeConsumerOpOperand)) {
+      return rewriter.notifyMatchFailure(candidateSlices.front(),
+                                         "could not fetch consumer to fuse");
+    }
+    std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
+    consumerOp = consumerOpOperands.front()->getOwner();
+  }
+
+  return tileAndFuseConsumerOfSlicesImpl(
+      rewriter, consumerOp, consumerOpOperands, candidateSlices, loops);
+}
+
+/// For a given `result` of a `forallOp` return the
+/// `tensor.parallel_insert_slice` op (or combining op) that is used to
+/// construct this result.
+static std::optional<Operation *>
+getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) {
+  if (result.getOwner() != forallOp)
+    return std::nullopt;
+  BlockArgument bbArg = forallOp.getTiedBlockArgument(result);
+  SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg);
+  // If the number of combining ops is not 1, then this is unexpected. Return
+  // nullopt.
+  if (combiningOps.size() != 1) {
+    return std::nullopt;
+  }
+  return combiningOps[0];
+}
+
+/// For a given result of the loop nest that is a tiled loop nest, return the
+/// insert slice-like op that is used for consumer fusion
+std::optional<Operation *>
+getProducingInsertSliceLikeOp(OpResult result,
+                              ArrayRef<LoopLikeOpInterface> loops) {
+  assert(!loops.empty() && "Expected loops to be not empty");
+  LoopLikeOpInterface outermostLoop = loops.front();
+
+  if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) {
+    assert(loops.size() == 1 &&
+           "expected only a single loop when tiling using scf.forall");
+    return getProducingParallelInsertSlice(forallOp, result);
+  }
+  // Assume that the loop nest is a nested `scf.for` that is created through
+  // tiling and retrieve the `tensor.insert_slice` operation used to construct
+  // the result.
+  while (loops.size() != 1) {
+    if (result.getOwner() != loops.front())
+      return std::nullopt;
+    auto forOp = dyn_cast<scf::ForOp>(loops.front());
+    if (!forOp)
+      return std::nullopt;
+    auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+    OpResult innerForResult =
+        dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber()));
+    if (!innerForResult)
+      return std::nullopt;
+    result = innerForResult;
+    loops = loops.drop_front();
+  }
+  if (result.getOwner() != loops.front())
+    return std::nullopt;
+  auto forOp = dyn_cast<scf::ForOp>(loops.front());
+  if (!forOp)
+    return std::nullopt;
+  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+  auto insertSliceOp = yieldOp.getOperand(result.getResultNumber())
+                           .getDefiningOp<tensor::InsertSliceOp>();
+  if (!insertSliceOp)
+    return std::nullopt;
+  return insertSliceOp;
+}
+
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user,
+                               MutableArrayRef<LoopLikeOpInterface> loops) {
+  // Only handle users that implement the `TilingInterface`.
+  if (!isa<TilingInterface>(user)) {
+    return rewriter.notifyMatchFailure(
+        user, "unhandled user that does not implement TilingInterface");
+  }
+
+  // Return if `loops` is empty, return an error for now. Caller is expected
+  // to handle this case.
+  if (loops.empty()) {
+    return rewriter.notifyMatchFailure(
+        user, "cannot call tile and fuse consumer with an empty loop nest");
+  }
+
+  LoopLikeOpInterface outermostLoop = loops.front();
+
+  // Collect the operands of the user that come from the outermost loop of the
+  // loop nest.
+  SmallVector<OpOperand *> consumerFusableOperands;
+  for (OpOperand &opOperand : user->getOpOperands()) {
+    if (opOperand.get().getDefiningOp() == outermostLoop) {
+      consumerFusableOperands.push_back(&opOperand);
+    }
+  }
+
+  // Nothing to fuse. Just return an empty set.
+  if (consumerFusableOperands.empty()) {
+    return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
+                                                   SmallVector<OpOperand *>{},
+                                                   SmallVector<Operation *>{}};
+  }
+
+  // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices
+  // for fusion.
+  SmallVector<Operation *> candidateSlices;
+  candidateSlices.reserve(consumerFusableOperands.size());
+  for (OpOperand *opOperand : consumerFusableOperands) {
+    std::optional<Operation *> slice =
+        getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops);
+    if (!slice) {
+      return rewriter.notifyMatchFailure(
+          user,
+          "couldnt find producing insert-slice like operation for operand");
+    }
+    candidateSlices.push_back(slice.value());
+  }
+  return tileAndFuseConsumerOfSlicesImpl(
+      rewriter, user, consumerFusableOperands, candidateSlices, loops);
+}
+
 //===----------------------------------------------------------------------===//
 // lowerToLoopsUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
index 185fb9b358055..d72ab080f3c5c 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
@@ -170,7 +170,7 @@ module {
       // Fuse the consumer operation into the tiled loop.
       %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
           : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
-      transform.test.fuse_consumer %slice_op in (%forall_op)
+      transform.test.fuse_consumer_using_slice %slice_op in (%forall_op)
         : (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
       transform.yield
     }
@@ -231,7 +231,7 @@ module {
       // Fuse the consumer operation into the tiled loop.
       %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
           : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
-      // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
+      // Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice
       // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
       // to fuse" error.
       transform.yield
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
new file mode 100644
index 0000000000000..62dd7faec4eb7
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
@@ -0,0 +1,1156 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0)>
+module {
+  func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.addf %out, %13 : f32
+          linalg.yield %14 : f32
+        } -> tensor<32xf32>
+      %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+      scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
+    }
+    %in_operand_2 = tensor.empty() : tensor<64xf32>
+    %out_operand_3 = tensor.empty() : tensor<64xf32>
+    %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32>
+    return %2 : tensor<64xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %loop = transform.structured.match ops{["scf.for"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop)
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_tileable_consumer_scf_for(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
+//      CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//      CHECK:   %0 = tensor.empty() : tensor<64xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]]
+// CHECK-SAME:      iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0)
+// CHECK-SAME:   {
+//      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[MAT_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
+//      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
+//      CHECK:      %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      %[[ELEM_OUT:.*]] = linalg.add
+// CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
+// CHECK-SAME:              outs(%[[SLICE_OUT]] :
+//      CHECK:      %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
+//      CHECK:      scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#2 :
+
+// -----
+
+module {
+  func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+         tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+      }
+    }
+    %in_ope...
[truncated]

@github-actions
Copy link

github-actions bot commented Nov 12, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Nov 12, 2025
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test for scf::tileAndFuseConsumerUsingSlices is copied to tile-and-fuse-consumer.mlir to
tile-and-fuse-consumer-using-slices.mlir. All the tests that were there in this file are now using the tileAndFuseConsumer method. The test op test.tile_and_fuse_consumer is modified to call scf::tileAndFuseConsumer, while a new op
test.tile_and_fuse_consumer_of_slice is used to keep the old path tested while it is deprecated.

I was curious about how the new op work with multiple consumers, but I don't see fuse_add_multiple_tilable_consumers in the tile-and-fuse-consumer.mlir file. Are we missing the test case?


// -----

// Check that when the given operand tiles are inconsistent, tiling fails.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this no longer hold? I don't see changes on lit checks. Or it may belong to the below test case: multi_slice_fusion_with_broadcast, that has an expected_error check.

@MaheshRavishankar
Copy link
Contributor Author

The test for scf::tileAndFuseConsumerUsingSlices is copied to tile-and-fuse-consumer.mlir to
tile-and-fuse-consumer-using-slices.mlir. All the tests that were there in this file are now using the tileAndFuseConsumer method. The test op test.tile_and_fuse_consumer is modified to call scf::tileAndFuseConsumer, while a new op
test.tile_and_fuse_consumer_of_slice is used to keep the old path tested while it is deprecated.

I was curious about how the new op work with multiple consumers, but I don't see fuse_add_multiple_tilable_consumers in the tile-and-fuse-consumer.mlir file. Are we missing the test case?

I think the expected use is for multiple consumers, you call the fusion method multiple times. I think this test case https://github.com/llvm/llvm-project/pull/167634/files#diff-154c2d387b69e0e07cde6b61865996435b248f8abac476ee42de6c3cfb12d715R736 shows this.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
@github-actions
Copy link

🐧 Linux x64 Test Results

  • 7104 tests passed
  • 594 tests skipped

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@MaheshRavishankar MaheshRavishankar merged commit dbeda4f into llvm:main Nov 20, 2025
8 of 9 checks passed
@MaheshRavishankar MaheshRavishankar deleted the users/MaheshRavishankar/tileAndFuseConsumer branch November 20, 2025 22:14
aadeshps-mcw pushed a commit to aadeshps-mcw/llvm-project that referenced this pull request Nov 26, 2025
… a given tiled loop nest. (llvm#167634)

The existing `scf::tileAndFuseConsumerOfSlices` takes a list of slices
(and loops they are part of), tries to find the consumer of these slices
(all slices are expected to be the same consumer), and then tiles the
consumer into the loop nest using the `TilingInterface`. A more natural
way of doing consumer fusion is to just start from the consumer, look
for operands that are produced by the loop nest passed in as `loops`
(presumably these loops are generated by tiling, but that is not a
requirement for consumer fusion). Using the consumer you can find the
slices of the operands that are accessed within the loop which you can
then use to tile and fuse the consumer (using `TilingInterface`). This
handles more naturally the case where multiple operands of the consumer
come from the loop nest.

The `scf::tileAndFuseConsumerOfSlices` was implemented as a mirror of
`scf::tileAndFuseProducerOfSlice`. For the latter, the slice has a
single producer for the source of the slice, which makes it a natural
way of specifying producer fusion. But for consumers, the result might
have multiple users, resulting in multiple candidates for fusion, as
well as a fusion candidate using multiple results from the tiled loop
nest. This means using slices
(`tensor.insert_slice`/`tensor.parallel_insert_slice`) as a hook for
consumer fusion turns out to be quite hard to navigate. The use of the
consumer directly avoids all those pain points. In time the
`scf::tileAndFuseConsumerOfSlices` should be deprecated in favor of
`scf::tileAndFuseConsumer`. There is a lot of tech-debt that has
accumulated in `scf::tileAndFuseConsumerOfSlices` that needs to be
cleanedup. So while that gets cleaned up, and required functionality is
moved to `scf::tileAndFuseConsumer`, the old path is still maintained.

The test for `scf::tileAndFuseConsumerUsingSlices` is copied to
`tile-and-fuse-consumer.mlir` to
`tile-and-fuse-consumer-using-slices.mlir`. All the tests that were
there in this file are now using the `tileAndFuseConsumer` method. The
test op `test.tile_and_fuse_consumer` is modified to call
`scf::tileAndFuseConsumer`, while a new op
`test.tile_and_fuse_consumer_of_slice` is used to keep the old path
tested while it is deprecated.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Priyanshu3820 pushed a commit to Priyanshu3820/llvm-project that referenced this pull request Nov 26, 2025
… a given tiled loop nest. (llvm#167634)

The existing `scf::tileAndFuseConsumerOfSlices` takes a list of slices
(and loops they are part of), tries to find the consumer of these slices
(all slices are expected to be the same consumer), and then tiles the
consumer into the loop nest using the `TilingInterface`. A more natural
way of doing consumer fusion is to just start from the consumer, look
for operands that are produced by the loop nest passed in as `loops`
(presumably these loops are generated by tiling, but that is not a
requirement for consumer fusion). Using the consumer you can find the
slices of the operands that are accessed within the loop which you can
then use to tile and fuse the consumer (using `TilingInterface`). This
handles more naturally the case where multiple operands of the consumer
come from the loop nest.

The `scf::tileAndFuseConsumerOfSlices` was implemented as a mirror of
`scf::tileAndFuseProducerOfSlice`. For the latter, the slice has a
single producer for the source of the slice, which makes it a natural
way of specifying producer fusion. But for consumers, the result might
have multiple users, resulting in multiple candidates for fusion, as
well as a fusion candidate using multiple results from the tiled loop
nest. This means using slices
(`tensor.insert_slice`/`tensor.parallel_insert_slice`) as a hook for
consumer fusion turns out to be quite hard to navigate. The use of the
consumer directly avoids all those pain points. In time the
`scf::tileAndFuseConsumerOfSlices` should be deprecated in favor of
`scf::tileAndFuseConsumer`. There is a lot of tech-debt that has
accumulated in `scf::tileAndFuseConsumerOfSlices` that needs to be
cleanedup. So while that gets cleaned up, and required functionality is
moved to `scf::tileAndFuseConsumer`, the old path is still maintained.

The test for `scf::tileAndFuseConsumerUsingSlices` is copied to
`tile-and-fuse-consumer.mlir` to
`tile-and-fuse-consumer-using-slices.mlir`. All the tests that were
there in this file are now using the `tileAndFuseConsumer` method. The
test op `test.tile_and_fuse_consumer` is modified to call
`scf::tileAndFuseConsumer`, while a new op
`test.tile_and_fuse_consumer_of_slice` is used to keep the old path
tested while it is deprecated.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
augusto2112 pushed a commit to augusto2112/llvm-project that referenced this pull request Dec 3, 2025
… a given tiled loop nest. (llvm#167634)

The existing `scf::tileAndFuseConsumerOfSlices` takes a list of slices
(and loops they are part of), tries to find the consumer of these slices
(all slices are expected to be the same consumer), and then tiles the
consumer into the loop nest using the `TilingInterface`. A more natural
way of doing consumer fusion is to just start from the consumer, look
for operands that are produced by the loop nest passed in as `loops`
(presumably these loops are generated by tiling, but that is not a
requirement for consumer fusion). Using the consumer you can find the
slices of the operands that are accessed within the loop which you can
then use to tile and fuse the consumer (using `TilingInterface`). This
handles more naturally the case where multiple operands of the consumer
come from the loop nest.

The `scf::tileAndFuseConsumerOfSlices` was implemented as a mirror of
`scf::tileAndFuseProducerOfSlice`. For the latter, the slice has a
single producer for the source of the slice, which makes it a natural
way of specifying producer fusion. But for consumers, the result might
have multiple users, resulting in multiple candidates for fusion, as
well as a fusion candidate using multiple results from the tiled loop
nest. This means using slices
(`tensor.insert_slice`/`tensor.parallel_insert_slice`) as a hook for
consumer fusion turns out to be quite hard to navigate. The use of the
consumer directly avoids all those pain points. In time the
`scf::tileAndFuseConsumerOfSlices` should be deprecated in favor of
`scf::tileAndFuseConsumer`. There is a lot of tech-debt that has
accumulated in `scf::tileAndFuseConsumerOfSlices` that needs to be
cleanedup. So while that gets cleaned up, and required functionality is
moved to `scf::tileAndFuseConsumer`, the old path is still maintained.

The test for `scf::tileAndFuseConsumerUsingSlices` is copied to
`tile-and-fuse-consumer.mlir` to
`tile-and-fuse-consumer-using-slices.mlir`. All the tests that were
there in this file are now using the `tileAndFuseConsumer` method. The
test op `test.tile_and_fuse_consumer` is modified to call
`scf::tileAndFuseConsumer`, while a new op
`test.tile_and_fuse_consumer_of_slice` is used to keep the old path
tested while it is deprecated.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants