Skip to content

Commit 7ebc802

Browse files
authored
[Relay] Introduce arguments limit to FuseOps pass (#15137)
* [Relay] Introduce arguments limit to FuseOps pass In PR #8313 a parameter `max_function_args` was introduced. It leads to limit number of function argument and in case when this value is exceeded then concatenation layer is split to a several concat operations. I faced a problem on Adreno GPU that for kernel with big number of arguments the enqueueNDRange was crashed without any errors. The problem appeared because of the huge number of arguments. But in this case not only concat layer was a root cause of the problem. Also after fusing several operations the final functions had a big number of arguments. As it was discussed in #8313, adding a limitation on the number of function arguments to the FuseOps pass might be a good improvement. In this PR I introduced such mechanism for limitation number of function arguments for FuseOps pass and add an arguments limit to OpenCL devices at 128 parameters. The idea of current approach is calculate the number of arguments for each node in fusing algorithm and in case then the number of function arguments exceeds the limit, specified by `max_function_args`, then the fusing should be stopped. In case when node has several inputs and for some of the inputs the number of arguments wasn't computed, then we postpone fusing for this node and will try fuse this node later when the number of arguments will be computed for all inputs. This approach with postponed fusing helps to avoid additional computations during compilation. Additionally, case of dynamic shapes should be handled. In case of dynamic shape, function arguments also included sizes of dynamic dimension and strides. The number of strides can be computed by calculating number of tensor dimensions (the number of strides equals to the rank of the tensor). The number of additional parameters with sizes of dynamic dimensions can be calculated by computing number of dynamic dimensions. * Fix memory_scope order in test * Apply code review comments * Apply comments
1 parent d8c0676 commit 7ebc802

File tree

17 files changed

+677
-96
lines changed

17 files changed

+677
-96
lines changed

include/tvm/relay/transform.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,12 @@ TVM_DLL Pass FoldConstant(bool fold_qnn = false);
120120
/*!
121121
* \brief Split function with huge number of arguments to smaller pieces.
122122
*
123+
* \param max_function_args Maximum number of function arguments. If it equals 0 then SplitArgs
124+
* shouldn't split the function.
125+
*
123126
* \return The pass.
124127
*/
125-
TVM_DLL Pass SplitArgs(int max_function_args);
128+
TVM_DLL Pass SplitArgs(uint64_t max_function_args);
126129

127130
/*!
128131
* \brief Fuse operations into expr into separate functions.

include/tvm/topi/transform.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
723723
}
724724

725725
/*!
726-
* \brief Calcluate the output shape of strided_slice, the entry point for Relay type relation
726+
* \brief Calculate the output shape of strided_slice, the entry point for Relay type relation
727727
*
728728
* \param ishape The input tensor shape
729729
* \param begin The indices to begin with in the slicing

python/tvm/relay/op/tensor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from . import _make
2525
from .dyn import _make as _dyn_make
26-
from ..expr import Tuple, Expr, Constant
26+
from ..expr import Tuple, Expr, Constant, Call
2727
from . import op as reg
2828

2929

@@ -1141,12 +1141,15 @@ def concatenate(data, axis):
11411141
result: relay.Expr
11421142
The concatenated tensor.
11431143
"""
1144-
data = list(data)
1144+
if not isinstance(data, Call):
1145+
data = list(data)
11451146
if not data:
11461147
raise ValueError("relay.concatenate requires data to be non-empty.")
1148+
if not isinstance(data, Call):
1149+
data = Tuple(data)
11471150
if not isinstance(axis, int):
11481151
raise ValueError("For now, we only support integer axis")
1149-
return _make.concatenate(Tuple(data), axis)
1152+
return _make.concatenate(data, axis)
11501153

11511154

11521155
def einsum(data, equation):

python/tvm/relay/transform/transform.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1376,10 +1376,17 @@ def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1):
13761376
def SplitArgs(max_function_args):
13771377
"""Split function with huge number of arguments to smaller pieces.
13781378
1379+
Parameters
1380+
----------
1381+
max_function_args: int
1382+
Maximum number of function arguments. If it equals 0 then SplitArgs
1383+
shouldn't split the function.
1384+
1385+
13791386
Returns
13801387
-------
13811388
ret : tvm.transform.Pass
1382-
The registered pass for constant folding.
1389+
The registered pass.
13831390
"""
13841391
return _ffi_api.SplitArgs(max_function_args)
13851392

python/tvm/target/target.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def max_shared_memory_per_block(self):
194194

195195
@property
196196
def max_function_args(self):
197-
return int(self.attrs.get("max_function_args", -1))
197+
return int(self.attrs.get("max_function_args", 0))
198198

199199
@property
200200
def vtcm_capacity(self):

src/relay/analysis/graph_partitioner.cc

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ void GraphPartitioner::MergeFromTo(Group* child, Group* parent) {
169169
if (child == parent) return;
170170
// update the number of nodes of the parent group
171171
parent->num_nodes += child->num_nodes;
172+
parent->args_num += child->args_num;
172173
child->parent = parent;
173174
// update anchor ref and pattern
174175
if (child->anchor_ref != nullptr) {
@@ -180,6 +181,10 @@ void GraphPartitioner::MergeFromTo(Group* child, Group* parent) {
180181

181182
void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
182183
Group* target) {
184+
if (postpone_node_ != nullptr) {
185+
postponed_fusing_map_.insert({postpone_node_, src});
186+
return;
187+
}
183188
if (src == sink) return;
184189
if (visited_.count(src)) return;
185190
visited_.insert(src);
@@ -220,7 +225,113 @@ size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node*
220225
return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent);
221226
}
222227

228+
size_t GraphPartitioner::CountAdditionalArgs_(const TensorTypeNode* ttype, bool with_strides) {
229+
size_t any_dims = 0;
230+
for (const auto& dim : ttype->shape) {
231+
if (dim.as<AnyNode>()) {
232+
any_dims++;
233+
}
234+
}
235+
if (with_strides && any_dims > 0) any_dims += ttype->shape.size();
236+
return any_dims;
237+
}
238+
239+
size_t GraphPartitioner::CountArgs_(IndexedForwardGraph::Node* src,
240+
const IndexedForwardGraph& graph, bool update_postpone) {
241+
std::unordered_set<Group*> visited_groups;
242+
Group* gnode = groups_[src->index];
243+
ICHECK(gnode != nullptr);
244+
auto sum = gnode->args_num;
245+
visited_groups.insert(gnode->FindRoot());
246+
auto calc_args_number = [this, src, &graph, &visited_groups,
247+
update_postpone](const relay::Expr& arg) -> size_t {
248+
if (arg.as<VarNode>()) return 0;
249+
auto* node = graph.node_map.at(arg.get());
250+
Group* prev_group = groups_[node->index]->FindRoot();
251+
if (visited_groups.count(prev_group) == 0) {
252+
visited_groups.insert(prev_group);
253+
if (prev_group->args_num > 0) {
254+
// Get the number of arguments from the group
255+
return prev_group->args_num;
256+
} else if (update_postpone) {
257+
// Update pointer to the node which should be postponed for deferred fusing
258+
postpone_node_ = src;
259+
} else {
260+
// Calculate the number of arguments for the node which wasn't processed before
261+
return CountArgs_(node, graph, update_postpone);
262+
}
263+
}
264+
return 0;
265+
};
266+
if (auto call_node = GetRef<ObjectRef>(src->ref).as<CallNode>()) {
267+
for (auto& it : call_node->args) {
268+
sum += calc_args_number(it);
269+
}
270+
} else if (auto tuple_node = GetRef<ObjectRef>(src->ref).as<TupleNode>()) {
271+
for (auto& it : tuple_node->fields) {
272+
sum += calc_args_number(it);
273+
}
274+
}
275+
return sum;
276+
}
277+
278+
size_t GraphPartitioner::CountArgsLimit_(const IndexedForwardGraph::Node* child) {
279+
auto* outputs_list = child->outputs.head;
280+
size_t output_args = 0;
281+
while (outputs_list != nullptr) {
282+
output_args++;
283+
if (auto call_node = GetRef<ObjectRef>(outputs_list->value.node->ref).as<CallNode>()) {
284+
if (const auto* ttype = call_node->checked_type().as<TensorTypeNode>()) {
285+
output_args += CountAdditionalArgs_(ttype, false);
286+
}
287+
}
288+
outputs_list = outputs_list->next;
289+
}
290+
return (max_function_args_ > output_args) ? max_function_args_ - output_args : 0;
291+
}
292+
293+
size_t GraphPartitioner::CountFusedArgs(const IndexedForwardGraph& graph,
294+
IndexedForwardGraph::Node* child) {
295+
size_t args_num = 0;
296+
auto* outputs_list = child->outputs.head;
297+
while (outputs_list != nullptr) {
298+
args_num = std::max(args_num, CountArgs_(outputs_list->value.node, graph));
299+
outputs_list = outputs_list->next;
300+
}
301+
return args_num;
302+
}
303+
223304
void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
305+
auto args_counter = [this](const tvm::Object* obj) {
306+
size_t args_num = 0;
307+
if (auto call_node = GetRef<ObjectRef>(obj).as<CallNode>()) {
308+
for (auto& it : call_node->args) {
309+
if (it.as<VarNode>() || it.as<TupleGetItemNode>()) {
310+
args_num++;
311+
if (const auto* ttype = it.as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
312+
args_num += CountAdditionalArgs_(ttype);
313+
}
314+
}
315+
}
316+
} else if (auto tuple_node = GetRef<ObjectRef>(obj).as<TupleNode>()) {
317+
for (auto& it : tuple_node->fields) {
318+
if (it.as<VarNode>() || it.as<TupleGetItemNode>()) {
319+
args_num++;
320+
if (const auto* ttype = it.as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
321+
args_num += CountAdditionalArgs_(ttype);
322+
}
323+
}
324+
}
325+
} else if (GetRef<ObjectRef>(obj).as<VarNode>()) {
326+
args_num++;
327+
if (const auto* ttype =
328+
GetRef<ObjectRef>(obj).as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
329+
args_num += CountAdditionalArgs_(ttype);
330+
}
331+
}
332+
return args_num;
333+
};
334+
224335
groups_.resize(graph.post_dfs_order.size());
225336
for (size_t nid = 0; nid < groups_.size(); ++nid) {
226337
const auto* graph_node = graph.post_dfs_order[nid];
@@ -231,6 +342,7 @@ void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
231342
if (group_node->pattern == relay::kOutEWiseFusable) {
232343
group_node->anchor_ref = graph_node->ref;
233344
}
345+
group_node->args_num = args_counter(graph_node->ref);
234346
groups_[nid] = group_node;
235347
}
236348
}
@@ -244,6 +356,21 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
244356
auto* dom_node = post_dom_tree.nodes[nid];
245357
Group* group_node = groups_[nid];
246358
ICHECK(group_node != nullptr);
359+
postpone_node_ = nullptr;
360+
// Check if the fusing of some inputs was postponed
361+
if (postponed_fusing_map_.count(graph_node)) {
362+
auto range = postponed_fusing_map_.equal_range(graph_node);
363+
for (auto it = range.first; it != range.second; ++it) {
364+
// If the number of arguments is less than the limit then the input can be fused
365+
if (CountArgs_(graph_node, graph, false) <= CountArgsLimit_(graph_node)) {
366+
auto* src = it->second;
367+
auto* snode = post_dom_tree.nodes[src->index]->parent->gnode;
368+
if (groups_[snode->index]->anchor_ref != nullptr) continue;
369+
CommitFuse(src, snode);
370+
}
371+
}
372+
postponed_fusing_map_.erase(graph_node);
373+
}
247374
// no actions for opaque nodes
248375
if (group_node->pattern == kOpaque) continue;
249376
// no actions needed if the current node have no dominator
@@ -254,6 +381,15 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
254381
// refuse the fusion if too many ops are going to be fused together
255382
if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
256383
continue;
384+
// Refuse the fusion if too many arguments are going to be in the fused function
385+
if (max_function_args_ > 0) {
386+
auto limit = CountArgsLimit_(graph_node);
387+
if (limit > 0) {
388+
if (CountFusedArgs(graph, graph_node) > limit) {
389+
continue;
390+
}
391+
}
392+
}
257393

258394
if (phase == 2) {
259395
// Fuse injective ops into intermediate tuples, if any

src/relay/analysis/graph_partitioner.h

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class IndexedForwardGraph {
7878
std::vector<Node*> post_dfs_order;
7979

8080
/*! \brief Dump the graph into string. */
81-
void DebugDump() {
81+
void DebugDump() const {
8282
std::ostringstream os;
8383
for (size_t i = 0; i < post_dfs_order.size(); ++i) {
8484
Node* node = post_dfs_order[i];
@@ -162,8 +162,12 @@ class DominatorTree {
162162
*/
163163
class GraphPartitioner {
164164
public:
165-
explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth)
166-
: arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {}
165+
explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth,
166+
size_t max_function_args)
167+
: arena_(arena),
168+
opt_level_(opt_level),
169+
max_fuse_depth_(max_fuse_depth),
170+
max_function_args_(max_function_args) {}
167171
/*!
168172
* \brief Group as a union find data structure.
169173
*/
@@ -183,6 +187,10 @@ class GraphPartitioner {
183187
* \brief The number of nodes belonging to this group
184188
*/
185189
uint32_t num_nodes{1};
190+
/*!
191+
* \brief The number of function arguments belonging to this group
192+
*/
193+
size_t args_num{0};
186194

187195
/*! \brief Optional attributes to annotate the grouped function. */
188196
runtime::Map<runtime::String, ObjectRef> attrs;
@@ -205,10 +213,21 @@ class GraphPartitioner {
205213
int opt_level_;
206214
/*! \brief The maximum number of operations in one fused function */
207215
size_t max_fuse_depth_;
216+
/*! \brief The maximum number of arguments in one fused function */
217+
size_t max_function_args_;
208218
/*! \brief The internal groups. */
209219
std::vector<Group*> groups_;
210220
/*! \brief internal field used for deduplication */
211221
std::unordered_set<IndexedForwardGraph::Node*> visited_;
222+
/*! \brief The map with nodes which were postponed for fusing. */
223+
std::unordered_multimap<const IndexedForwardGraph::Node*, IndexedForwardGraph::Node*>
224+
postponed_fusing_map_;
225+
/*!
226+
* \brief Fusing of this node should be postponed till all child nodes are evaluated.
227+
* It is used to calculate the number of arguments which will be passed to this node in
228+
* the generated function.
229+
*/
230+
const IndexedForwardGraph::Node* postpone_node_{nullptr};
212231
// Internal implementation of CheckPath
213232
template <typename F>
214233
bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond);
@@ -247,6 +266,23 @@ class GraphPartitioner {
247266
void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);
248267

249268
size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);
269+
// Count the number of additional arguments. In the case of dynamic shape,
270+
// generated function takes several additional arguments, such as the sizes of
271+
// the dynamic dimensions and strides.
272+
// This function calculates the number of such additional arguments.
273+
size_t CountAdditionalArgs_(const TensorTypeNode* ttype, bool with_strides = true);
274+
// Calculate the number of arguments for the node.
275+
size_t CountArgs_(IndexedForwardGraph::Node* src, const IndexedForwardGraph& graph,
276+
bool update_postpone = true);
277+
// Count the actual limit of arguments for a generated function.
278+
// max_function_args_ specifies the number of maximum function arguments. But
279+
// usually, output tensors are also passed to the function as arguments.
280+
// Additionally, in the case of dynamic shape, it is necessary to take into
281+
// account the number of parameters which specifies the sizes of the dynamic
282+
// dimensions.
283+
// This function computes the maximum number of arguments by the following formula:
284+
// limit = max_function_args_ - output_args_count
285+
size_t CountArgsLimit_(const IndexedForwardGraph::Node* child);
250286

251287
// Count the number of nodes in a fused subgraph if child is additionally fused.
252288
// dom_parent is already known to be a part of the subgraph.
@@ -256,6 +292,10 @@ class GraphPartitioner {
256292
// is important for correct calculation.
257293
size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
258294
IndexedForwardGraph::Node* dom_parent);
295+
// Count the number of arguments in a fused subgraph. This function also takes into account the
296+
// number of the child's output node argument. It helps to stop fusing before the node when the
297+
// limit will be exceeded.
298+
size_t CountFusedArgs(const IndexedForwardGraph& graph, IndexedForwardGraph::Node* child);
259299

260300
// Initialize the groups.
261301
void InitGroups(const IndexedForwardGraph& graph);

src/relay/backend/build_module.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ class RelayBuildModule : public runtime::ModuleNode {
337337
if (config_->optional_homogeneous_target.defined()) {
338338
// This pass currently only supports the homogeneous case.
339339
pass_seqs.push_back(transform::SplitArgs(
340-
config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", -1)
340+
config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", 0)
341341
.value()
342342
.IntValue()));
343343
}

src/relay/backend/vm/compiler.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,13 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
10591059
// Always plan devices so the remaining passes don't need to distinguish homogeneous vs
10601060
// heterogeneous execution.
10611061
pass_seqs.push_back(transform::PlanDevices(config_));
1062+
if (config_->optional_homogeneous_target.defined()) {
1063+
// This pass currently only supports the homogeneous case.
1064+
pass_seqs.push_back(transform::SplitArgs(
1065+
config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", 0)
1066+
.value()
1067+
.IntValue()));
1068+
}
10621069

10631070
pass_seqs.push_back(transform::FuseOps());
10641071

0 commit comments

Comments
 (0)