@@ -36,6 +36,7 @@ namespace relay {
3636
3737// Pattern Matcher
3838bool DFPatternMatcher::Match (const DFPattern& pattern, const Expr& expr) {
39+ VLOG (1 ) << " Match " << PrettyPrint (pattern) << " in:" << std::endl << PrettyPrint (expr);
3940 memo_.clear ();
4041 matched_nodes_.clear ();
4142 return VisitDFPattern (pattern, expr);
@@ -58,6 +59,7 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr
5859 if (out) {
5960 memo_[pattern].push_back (expr);
6061 matched_nodes_.push_back (pattern);
62+ VLOG (1 ) << " Matched " << PrettyPrint (pattern) << " at:" << std::endl << PrettyPrint (expr);
6163 } else {
6264 ClearMap (watermark);
6365 }
@@ -124,7 +126,6 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
124126 if (!matches) {
125127 return matches;
126128 }
127- VLOG (1 ) << " considering AttrPatternNode at:\n " << PrettyPrint (expr);
128129 auto attributes = attr_pattern->attrs .as <DictAttrsNode>()->dict ;
129130 if (const auto * op_node = expr.as <OpNode>()) {
130131 Op op = GetRef<Op>(op_node);
@@ -299,14 +300,18 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
299300// Recursively find the Dominator parent along all inputs paths.
300301bool DFPatternMatcher::MatchesPath (const DominatorPatternNode* op, const Expr& expr) {
301302 auto call_node = expr.as <CallNode>();
302- for (auto node : expr_graph_.node_map_ .at (expr)->inputs_ ) {
303- if (!(call_node && node->ref_ == call_node->op )) {
303+ auto index_node = expr_to_node (expr);
304+ for (auto node : index_node->inputs_ ) {
305+ if (!(call_node && node->ref () == call_node->op )) {
304306 memoize_ = true ;
305- if (VisitDFPattern (op->parent , node->ref_ )) {
307+ if (VisitDFPattern (op->parent , node->ref () )) {
306308 return true ;
307309 } else {
308310 memoize_ = false ;
309- if (!VisitDFPattern (op->path , node->ref_ ) || !MatchesPath (op, node->ref_ )) {
311+ if (!VisitDFPattern (op->path , node->ref ())) {
312+ return false ;
313+ }
314+ if (!MatchesPath (op, node->ref ())) {
310315 return false ;
311316 }
312317 }
@@ -318,19 +323,19 @@ bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& e
318323// Iteratively ensure that the parent is dominated somewhere by the child or the path
319324bool DFPatternMatcher::DominatesParent (const DominatorPatternNode* op, const Expr& expr) {
320325 std::stack<Expr> stack;
321- std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual > visited;
326+ std::unordered_set<const ExprNode* > visited;
322327 stack.push (expr);
323328 while (!stack.empty ()) {
324329 Expr current = stack.top ();
325330 stack.pop ();
326- for (auto node : expr_graph_. node_map_ . at (current)->dominator_children_ ) {
327- if (visited.count (node->ref_ ) == 0 ) {
328- if (VisitDFPattern (op->parent , node->ref_ )) {
331+ for (auto node : expr_to_node (current)->dominator_children_ ) {
332+ if (visited.count (node->node_ref_ ) == 0 ) {
333+ if (VisitDFPattern (op->parent , node->ref () )) {
329334 return true ;
330335 } else {
331- stack.push (node->ref_ );
336+ stack.push (node->ref () );
332337 }
333- visited.insert (node->ref_ );
338+ visited.insert (node->node_ref_ );
334339 }
335340 }
336341 }
@@ -500,7 +505,8 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr
500505}
501506
502507bool MatchPattern (DFPattern pattern, Expr expr) {
503- return DFPatternMatcher (expr).Match (pattern, expr);
508+ std::unique_ptr<IndexedGraph<Expr>> expr_graph = CreateIndexedGraph (expr);
509+ return DFPatternMatcher (expr_graph.get ()).Match (pattern, expr);
504510}
505511
506512TVM_REGISTER_GLOBAL (" relay.dataflow_pattern.match" ).set_body_typed(MatchPattern);
@@ -575,17 +581,18 @@ const std::unordered_map<int, PatternGrouper::Group>& PatternGrouper::GroupMatch
575581
576582 pattern_ = pattern;
577583 pattern_graph_ = CreateIndexedGraph (pattern_);
578- auto matcher = DFPatternMatcher (pre );
584+ std::unique_ptr<IndexedGraph<Expr>> expr_graph = CreateIndexedGraph (pre );
585+ DFPatternMatcher matcher (expr_graph.get ());
579586 matcher_ = &matcher;
580587 this ->VisitExprs ();
581588 return this ->groups_ ;
582589}
583590
584591void PatternGrouper::VisitExprs () {
585592 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> pre_partitioned;
586- for (size_t i = matcher_->expr_graph_ . topological_order_ . size (); i != 0 ; --i) {
587- size_t index = i - 1 ;
588- Expr current = matcher_->expr_graph_ . topological_order_ . at (index)->ref_ ;
593+ for (PostDfsIndex i = matcher_->size (); i != 0 ; --i) {
594+ PostDfsIndex index = i - 1 ;
595+ const auto current = matcher_->index_to_node (index)->ref () ;
589596 if (gid_assignments_.count (current) == 0 ) { // Don't visit nodes we've already grouped
590597 if (auto op = current.as <FunctionNode>()) {
591598 if (op->attrs .defined () && op->attrs ->dict .count (attr::kPartitionedFromPattern ) != 0 ) {
@@ -607,22 +614,24 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
607614 auto node_map = matcher_->GetMemo ();
608615 // Get fuzzy patterns
609616 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
610- for (auto node : pattern_graph_.topological_order_ ) {
617+ for (PostDfsIndex index = 0 ; index < pattern_graph_->size (); ++index) {
618+ auto node = pattern_graph_->index_to_node (index);
611619 // Don't treat fuzzy Dominator patterns input variables for partition
612- if (auto op = node->ref_ .as <DominatorPatternNode>()) {
620+ if (auto op = node->ref () .as <DominatorPatternNode>()) {
613621 for (auto fuzzy_op : {op->parent , op->path }) {
614622 for (auto match : node_map[fuzzy_op]) {
615623 fuzzy_matches.insert (match);
616624 }
617625 }
618626 }
619627 // Don't treat Function params or body as input variables for partition
620- if (node->ref_ .as <FunctionPatternNode>()) {
621- auto matches = node_map[node->ref_ ];
628+ if (node->ref () .as <FunctionPatternNode>()) {
629+ auto matches = node_map[node->ref () ];
622630 for (auto match : matches) {
623- auto graph = CreateIndexedGraph (match.as <FunctionNode>()->body );
624- for (auto node : graph.topological_order_ ) {
625- fuzzy_matches.insert (node->ref_ );
631+ auto sub_graph = CreateIndexedGraph (match.as <FunctionNode>()->body );
632+ for (PostDfsIndex sub_index = 0 ; sub_index < sub_graph->size (); ++sub_index) {
633+ auto sub_node = sub_graph->index_to_node (sub_index);
634+ fuzzy_matches.insert (sub_node->ref ());
626635 }
627636 }
628637 }
@@ -636,10 +645,11 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
636645 std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs;
637646 Array<Var> params;
638647
639- for (auto node : pattern_graph_.topological_order_ ) {
648+ for (PostDfsIndex index = 0 ; index < pattern_graph_->size (); ++index) {
649+ auto node = pattern_graph_->index_to_node (index);
640650 auto make_input = [&](const Expr& input) {
641651 if (fuzzy_matches.count (input) == 0 && input.as <OpNode>() == nullptr &&
642- input.as <FunctionNode>() == nullptr && !EmbedConst (input, node->ref_ )) {
652+ input.as <FunctionNode>() == nullptr && !EmbedConst (input, node->ref () )) {
643653 inputs[input] =
644654 Var (" FunctionVar_" + std::to_string (graph_number_) + " _" + std::to_string (var_number),
645655 NullValue<Type>());
@@ -648,29 +658,29 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
648658 var_number++;
649659 }
650660 };
651- auto tuple = node->ref_ .as <TuplePatternNode>();
652- auto call = node->ref_ .as <CallPatternNode>();
661+ auto tuple = node->ref () .as <TuplePatternNode>();
662+ auto call = node->ref () .as <CallPatternNode>();
653663 if (tuple && !tuple->fields .defined ()) {
654- if (node_map.count (node->ref_ )) {
655- auto matches = node_map[node->ref_ ];
664+ if (node_map.count (node->ref () )) {
665+ auto matches = node_map[node->ref () ];
656666 for (auto match : matches) {
657667 for (auto input : match.as <TupleNode>()->fields ) {
658668 make_input (input);
659669 }
660670 }
661671 }
662672 } else if (call && !call->args .defined ()) {
663- if (node_map.count (node->ref_ )) {
664- auto matches = node_map[node->ref_ ];
673+ if (node_map.count (node->ref () )) {
674+ auto matches = node_map[node->ref () ];
665675 for (auto match : matches) {
666676 for (auto input : match.as <CallNode>()->args ) {
667677 make_input (input);
668678 }
669679 }
670680 }
671681 } else if (node->inputs_ .size () == 0 ) {
672- if (node_map.count (node->ref_ )) {
673- auto matches = node_map[node->ref_ ];
682+ if (node_map.count (node->ref () )) {
683+ auto matches = node_map[node->ref () ];
674684 for (auto match : matches) {
675685 make_input (match);
676686 }
@@ -708,13 +718,17 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
708718 return ;
709719 } else if (kv.second != body) {
710720 // if the node isn't the output of the group
711- auto node = matcher_->expr_graph_ . node_map_ . at (kv.first );
721+ auto node = matcher_->expr_to_node (kv.first );
712722 for (auto * output : node->outputs_ ) {
713723 // and the node is used by nodes outside of the group
714- if (memo.count (output->ref_ ) == 0 &&
715- !matcher_->expr_graph_ .node_map_ .at (expr)->Dominates (output)) {
716- // Exit because nodes in this pattern's body are used outside the pattern
717- // fusing it would be invalid
724+ if (memo.count (output->ref ()) == 0 ) {
725+ // TODO(mbs): This condition used to also include the following test, which since
726+ // the dominators relation is used back-to-front was always vacuously true. So the
727+ // code is just rejecting the match if a strictly internal node happened to connect
728+ // to an outside node.
729+ ICHECK (!matcher_->expr_to_node (expr)->Dominates (output));
730+ // Exit because nodes in this pattern's body are used outside the pattern, fusing it
731+ // would be invalid
718732 return ;
719733 }
720734 }
0 commit comments