diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp index 5d3cec402cab1..860384f954536 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -43,50 +43,34 @@ static bool overrideBuffer(Operation *op, Value buffer) { /// propagate the type change and erase old subview ops. static void replaceUsesAndPropagateType(RewriterBase &rewriter, Operation *oldOp, Value val) { - SmallVector opsToDelete; - SmallVector operandsToReplace; - - // Save the operand to replace / delete later (avoid iterator invalidation). - // TODO: can we use an early_inc iterator? - for (OpOperand &use : oldOp->getUses()) { - // Non-subview ops will be replaced by `val`. - auto subviewUse = dyn_cast(use.getOwner()); - if (!subviewUse) { - operandsToReplace.push_back(&use); + // Iterate with early_inc to erase current user inside the loop. + for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) { + Operation *user = use.getOwner(); + if (auto subviewUse = dyn_cast(user)) { + // `subview(old_op)` is replaced by a new `subview(val)`. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(subviewUse); + MemRefType newType = memref::SubViewOp::inferRankReducedResultType( + subviewUse.getType().getShape(), cast(val.getType()), + subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), + subviewUse.getStaticStrides()); + Value newSubview = memref::SubViewOp::create( + rewriter, subviewUse->getLoc(), newType, val, + subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), + subviewUse.getMixedStrides()); + + // Ouch recursion ... is this really necessary? + replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); + + // Safe to erase. + rewriter.eraseOp(subviewUse); continue; } - - // `subview(old_op)` is replaced by a new `subview(val)`. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(subviewUse); - MemRefType newType = memref::SubViewOp::inferRankReducedResultType( - subviewUse.getType().getShape(), cast(val.getType()), - subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), - subviewUse.getStaticStrides()); - Value newSubview = memref::SubViewOp::create( - rewriter, subviewUse->getLoc(), newType, val, - subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), - subviewUse.getMixedStrides()); - - // Ouch recursion ... is this really necessary? - replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); - - opsToDelete.push_back(use.getOwner()); + // Non-subview: replace with new value. + rewriter.startOpModification(user); + use.set(val); + rewriter.finalizeOpModification(user); } - - // Perform late replacement. - // TODO: can we use an early_inc iterator? - for (OpOperand *operand : operandsToReplace) { - Operation *op = operand->getOwner(); - rewriter.startOpModification(op); - operand->set(val); - rewriter.finalizeOpModification(op); - } - - // Perform late op erasure. - // TODO: can we use an early_inc iterator? - for (Operation *op : opsToDelete) - rewriter.eraseOp(op); } // Transformation to do multi-buffering/array expansion to remove dependencies @@ -216,8 +200,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, offsets, sizes, strides); LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); - // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to - // handle dealloc uses separately.. + // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need + // to handle dealloc uses separately.. for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) { auto deallocOp = dyn_cast(use.getOwner()); if (!deallocOp)