Skip to content

Commit d8f57ed

Browse files
authored
[Relay] IndexedGraph improvements in preparation for Collage (#11481)
* [Relay] Odd's 'n ends changes to help Collage. - Complete the implementation of WithFields. (Unfortunately they appear to be without unit tests and I continue this tradition...) - InferTypeExpr for InferTypeLocal but return the expression rather than the type. - Remove python binding of InlineComposites since C++ impl was removed some time ago. - Make IndexedGraph<Expr/DFPattern> more robust as stand-alone datastructure, and avoid unnecessary copies. This will become a fundamental datastructure in Collage rather than just a helper for DFPatternMatcher. - Extend IndexedGraph with a notion of 'basic block' on every dataflow node. Needed by Collage to avoid impossible partitions. * - Revert non IndexedGraph changes. * - Stick to 'Indexed graph' terminology - More tests * - Stick to 'Indexed graph' terminology - More tests * - Remove silly unit test
1 parent 8170219 commit d8f57ed

File tree

7 files changed

+922
-237
lines changed

7 files changed

+922
-237
lines changed

src/relay/ir/dataflow_matcher.cc

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ namespace relay {
3636

3737
// Pattern Matcher
3838
bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
39+
VLOG(1) << "Match " << PrettyPrint(pattern) << " in:" << std::endl << PrettyPrint(expr);
3940
memo_.clear();
4041
matched_nodes_.clear();
4142
return VisitDFPattern(pattern, expr);
@@ -58,6 +59,7 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr
5859
if (out) {
5960
memo_[pattern].push_back(expr);
6061
matched_nodes_.push_back(pattern);
62+
VLOG(1) << "Matched " << PrettyPrint(pattern) << " at:" << std::endl << PrettyPrint(expr);
6163
} else {
6264
ClearMap(watermark);
6365
}
@@ -124,7 +126,6 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
124126
if (!matches) {
125127
return matches;
126128
}
127-
VLOG(1) << "considering AttrPatternNode at:\n" << PrettyPrint(expr);
128129
auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
129130
if (const auto* op_node = expr.as<OpNode>()) {
130131
Op op = GetRef<Op>(op_node);
@@ -299,14 +300,18 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
299300
// Recursively find the Dominator parent along all inputs paths.
300301
bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
301302
auto call_node = expr.as<CallNode>();
302-
for (auto node : expr_graph_.node_map_.at(expr)->inputs_) {
303-
if (!(call_node && node->ref_ == call_node->op)) {
303+
auto index_node = expr_to_node(expr);
304+
for (auto node : index_node->inputs_) {
305+
if (!(call_node && node->ref() == call_node->op)) {
304306
memoize_ = true;
305-
if (VisitDFPattern(op->parent, node->ref_)) {
307+
if (VisitDFPattern(op->parent, node->ref())) {
306308
return true;
307309
} else {
308310
memoize_ = false;
309-
if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) {
311+
if (!VisitDFPattern(op->path, node->ref())) {
312+
return false;
313+
}
314+
if (!MatchesPath(op, node->ref())) {
310315
return false;
311316
}
312317
}
@@ -318,19 +323,19 @@ bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& e
318323
// Iteratively ensure that the parent is dominated somewhere by the child or the path
319324
bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
320325
std::stack<Expr> stack;
321-
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visited;
326+
std::unordered_set<const ExprNode*> visited;
322327
stack.push(expr);
323328
while (!stack.empty()) {
324329
Expr current = stack.top();
325330
stack.pop();
326-
for (auto node : expr_graph_.node_map_.at(current)->dominator_children_) {
327-
if (visited.count(node->ref_) == 0) {
328-
if (VisitDFPattern(op->parent, node->ref_)) {
331+
for (auto node : expr_to_node(current)->dominator_children_) {
332+
if (visited.count(node->node_ref_) == 0) {
333+
if (VisitDFPattern(op->parent, node->ref())) {
329334
return true;
330335
} else {
331-
stack.push(node->ref_);
336+
stack.push(node->ref());
332337
}
333-
visited.insert(node->ref_);
338+
visited.insert(node->node_ref_);
334339
}
335340
}
336341
}
@@ -500,7 +505,8 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr
500505
}
501506

502507
bool MatchPattern(DFPattern pattern, Expr expr) {
503-
return DFPatternMatcher(expr).Match(pattern, expr);
508+
std::unique_ptr<IndexedGraph<Expr>> expr_graph = CreateIndexedGraph(expr);
509+
return DFPatternMatcher(expr_graph.get()).Match(pattern, expr);
504510
}
505511

506512
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern);
@@ -575,17 +581,18 @@ const std::unordered_map<int, PatternGrouper::Group>& PatternGrouper::GroupMatch
575581

576582
pattern_ = pattern;
577583
pattern_graph_ = CreateIndexedGraph(pattern_);
578-
auto matcher = DFPatternMatcher(pre);
584+
std::unique_ptr<IndexedGraph<Expr>> expr_graph = CreateIndexedGraph(pre);
585+
DFPatternMatcher matcher(expr_graph.get());
579586
matcher_ = &matcher;
580587
this->VisitExprs();
581588
return this->groups_;
582589
}
583590

584591
void PatternGrouper::VisitExprs() {
585592
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> pre_partitioned;
586-
for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) {
587-
size_t index = i - 1;
588-
Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_;
593+
for (PostDfsIndex i = matcher_->size(); i != 0; --i) {
594+
PostDfsIndex index = i - 1;
595+
const auto current = matcher_->index_to_node(index)->ref();
589596
if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped
590597
if (auto op = current.as<FunctionNode>()) {
591598
if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) {
@@ -607,22 +614,24 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
607614
auto node_map = matcher_->GetMemo();
608615
// Get fuzzy patterns
609616
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
610-
for (auto node : pattern_graph_.topological_order_) {
617+
for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) {
618+
auto node = pattern_graph_->index_to_node(index);
611619
// Don't treat fuzzy Dominator patterns input variables for partition
612-
if (auto op = node->ref_.as<DominatorPatternNode>()) {
620+
if (auto op = node->ref().as<DominatorPatternNode>()) {
613621
for (auto fuzzy_op : {op->parent, op->path}) {
614622
for (auto match : node_map[fuzzy_op]) {
615623
fuzzy_matches.insert(match);
616624
}
617625
}
618626
}
619627
// Don't treat Function params or body as input variables for partition
620-
if (node->ref_.as<FunctionPatternNode>()) {
621-
auto matches = node_map[node->ref_];
628+
if (node->ref().as<FunctionPatternNode>()) {
629+
auto matches = node_map[node->ref()];
622630
for (auto match : matches) {
623-
auto graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
624-
for (auto node : graph.topological_order_) {
625-
fuzzy_matches.insert(node->ref_);
631+
auto sub_graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
632+
for (PostDfsIndex sub_index = 0; sub_index < sub_graph->size(); ++sub_index) {
633+
auto sub_node = sub_graph->index_to_node(sub_index);
634+
fuzzy_matches.insert(sub_node->ref());
626635
}
627636
}
628637
}
@@ -636,10 +645,11 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
636645
std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs;
637646
Array<Var> params;
638647

639-
for (auto node : pattern_graph_.topological_order_) {
648+
for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) {
649+
auto node = pattern_graph_->index_to_node(index);
640650
auto make_input = [&](const Expr& input) {
641651
if (fuzzy_matches.count(input) == 0 && input.as<OpNode>() == nullptr &&
642-
input.as<FunctionNode>() == nullptr && !EmbedConst(input, node->ref_)) {
652+
input.as<FunctionNode>() == nullptr && !EmbedConst(input, node->ref())) {
643653
inputs[input] =
644654
Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
645655
NullValue<Type>());
@@ -648,29 +658,29 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
648658
var_number++;
649659
}
650660
};
651-
auto tuple = node->ref_.as<TuplePatternNode>();
652-
auto call = node->ref_.as<CallPatternNode>();
661+
auto tuple = node->ref().as<TuplePatternNode>();
662+
auto call = node->ref().as<CallPatternNode>();
653663
if (tuple && !tuple->fields.defined()) {
654-
if (node_map.count(node->ref_)) {
655-
auto matches = node_map[node->ref_];
664+
if (node_map.count(node->ref())) {
665+
auto matches = node_map[node->ref()];
656666
for (auto match : matches) {
657667
for (auto input : match.as<TupleNode>()->fields) {
658668
make_input(input);
659669
}
660670
}
661671
}
662672
} else if (call && !call->args.defined()) {
663-
if (node_map.count(node->ref_)) {
664-
auto matches = node_map[node->ref_];
673+
if (node_map.count(node->ref())) {
674+
auto matches = node_map[node->ref()];
665675
for (auto match : matches) {
666676
for (auto input : match.as<CallNode>()->args) {
667677
make_input(input);
668678
}
669679
}
670680
}
671681
} else if (node->inputs_.size() == 0) {
672-
if (node_map.count(node->ref_)) {
673-
auto matches = node_map[node->ref_];
682+
if (node_map.count(node->ref())) {
683+
auto matches = node_map[node->ref()];
674684
for (auto match : matches) {
675685
make_input(match);
676686
}
@@ -708,13 +718,17 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
708718
return;
709719
} else if (kv.second != body) {
710720
// if the node isn't the output of the group
711-
auto node = matcher_->expr_graph_.node_map_.at(kv.first);
721+
auto node = matcher_->expr_to_node(kv.first);
712722
for (auto* output : node->outputs_) {
713723
// and the node is used by nodes outside of the group
714-
if (memo.count(output->ref_) == 0 &&
715-
!matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) {
716-
// Exit because nodes in this pattern's body are used outside the pattern
717-
// fusing it would be invalid
724+
if (memo.count(output->ref()) == 0) {
725+
// TODO(mbs): This condition used to also include the following test, which since
726+
// the dominators relation is used back-to-front was always vacuously true. So the
727+
// code is just rejecting the match if a strictly internal node happened to connect
728+
// to an outside node.
729+
ICHECK(!matcher_->expr_to_node(expr)->Dominates(output));
730+
// Exit because nodes in this pattern's body are used outside the pattern, fusing it
731+
// would be invalid
718732
return;
719733
}
720734
}

src/relay/ir/dataflow_matcher_impl.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
#include <tvm/relay/dataflow_matcher.h>
2828
#include <tvm/relay/dataflow_pattern.h>
2929
#include <tvm/relay/dataflow_pattern_functor.h>
30+
#include <tvm/relay/expr_functor.h>
3031

32+
#include <memory>
3133
#include <string>
3234
#include <unordered_map>
3335
#include <vector>
@@ -39,10 +41,20 @@ namespace relay {
3941

4042
class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
4143
public:
42-
explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
44+
explicit DFPatternMatcher(const IndexedGraph<Expr>* expr_graph) : expr_graph_(expr_graph) {}
4345
bool Match(const DFPattern& pattern, const Expr& expr);
4446
Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
45-
const IndexedGraph<Expr> expr_graph_;
47+
48+
const IndexedGraph<Expr>::Node* expr_to_node(const Expr& expr) const {
49+
return expr_graph_->item_to_node(expr);
50+
}
51+
const IndexedGraph<Expr>::Node* index_to_node(size_t index) const {
52+
return expr_graph_->index_to_node(index);
53+
}
54+
size_t size() const { return expr_graph_->size(); }
55+
const std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual>& memo() const {
56+
return memo_;
57+
}
4658

4759
protected:
4860
bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
@@ -67,6 +79,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
6779
bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
6880
bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
6981

82+
const IndexedGraph<Expr>* expr_graph_;
7083
std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual> memo_;
7184
std::vector<DFPattern> matched_nodes_;
7285
bool memoize_ = true;
@@ -131,7 +144,7 @@ class PatternGrouper {
131144
std::unordered_map<int, Group> groups_;
132145
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_;
133146
DFPatternMatcher* matcher_ = nullptr;
134-
IndexedGraph<DFPattern> pattern_graph_;
147+
std::unique_ptr<IndexedGraph<DFPattern>> pattern_graph_;
135148
int gid_ = 0;
136149
int graph_number_ = 0;
137150
};

0 commit comments

Comments
 (0)