diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 44b1bcf8e4300..66f369e8a5f65 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -779,13 +779,13 @@ void RemoveDeadValues::runOnOperation() { module->walk([&](RegionBranchOpInterface regionBranchOp) { opsToCanonicalize.push_back(regionBranchOp.getOperation()); }); - // TODO: Apply only region branch op canonicalization patterns or find a - // better API to collect all canonicalization patterns. + // Collect all canonicalization patterns for region branch ops. RewritePatternSet owningPatterns(context); - for (auto *dialect : context->getLoadedDialects()) - dialect->getCanonicalizationPatterns(owningPatterns); - for (RegisteredOperationName op : context->getRegisteredOperations()) - op.getCanonicalizationPatterns(owningPatterns, context); + DenseSet populatedPatterns; + for (Operation *op : opsToCanonicalize) + if (std::optional info = op->getRegisteredInfo()) + if (populatedPatterns.insert(*info).second) + info->getCanonicalizationPatterns(owningPatterns, context); if (failed(applyOpPatternsGreedily(opsToCanonicalize, std::move(owningPatterns)))) { module->emitError("greedy pattern rewrite failed to converge");