Skip to content

Commit d274784

Browse files
author
Bin Li
committed
callbacks_map -> done, swapping false and true
1 parent d07496c commit d274784

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

src/relay/ir/dataflow_matcher.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -796,15 +796,12 @@ Expr PatternRewriter::Rewrite(const Array<DFPatternCallback>& callbacks, const E
796796
bool equal = true;
797797
static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual");
798798
ICHECK(structural_equal) << "node.StructuralEqual is not registered.";
799-
// Use the callback until the callback's attribute is rewrite_once=true and has been rewritten
800-
std::unordered_map<DFPatternCallback, bool, ObjectPtrHash, ObjectPtrEqual> callbacks_map;
801-
for (auto callback : callbacks) {
802-
callbacks_map.insert({callback, true});
803-
}
799+
// Keep track of callbacks that have finished rewriting
800+
std::unordered_map<DFPatternCallback, bool, ObjectPtrHash, ObjectPtrEqual> done;
804801
do {
805802
last = post;
806803
for (auto callback : callbacks) {
807-
if (callbacks_map[callback]) {
804+
if (!done[callback]) {
808805
auto before = post;
809806
callback_ = callback;
810807
if (callback_->require_type) {
@@ -821,7 +818,7 @@ Expr PatternRewriter::Rewrite(const Array<DFPatternCallback>& callbacks, const E
821818
if (callback_->rewrite_once) {
822819
bool current_equal = (*structural_equal)(before, post, false, true);
823820
if (!current_equal) {
824-
callbacks_map[callback] = false;
821+
done[callback] = true;
825822
}
826823
}
827824
}

0 commit comments

Comments
 (0)