|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +/*! |
| 21 | + * \file src/relay/collage/candidate_partition.cc |
| 22 | + * \brief A potential partition in the Collage search. |
| 23 | + */ |
| 24 | + |
| 25 | +#include "./candidate_partition.h" |
| 26 | + |
| 27 | +#include <tvm/relay/attrs/memory.h> |
| 28 | + |
| 29 | +#include "./candidate_set.h" |
| 30 | +#include "./partition_rule.h" |
| 31 | +#include "./partition_spec.h" |
| 32 | +#include "./utils.h" |
| 33 | + |
| 34 | +namespace tvm { |
| 35 | +namespace relay { |
| 36 | +namespace collage { |
| 37 | + |
| 38 | +TVM_REGISTER_NODE_TYPE(CandidatePartitionNode); |
| 39 | + |
| 40 | +void CandidatePartitionNode::VisitAttrs(AttrVisitor* v) { |
| 41 | + v->Visit("rule_name", &rule_name_); |
| 42 | + v->Visit("sub_graph", &sub_graph_); |
| 43 | + v->Visit("spec", &spec_); |
| 44 | + // TODO(mbs): cost_ |
| 45 | +} |
| 46 | + |
| 47 | +PartitionSpec CandidatePartitionNode::partition_spec() const { |
| 48 | + return Downcast<PartitionSpec>(spec_); |
| 49 | +} |
| 50 | + |
| 51 | +std::string CandidatePartitionNode::partition_spec_name() const { |
| 52 | + return Downcast<PartitionSpec>(spec_)->spec_name_; |
| 53 | +} |
| 54 | + |
| 55 | +Target CandidatePartitionNode::target() const { return Downcast<PartitionSpec>(spec_)->target_; } |
| 56 | + |
| 57 | +std::string CandidatePartitionNode::ToSummary(const DataflowGraph& dataflow_graph) const { |
| 58 | + std::ostringstream os; |
| 59 | + os << sub_graph_->label_; |
| 60 | + os << " | ("; |
| 61 | + bool first = true; |
| 62 | + for (PostDfsIndex index : sub_graph_->input_) { |
| 63 | + Expr sub_expr = dataflow_graph.index_to_node(index)->ref(); |
| 64 | + if (CanInline(sub_expr)) { |
| 65 | + continue; |
| 66 | + } |
| 67 | + if (first) { |
| 68 | + first = false; |
| 69 | + } else { |
| 70 | + os << ", "; |
| 71 | + } |
| 72 | + os << PrettyPrint(sub_expr->checked_type()); |
| 73 | + } |
| 74 | + os << ") -> ("; |
| 75 | + first = true; |
| 76 | + for (PostDfsIndex index : sub_graph_->exit_) { |
| 77 | + Expr sub_expr = dataflow_graph.index_to_node(index)->ref(); |
| 78 | + if (CanInline(sub_expr)) { |
| 79 | + continue; |
| 80 | + } |
| 81 | + if (first) { |
| 82 | + first = false; |
| 83 | + } else { |
| 84 | + os << ", "; |
| 85 | + } |
| 86 | + os << PrettyPrint(sub_expr->checked_type()); |
| 87 | + } |
| 88 | + os << ") | "; |
| 89 | + os << sub_graph_->inside_.ToString(); |
| 90 | + os << " | "; |
| 91 | + os << partition_spec_name(); |
| 92 | + os << " | "; |
| 93 | + os << cost_.ToString(); |
| 94 | + return os.str(); |
| 95 | +} |
| 96 | + |
| 97 | +std::string CandidatePartitionNode::ToString() const { |
| 98 | + std::ostringstream os; |
| 99 | + os << "{rule_name=" << rule_name_; |
| 100 | + os << ",sub_graph=" << sub_graph_->ToString(); |
| 101 | + os << ",spec_name=" << partition_spec_name(); |
| 102 | + if (!cost_.is_unknown()) { |
| 103 | + os << ",cost=" << cost_.ToString(); |
| 104 | + } |
| 105 | + os << "}"; |
| 106 | + return os.str(); |
| 107 | +} |
| 108 | + |
| 109 | +CandidatePartition::CandidatePartition(String rule_name, SubGraph sub_graph, |
| 110 | + ObjectRef /* actually PartitionSpec */ spec, Cost cost) { |
| 111 | + auto node = runtime::make_object<CandidatePartitionNode>(); |
| 112 | + node->rule_name_ = std::move(rule_name); |
| 113 | + node->sub_graph_ = std::move(sub_graph); |
| 114 | + node->spec_ = std::move(spec); |
| 115 | + node->cost_ = cost; |
| 116 | + data_ = std::move(node); |
| 117 | +} |
| 118 | + |
| 119 | +CandidatePartition WithRuleName(CandidatePartition candidate, String rule_name) { |
| 120 | + if (rule_name == candidate->rule_name_) { |
| 121 | + return candidate; |
| 122 | + } |
| 123 | + auto* node = candidate.CopyOnWrite(); |
| 124 | + node->rule_name_ = std::move(rule_name); |
| 125 | + return GetRef<CandidatePartition>(node); |
| 126 | +} |
| 127 | + |
| 128 | +CandidatePartition WithSubGraph(CandidatePartition candidate, SubGraph sub_graph) { |
| 129 | + if (sub_graph == candidate->sub_graph_) { |
| 130 | + return candidate; |
| 131 | + } |
| 132 | + auto* node = candidate.CopyOnWrite(); |
| 133 | + node->sub_graph_ = std::move(sub_graph); |
| 134 | + return GetRef<CandidatePartition>(node); |
| 135 | +} |
| 136 | + |
| 137 | +bool CandidatePartition::operator<(const CandidatePartition& that) const { |
| 138 | + // Order lexicographically on sub-graphs. |
| 139 | + if (*get()->sub_graph_.get() < *that->sub_graph_.get()) { |
| 140 | + return true; |
| 141 | + } |
| 142 | + if (*that->sub_graph_.get() < *get()->sub_graph_.get()) { |
| 143 | + return false; |
| 144 | + } |
| 145 | + // Break ties by rule name. |
| 146 | + return get()->rule_name_ < that->rule_name_; |
| 147 | +} |
| 148 | + |
| 149 | +bool CandidatePartition::AreTouching(const DataflowGraph& dataflow_graph, |
| 150 | + const CandidatePartition& that) const { |
| 151 | + return get()->spec_ == that->spec_ && |
| 152 | + get()->sub_graph_.AreTouching(dataflow_graph, that->sub_graph_); |
| 153 | +} |
| 154 | + |
| 155 | +CandidatePartition CandidatePartition::DisjointUnion(const DataflowGraph& dataflow_graph, |
| 156 | + const CandidatePartition& that) const { |
| 157 | + ICHECK_EQ(get()->spec_, that->spec_); |
| 158 | + return CandidatePartition(UnionLabels(get()->rule_name_, that->rule_name_), |
| 159 | + get()->sub_graph_.DisjointUnion(dataflow_graph, that->sub_graph_), |
| 160 | + get()->spec_, get()->cost_ + that->cost_); |
| 161 | +} |
| 162 | + |
| 163 | +/*static*/ |
| 164 | +CandidatePartition CandidatePartition::DisjointUnion(const DataflowGraph& dataflow_graph, |
| 165 | + std::vector<CandidatePartition> candidates) { |
| 166 | + ICHECK_GT(candidates.size(), 1); |
| 167 | + CandidatePartition result = candidates.front(); |
| 168 | + for (size_t i = 1; i < candidates.size(); ++i) { |
| 169 | + result = result.DisjointUnion(dataflow_graph, candidates[i]); |
| 170 | + } |
| 171 | + return result; |
| 172 | +} |
| 173 | + |
| 174 | +/*static*/ |
| 175 | +Expr CandidatePartition::ParallelRewrite(const DataflowGraph& dataflow_graph, |
| 176 | + const std::vector<CandidatePartition>& candidates) { |
| 177 | + std::vector<SubGraph> sub_graphs; |
| 178 | + sub_graphs.reserve(candidates.size()); |
| 179 | + for (const auto& candidate : candidates) { |
| 180 | + sub_graphs.emplace_back(candidate->sub_graph_); |
| 181 | + } |
| 182 | + return SubGraph::ParallelRewrite(dataflow_graph, sub_graphs); |
| 183 | +} |
| 184 | + |
| 185 | +/*static*/ |
| 186 | +std::vector<CandidatePartition> CandidatePartition::MaxCoalesce( |
| 187 | + const DataflowGraph& dataflow_graph, std::vector<CandidatePartition> candidates) { |
| 188 | + VLOG(1) << "Running MaxCoalesce over " << candidates.size() << " candidates"; |
| 189 | + // This is an eager version of using the simple (kOpaque, kOpaque) combiner. |
| 190 | + |
| 191 | + // Switch to set representation. |
| 192 | + CandidateSet result_set(std::move(candidates)); |
| 193 | + |
| 194 | + // Until fixed point... |
| 195 | + size_t num_rounds = 0; |
| 196 | + while (result_set.PrepareForNextRound()) { |
| 197 | + VLOG_CONTEXT << "round " << ++num_rounds; |
| 198 | + VLOG(1) << "checking " << result_set.size() << " candidates (" << result_set.first_new_index() |
| 199 | + << " existing)"; |
| 200 | + IndexSet removed_this_round(result_set.size()); // over candidate indexes! |
| 201 | + |
| 202 | + // Build map from post-dfs indices to the indices of candidates with corresponding entry node. |
| 203 | + // NOTE: the index set is over candidate indices not post-dfs indices! |
| 204 | + std::vector<IndexSet> entry_map(dataflow_graph.size(), IndexSet(result_set.size())); |
| 205 | + for (size_t i = 0; i < result_set.size(); ++i) { |
| 206 | + CandidatePartition candidate = result_set.at(i); |
| 207 | + for (PostDfsIndex entry_index : candidate->sub_graph_->entry_) { |
| 208 | + entry_map[entry_index].Add(i); |
| 209 | + } |
| 210 | + } |
| 211 | + |
| 212 | + for (size_t i = 0; i < result_set.size(); ++i) { |
| 213 | + if (removed_this_round[i]) { |
| 214 | + // Already merged. |
| 215 | + continue; |
| 216 | + } |
| 217 | + CandidatePartition upstream = result_set.at(i); |
| 218 | + // Narrow our search to just those candidates which could touch. |
| 219 | + IndexSet possible_downstream(result_set.size()); // over candidate indexes! |
| 220 | + for (PostDfsIndex output_index : upstream->sub_graph_->output_) { |
| 221 | + possible_downstream = possible_downstream | entry_map[output_index]; |
| 222 | + } |
| 223 | + for (size_t j : possible_downstream) { |
| 224 | + if (removed_this_round[j]) { |
| 225 | + // Already merged. |
| 226 | + continue; |
| 227 | + } |
| 228 | + if (i == j) { |
| 229 | + // Ignore self. |
| 230 | + continue; |
| 231 | + } |
| 232 | + CandidatePartition downstream = result_set.at(j); |
| 233 | + if (!upstream.AreTouching(dataflow_graph, downstream)) { |
| 234 | + continue; |
| 235 | + } |
| 236 | + CandidatePartition new_candidate = upstream.DisjointUnion(dataflow_graph, downstream); |
| 237 | + VLOG(2) << "Merging upstream candidate " << upstream->ToString() |
| 238 | + << " and downstream candidate " << downstream->ToString() << " to yield " |
| 239 | + << new_candidate->ToString(); |
| 240 | + result_set.Add(dataflow_graph, new_candidate); |
| 241 | + result_set.Remove(upstream); |
| 242 | + removed_this_round.Add(i); |
| 243 | + result_set.Remove(downstream); |
| 244 | + removed_this_round.Add(j); |
| 245 | + } |
| 246 | + } |
| 247 | + } |
| 248 | + |
| 249 | + // Restore canonical order. |
| 250 | + result_set.sort(); |
| 251 | + |
| 252 | + VLOG(1) << "MaxCoalesce produced " << result_set.size() << " candidates"; |
| 253 | + return result_set.MovedCurrentCandidates(); |
| 254 | +} |
| 255 | + |
| 256 | +} // namespace collage |
| 257 | +} // namespace relay |
| 258 | +} // namespace tvm |
0 commit comments