|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +#include "./graph_partitioner.h" |
| 21 | + |
| 22 | +#include <vector> |
| 23 | + |
| 24 | +namespace tvm { |
| 25 | +namespace relay { |
| 26 | + |
| 27 | +DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { |
| 28 | + DominatorTree tree; |
| 29 | + tree.nodes.resize(graph.post_dfs_order.size(), nullptr); |
| 30 | + // reverse topo order |
| 31 | + for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { |
| 32 | + size_t index = i - 1; |
| 33 | + tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); |
| 34 | + } |
| 35 | + return tree; |
| 36 | +} |
| 37 | + |
| 38 | +DominatorTree::Node* DominatorTree::LeastCommonAncestor(Node* lhs, Node* rhs, |
| 39 | + OpPatternKind* edge_pattern) { |
| 40 | + while (lhs != rhs) { |
| 41 | + if (lhs == nullptr) return nullptr; |
| 42 | + if (rhs == nullptr) return nullptr; |
| 43 | + if (lhs->depth < rhs->depth) { |
| 44 | + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); |
| 45 | + rhs = rhs->parent; |
| 46 | + } else if (rhs->depth < lhs->depth) { |
| 47 | + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); |
| 48 | + lhs = lhs->parent; |
| 49 | + } else { |
| 50 | + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); |
| 51 | + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); |
| 52 | + lhs = lhs->parent; |
| 53 | + rhs = rhs->parent; |
| 54 | + } |
| 55 | + } |
| 56 | + return lhs; |
| 57 | +} |
| 58 | + |
| 59 | +DominatorTree::Node* DominatorTree::LeastCommonAncestor( |
| 60 | + const LinkedList<IndexedForwardGraph::Edge>& input_nodes, OpPatternKind* edge_pattern) { |
| 61 | + auto link = input_nodes.head; |
| 62 | + if (link == nullptr) { |
| 63 | + return nullptr; |
| 64 | + } |
| 65 | + auto get_node = [&](const IndexedForwardGraph::Edge& edge) { |
| 66 | + size_t oindex = edge.node->index; |
| 67 | + ICHECK_LT(oindex, nodes.size()); |
| 68 | + Node* onode = nodes[oindex]; |
| 69 | + ICHECK(onode != nullptr); |
| 70 | + return onode; |
| 71 | + }; |
| 72 | + Node* parent = get_node(link->value); |
| 73 | + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); |
| 74 | + link = link->next; |
| 75 | + for (; link != nullptr; link = link->next) { |
| 76 | + parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern); |
| 77 | + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); |
| 78 | + } |
| 79 | + return parent; |
| 80 | +} |
| 81 | + |
| 82 | +DominatorTree::Node* DominatorTree::GetNode(support::Arena* arena, |
| 83 | + IndexedForwardGraph::Node* gnode) { |
| 84 | + Node* tnode = arena->make<Node>(); |
| 85 | + tnode->gnode = gnode; |
| 86 | + if (gnode->extern_ref) { |
| 87 | + tnode->depth = 1; |
| 88 | + tnode->parent = nullptr; |
| 89 | + tnode->pattern = kOpaque; |
| 90 | + } else { |
| 91 | + // find the LCAs of all outputs. |
| 92 | + OpPatternKind pattern = kElemWise; |
| 93 | + Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); |
| 94 | + tnode->depth = parent ? parent->depth + 1 : 1; |
| 95 | + tnode->parent = parent; |
| 96 | + tnode->pattern = pattern; |
| 97 | + } |
| 98 | + return tnode; |
| 99 | +} |
| 100 | + |
| 101 | +std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition( |
| 102 | + const IndexedForwardGraph& graph) { |
| 103 | + this->InitGroups(graph); |
| 104 | + if (opt_level_ == 0) return std::move(groups_); |
| 105 | + // get post dominator tree |
| 106 | + auto post_dom_tree = DominatorTree::PostDom(arena_, graph); |
| 107 | + // run fusion algorithm. |
| 108 | + for (int phase = 0; phase < 3; ++phase) { |
| 109 | + this->RunFuse(graph, post_dom_tree, phase); |
| 110 | + } |
| 111 | + return std::move(groups_); |
| 112 | +} |
| 113 | + |
| 114 | +GraphPartitioner::Group* GraphPartitioner::Group::FindRoot() { |
| 115 | + // fast path |
| 116 | + if (this->parent == nullptr) return this; |
| 117 | + // slow path with path compression. |
| 118 | + Group* root = this; |
| 119 | + while (root->parent != nullptr) { |
| 120 | + root = root->parent; |
| 121 | + } |
| 122 | + for (Group* p = this; p != root;) { |
| 123 | + Group* parent = p->parent; |
| 124 | + p->parent = root; |
| 125 | + p = parent; |
| 126 | + } |
| 127 | + return root; |
| 128 | +} |
| 129 | + |
| 130 | +template <typename F> |
| 131 | +bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, |
| 132 | + F fcond) { |
| 133 | + if (visited_.count(src)) return true; |
| 134 | + visited_.insert(src); |
| 135 | + Group* gnode = groups_[src->index]; |
| 136 | + ICHECK(gnode != nullptr); |
| 137 | + gnode = gnode->FindRoot(); |
| 138 | + if (!fcond(gnode->pattern, src == sink)) return false; |
| 139 | + if (src == sink) return true; |
| 140 | + for (auto link = src->outputs.head; link != nullptr; link = link->next) { |
| 141 | + if (!CheckPath_(link->value.node, sink, fcond)) return false; |
| 142 | + } |
| 143 | + return true; |
| 144 | +} |
| 145 | + |
| 146 | +template <typename F> |
| 147 | +bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, |
| 148 | + F fcond) { |
| 149 | + ICHECK(!src->extern_ref); |
| 150 | + visited_.clear(); |
| 151 | + ICHECK(src != sink); |
| 152 | + for (auto link = src->outputs.head; link != nullptr; link = link->next) { |
| 153 | + if (!CheckPath_(link->value.node, sink, fcond)) return false; |
| 154 | + } |
| 155 | + return true; |
| 156 | +} |
| 157 | + |
| 158 | +OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { |
| 159 | + if (lhs > relay::kBroadcast && rhs > relay::kBroadcast) { |
| 160 | + LOG(FATAL) << "Cannot merge two complex group together"; |
| 161 | + } |
| 162 | + if (lhs > rhs) return lhs; |
| 163 | + return rhs; |
| 164 | +} |
| 165 | + |
| 166 | +void GraphPartitioner::MergeFromTo(Group* child, Group* parent) { |
| 167 | + child = child->FindRoot(); |
| 168 | + parent = parent->FindRoot(); |
| 169 | + if (child == parent) return; |
| 170 | + // update the number of nodes of the parent group |
| 171 | + parent->num_nodes += child->num_nodes; |
| 172 | + child->parent = parent; |
| 173 | + // update anchor ref and pattern |
| 174 | + if (child->anchor_ref != nullptr) { |
| 175 | + ICHECK(parent->anchor_ref == nullptr); |
| 176 | + parent->anchor_ref = child->anchor_ref; |
| 177 | + parent->pattern = CombinePattern(child->pattern, parent->pattern); |
| 178 | + } |
| 179 | +} |
| 180 | + |
| 181 | +void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, |
| 182 | + Group* target) { |
| 183 | + if (src == sink) return; |
| 184 | + if (visited_.count(src)) return; |
| 185 | + visited_.insert(src); |
| 186 | + Group* gnode = groups_[src->index]; |
| 187 | + ICHECK(gnode != nullptr); |
| 188 | + // merge the current group to the parent if possible. |
| 189 | + MergeFromTo(gnode, target); |
| 190 | + for (auto link = src->outputs.head; link != nullptr; link = link->next) { |
| 191 | + CommitFuse_(link->value.node, sink, target); |
| 192 | + } |
| 193 | +} |
| 194 | + |
| 195 | +void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { |
| 196 | + Group* target = groups_[sink->index]; |
| 197 | + visited_.clear(); |
| 198 | + ICHECK(src != sink); |
| 199 | + CommitFuse_(src, sink, target); |
| 200 | +} |
| 201 | + |
| 202 | +size_t GraphPartitioner::CountNodesUptoSink_(IndexedForwardGraph::Node* src, |
| 203 | + IndexedForwardGraph::Node* sink) { |
| 204 | + if (src == sink || visited_.count(src)) return 0; |
| 205 | + visited_.insert(src); |
| 206 | + Group* gnode = groups_[src->index]; |
| 207 | + ICHECK(gnode != nullptr); |
| 208 | + auto sum = gnode->num_nodes; |
| 209 | + for (auto link = src->outputs.head; link != nullptr; link = link->next) { |
| 210 | + sum += CountNodesUptoSink_(link->value.node, sink); |
| 211 | + } |
| 212 | + return sum; |
| 213 | +} |
| 214 | + |
| 215 | +size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, |
| 216 | + IndexedForwardGraph::Node* dom_parent) { |
| 217 | + Group* target = groups_[dom_parent->index]; |
| 218 | + visited_.clear(); |
| 219 | + ICHECK(child != dom_parent); |
| 220 | + return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); |
| 221 | +} |
| 222 | + |
| 223 | +void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) { |
| 224 | + groups_.resize(graph.post_dfs_order.size()); |
| 225 | + for (size_t nid = 0; nid < groups_.size(); ++nid) { |
| 226 | + const auto* graph_node = graph.post_dfs_order[nid]; |
| 227 | + auto* group_node = arena_->make<Group>(); |
| 228 | + group_node->pattern = graph_node->pattern; |
| 229 | + group_node->root_ref = graph_node->ref; |
| 230 | + // set anchor ref if necessary. |
| 231 | + if (group_node->pattern == relay::kOutEWiseFusable) { |
| 232 | + group_node->anchor_ref = graph_node->ref; |
| 233 | + } |
| 234 | + groups_[nid] = group_node; |
| 235 | + } |
| 236 | +} |
| 237 | + |
| 238 | +void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // |
| 239 | + const DominatorTree& post_dom_tree, // |
| 240 | + int phase) { |
| 241 | + for (size_t nid = 0; nid < groups_.size(); ++nid) { |
| 242 | + // the group of current node has been specified already. |
| 243 | + auto* graph_node = graph.post_dfs_order[nid]; |
| 244 | + auto* dom_node = post_dom_tree.nodes[nid]; |
| 245 | + Group* group_node = groups_[nid]; |
| 246 | + ICHECK(group_node != nullptr); |
| 247 | + // no actions for opaque nodes |
| 248 | + if (group_node->pattern == kOpaque) continue; |
| 249 | + // no actions needed if the current node have no dominator |
| 250 | + if (dom_node->parent == nullptr) continue; |
| 251 | + ICHECK(!graph_node->extern_ref); |
| 252 | + size_t dom_parent_gindex = dom_node->parent->gnode->index; |
| 253 | + |
| 254 | + // refuse the fusion if too many ops are going to be fused together |
| 255 | + if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) |
| 256 | + continue; |
| 257 | + |
| 258 | + if (phase == 2) { |
| 259 | + // Fuse injective ops into intermediate tuples, if any |
| 260 | + if (group_node->pattern > relay::kInjective) continue; |
| 261 | + Group* dom_parent_group = groups_[dom_parent_gindex]; |
| 262 | + Group* dom_root_group = dom_parent_group->FindRoot(); |
| 263 | + // If dom node group has a tuple as its root, we do not fuse tuple fields into it |
| 264 | + if (dom_root_group->pattern == relay::kTuple) continue; |
| 265 | + if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) { |
| 266 | + // Now we know the tuple has been fused into subsequent injective ops |
| 267 | + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; |
| 268 | + // dom_root_group can also be tuple, as in inception layers |
| 269 | + // CheckPath is needed to avoid fusing two intermediate tuples |
| 270 | + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { |
| 271 | + CommitFuse(graph_node, dom_node->parent->gnode); |
| 272 | + } |
| 273 | + } |
| 274 | + continue; |
| 275 | + } |
| 276 | + |
| 277 | + // Skip if current node is already fused to the parent. |
| 278 | + if (groups_[dom_parent_gindex] != nullptr && |
| 279 | + group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { |
| 280 | + continue; |
| 281 | + } |
| 282 | + // Do not fuse into tuple for now |
| 283 | + if (groups_[dom_parent_gindex]->pattern == kTuple) continue; |
| 284 | + // Try to fuse current node to its post-dominator. |
| 285 | + if (group_node->pattern == kOutEWiseFusable) { |
| 286 | + if (phase != 0) continue; |
| 287 | + // Path for OutEWiseFusable: conv2d |
| 288 | + // Check if the dominator relation is elemwise. |
| 289 | + if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { |
| 290 | + ICHECK(dom_node->parent->gnode != nullptr); |
| 291 | + // The fuse can be executed if all the intermediate ops are still broadcast. |
| 292 | + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; |
| 293 | + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { |
| 294 | + CommitFuse(graph_node, dom_node->parent->gnode); |
| 295 | + } |
| 296 | + } |
| 297 | + } else if (group_node->pattern <= kBroadcast) { |
| 298 | + // Pre-condition: can only be fused to parent which is injective or reduction. |
| 299 | + if (dom_node->parent != nullptr && |
| 300 | + (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { |
| 301 | + // Check if all the intermediate ops are still broadcast. |
| 302 | + // The final terminal node can already be fused to a OutEWiseFusable group. |
| 303 | + auto fcond = [](OpPatternKind kind, bool is_sink) { |
| 304 | + if (!is_sink) { |
| 305 | + // Elemwise, broadcast, and injective ops on the parallel branches |
| 306 | + // are allowed be fused to the elemwise/broadcast anchor. |
| 307 | + return kind <= kInjective; |
| 308 | + } else { |
| 309 | + return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || |
| 310 | + kind == kOutEWiseFusable); |
| 311 | + } |
| 312 | + }; |
| 313 | + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { |
| 314 | + CommitFuse(graph_node, dom_node->parent->gnode); |
| 315 | + } |
| 316 | + } |
| 317 | + } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) { |
| 318 | + // defer injective fusion to second phase. |
| 319 | + // so conv2d always finishes fusing. |
| 320 | + if (phase != 1) continue; |
| 321 | + // Check if all path are injective. |
| 322 | + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; |
| 323 | + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { |
| 324 | + CommitFuse(graph_node, dom_node->parent->gnode); |
| 325 | + } |
| 326 | + } else { |
| 327 | + // do nothing. |
| 328 | + ICHECK(group_node->pattern == kCommReduce); |
| 329 | + } |
| 330 | + } |
| 331 | +} |
| 332 | + |
| 333 | +} // namespace relay |
| 334 | +} // namespace tvm |
0 commit comments