diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 5784da3269a6..f68b882aa0f5 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -44,12 +44,13 @@ SmallVector mmaVersionToShapePerWarp(int version) { SmallVector warpsPerTileV2(triton::DotOp dotOp, const ArrayRef shape, int numWarps) { - SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); - if (llvm::find_if(slices, [](Operation *op) { - return isa(op); - }) != slices.end()) - return {(unsigned)numWarps, 1}; + auto filter = [&dotOp](Operation *op) { + return op->getParentRegion() == dotOp->getParentRegion(); + }; + auto slices = mlir::getSlice(dotOp, filter); + for (Operation *op : slices) + if (isa(op) && (op != dotOp)) + return {(unsigned)numWarps, 1}; SmallVector ret = {1, 1}; SmallVector shapePerWarp = {16, 8};