Skip to content

Commit a49a7fe

Browse files
author
Siyuan Feng
authored
[Relay][Pass] Separate out the graph partitioning code from fuse_ops.cc (#13964)
* [Relay][Pass] Separate out the graph partitioning code from fuse_ops.cc The current `fuse_ops.cc` contains the following parts: 1. `IndexedForwardGraph` and `DominatorTree` which are used for graph partitioning 2. A Relay Expr visitor to create the `DominatorTree` 3. A Relay Expr mutator to fuse the ops This PR separates the graph partitioning code from `fuse_ops.cc` and moves it to the analysis folder, for: 1. Better code organization and readability as the graph partitioning code is quite long and not directly related to the fusion mutator 2. Possible reuse opportunities for other fusion passes in Relax NOTE: we won't bring relax fusion in `main` branch for now, but this pr is still reasonable for `main`. * lint
1 parent 49b6c3a commit a49a7fe

File tree

3 files changed

+615
-504
lines changed

3 files changed

+615
-504
lines changed
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include "./graph_partitioner.h"
21+
22+
#include <vector>
23+
24+
namespace tvm {
25+
namespace relay {
26+
27+
DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) {
28+
DominatorTree tree;
29+
tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
30+
// reverse topo order
31+
for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
32+
size_t index = i - 1;
33+
tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
34+
}
35+
return tree;
36+
}
37+
38+
DominatorTree::Node* DominatorTree::LeastCommonAncestor(Node* lhs, Node* rhs,
39+
OpPatternKind* edge_pattern) {
40+
while (lhs != rhs) {
41+
if (lhs == nullptr) return nullptr;
42+
if (rhs == nullptr) return nullptr;
43+
if (lhs->depth < rhs->depth) {
44+
edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
45+
rhs = rhs->parent;
46+
} else if (rhs->depth < lhs->depth) {
47+
edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
48+
lhs = lhs->parent;
49+
} else {
50+
edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
51+
edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
52+
lhs = lhs->parent;
53+
rhs = rhs->parent;
54+
}
55+
}
56+
return lhs;
57+
}
58+
59+
DominatorTree::Node* DominatorTree::LeastCommonAncestor(
60+
const LinkedList<IndexedForwardGraph::Edge>& input_nodes, OpPatternKind* edge_pattern) {
61+
auto link = input_nodes.head;
62+
if (link == nullptr) {
63+
return nullptr;
64+
}
65+
auto get_node = [&](const IndexedForwardGraph::Edge& edge) {
66+
size_t oindex = edge.node->index;
67+
ICHECK_LT(oindex, nodes.size());
68+
Node* onode = nodes[oindex];
69+
ICHECK(onode != nullptr);
70+
return onode;
71+
};
72+
Node* parent = get_node(link->value);
73+
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
74+
link = link->next;
75+
for (; link != nullptr; link = link->next) {
76+
parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
77+
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
78+
}
79+
return parent;
80+
}
81+
82+
DominatorTree::Node* DominatorTree::GetNode(support::Arena* arena,
83+
IndexedForwardGraph::Node* gnode) {
84+
Node* tnode = arena->make<Node>();
85+
tnode->gnode = gnode;
86+
if (gnode->extern_ref) {
87+
tnode->depth = 1;
88+
tnode->parent = nullptr;
89+
tnode->pattern = kOpaque;
90+
} else {
91+
// find the LCAs of all outputs.
92+
OpPatternKind pattern = kElemWise;
93+
Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
94+
tnode->depth = parent ? parent->depth + 1 : 1;
95+
tnode->parent = parent;
96+
tnode->pattern = pattern;
97+
}
98+
return tnode;
99+
}
100+
101+
std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
102+
const IndexedForwardGraph& graph) {
103+
this->InitGroups(graph);
104+
if (opt_level_ == 0) return std::move(groups_);
105+
// get post dominator tree
106+
auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
107+
// run fusion algorithm.
108+
for (int phase = 0; phase < 3; ++phase) {
109+
this->RunFuse(graph, post_dom_tree, phase);
110+
}
111+
return std::move(groups_);
112+
}
113+
114+
GraphPartitioner::Group* GraphPartitioner::Group::FindRoot() {
115+
// fast path
116+
if (this->parent == nullptr) return this;
117+
// slow path with path compression.
118+
Group* root = this;
119+
while (root->parent != nullptr) {
120+
root = root->parent;
121+
}
122+
for (Group* p = this; p != root;) {
123+
Group* parent = p->parent;
124+
p->parent = root;
125+
p = parent;
126+
}
127+
return root;
128+
}
129+
130+
template <typename F>
131+
bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
132+
F fcond) {
133+
if (visited_.count(src)) return true;
134+
visited_.insert(src);
135+
Group* gnode = groups_[src->index];
136+
ICHECK(gnode != nullptr);
137+
gnode = gnode->FindRoot();
138+
if (!fcond(gnode->pattern, src == sink)) return false;
139+
if (src == sink) return true;
140+
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
141+
if (!CheckPath_(link->value.node, sink, fcond)) return false;
142+
}
143+
return true;
144+
}
145+
146+
template <typename F>
147+
bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
148+
F fcond) {
149+
ICHECK(!src->extern_ref);
150+
visited_.clear();
151+
ICHECK(src != sink);
152+
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
153+
if (!CheckPath_(link->value.node, sink, fcond)) return false;
154+
}
155+
return true;
156+
}
157+
158+
OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) {
159+
if (lhs > relay::kBroadcast && rhs > relay::kBroadcast) {
160+
LOG(FATAL) << "Cannot merge two complex group together";
161+
}
162+
if (lhs > rhs) return lhs;
163+
return rhs;
164+
}
165+
166+
void GraphPartitioner::MergeFromTo(Group* child, Group* parent) {
167+
child = child->FindRoot();
168+
parent = parent->FindRoot();
169+
if (child == parent) return;
170+
// update the number of nodes of the parent group
171+
parent->num_nodes += child->num_nodes;
172+
child->parent = parent;
173+
// update anchor ref and pattern
174+
if (child->anchor_ref != nullptr) {
175+
ICHECK(parent->anchor_ref == nullptr);
176+
parent->anchor_ref = child->anchor_ref;
177+
parent->pattern = CombinePattern(child->pattern, parent->pattern);
178+
}
179+
}
180+
181+
void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
182+
Group* target) {
183+
if (src == sink) return;
184+
if (visited_.count(src)) return;
185+
visited_.insert(src);
186+
Group* gnode = groups_[src->index];
187+
ICHECK(gnode != nullptr);
188+
// merge the current group to the parent if possible.
189+
MergeFromTo(gnode, target);
190+
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
191+
CommitFuse_(link->value.node, sink, target);
192+
}
193+
}
194+
195+
void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
196+
Group* target = groups_[sink->index];
197+
visited_.clear();
198+
ICHECK(src != sink);
199+
CommitFuse_(src, sink, target);
200+
}
201+
202+
size_t GraphPartitioner::CountNodesUptoSink_(IndexedForwardGraph::Node* src,
203+
IndexedForwardGraph::Node* sink) {
204+
if (src == sink || visited_.count(src)) return 0;
205+
visited_.insert(src);
206+
Group* gnode = groups_[src->index];
207+
ICHECK(gnode != nullptr);
208+
auto sum = gnode->num_nodes;
209+
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
210+
sum += CountNodesUptoSink_(link->value.node, sink);
211+
}
212+
return sum;
213+
}
214+
215+
size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
216+
IndexedForwardGraph::Node* dom_parent) {
217+
Group* target = groups_[dom_parent->index];
218+
visited_.clear();
219+
ICHECK(child != dom_parent);
220+
return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent);
221+
}
222+
223+
void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
224+
groups_.resize(graph.post_dfs_order.size());
225+
for (size_t nid = 0; nid < groups_.size(); ++nid) {
226+
const auto* graph_node = graph.post_dfs_order[nid];
227+
auto* group_node = arena_->make<Group>();
228+
group_node->pattern = graph_node->pattern;
229+
group_node->root_ref = graph_node->ref;
230+
// set anchor ref if necessary.
231+
if (group_node->pattern == relay::kOutEWiseFusable) {
232+
group_node->anchor_ref = graph_node->ref;
233+
}
234+
groups_[nid] = group_node;
235+
}
236+
}
237+
238+
void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
239+
const DominatorTree& post_dom_tree, //
240+
int phase) {
241+
for (size_t nid = 0; nid < groups_.size(); ++nid) {
242+
// the group of current node has been specified already.
243+
auto* graph_node = graph.post_dfs_order[nid];
244+
auto* dom_node = post_dom_tree.nodes[nid];
245+
Group* group_node = groups_[nid];
246+
ICHECK(group_node != nullptr);
247+
// no actions for opaque nodes
248+
if (group_node->pattern == kOpaque) continue;
249+
// no actions needed if the current node have no dominator
250+
if (dom_node->parent == nullptr) continue;
251+
ICHECK(!graph_node->extern_ref);
252+
size_t dom_parent_gindex = dom_node->parent->gnode->index;
253+
254+
// refuse the fusion if too many ops are going to be fused together
255+
if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
256+
continue;
257+
258+
if (phase == 2) {
259+
// Fuse injective ops into intermediate tuples, if any
260+
if (group_node->pattern > relay::kInjective) continue;
261+
Group* dom_parent_group = groups_[dom_parent_gindex];
262+
Group* dom_root_group = dom_parent_group->FindRoot();
263+
// If dom node group has a tuple as its root, we do not fuse tuple fields into it
264+
if (dom_root_group->pattern == relay::kTuple) continue;
265+
if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) {
266+
// Now we know the tuple has been fused into subsequent injective ops
267+
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
268+
// dom_root_group can also be tuple, as in inception layers
269+
// CheckPath is needed to avoid fusing two intermediate tuples
270+
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
271+
CommitFuse(graph_node, dom_node->parent->gnode);
272+
}
273+
}
274+
continue;
275+
}
276+
277+
// Skip if current node is already fused to the parent.
278+
if (groups_[dom_parent_gindex] != nullptr &&
279+
group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
280+
continue;
281+
}
282+
// Do not fuse into tuple for now
283+
if (groups_[dom_parent_gindex]->pattern == kTuple) continue;
284+
// Try to fuse current node to its post-dominator.
285+
if (group_node->pattern == kOutEWiseFusable) {
286+
if (phase != 0) continue;
287+
// Path for OutEWiseFusable: conv2d
288+
// Check if the dominator relation is elemwise.
289+
if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
290+
ICHECK(dom_node->parent->gnode != nullptr);
291+
// The fuse can be executed if all the intermediate ops are still broadcast.
292+
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; };
293+
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
294+
CommitFuse(graph_node, dom_node->parent->gnode);
295+
}
296+
}
297+
} else if (group_node->pattern <= kBroadcast) {
298+
// Pre-condition: can only be fused to parent which is injective or reduction.
299+
if (dom_node->parent != nullptr &&
300+
(dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) {
301+
// Check if all the intermediate ops are still broadcast.
302+
// The final terminal node can already be fused to a OutEWiseFusable group.
303+
auto fcond = [](OpPatternKind kind, bool is_sink) {
304+
if (!is_sink) {
305+
// Elemwise, broadcast, and injective ops on the parallel branches
306+
// are allowed be fused to the elemwise/broadcast anchor.
307+
return kind <= kInjective;
308+
} else {
309+
return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective ||
310+
kind == kOutEWiseFusable);
311+
}
312+
};
313+
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
314+
CommitFuse(graph_node, dom_node->parent->gnode);
315+
}
316+
}
317+
} else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
318+
// defer injective fusion to second phase.
319+
// so conv2d always finishes fusing.
320+
if (phase != 1) continue;
321+
// Check if all path are injective.
322+
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
323+
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
324+
CommitFuse(graph_node, dom_node->parent->gnode);
325+
}
326+
} else {
327+
// do nothing.
328+
ICHECK(group_node->pattern == kCommReduce);
329+
}
330+
}
331+
}
332+
333+
} // namespace relay
334+
} // namespace tvm

0 commit comments

Comments
 (0)