diff --git a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp index 44809814df..efe1848f5c 100644 --- a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp @@ -124,6 +124,25 @@ Value getAccumLoopCountArg(scf::ForOp parentForOp) { return tmpAccumLoopCount; } +// Check to see if op is enclosed under ifOp. +static bool enclosing(scf::IfOp ifOp, Operation *op) { + auto pOp = op->getParentOfType(); + while (pOp) { + if (pOp == ifOp) + return true; + pOp = pOp->getParentOfType(); + } + return false; +} + +// Check to see if there is no outer loop that is enclosed under ifOp. +static bool immediateEnclosing(scf::IfOp ifOp, Operation *subOp) { + auto pOp = subOp->getParentOfType(); + if (!pOp) + return true; + return !enclosing(ifOp, pOp.getOperation()); +} + // Return true if the IfOp contains a ForOp that is in loopWithBufferReuse. static bool needAccumulatedLoopCnt(scf::IfOp ifOp, @@ -132,7 +151,9 @@ needAccumulatedLoopCnt(scf::IfOp ifOp, ifOp.walk([&](Operation *subOp) { if (auto forOp = dyn_cast(subOp)) for (auto tLoop : loopWithBufferReuse) - if (forOp.getOperation() == tLoop) { + // For the case of ifOp contains forOp, which contains subOp, no need to + // generate accumLoopCount for ifOp. + if (forOp.getOperation() == tLoop && immediateEnclosing(ifOp, tLoop)) { needAccum = true; break; }