Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,12 @@ TVM_DLL Pass FoldConstant(bool fold_qnn = false);
/*!
* \brief Split function with huge number of arguments to smaller pieces.
*
* \param max_function_args Maximum number of function arguments. If it equals 0 then SplitArgs
* shouldn't split the function.
*
* \return The pass.
*/
TVM_DLL Pass SplitArgs(int max_function_args);
TVM_DLL Pass SplitArgs(uint64_t max_function_args);

/*!
* \brief Fuse operations into expr into separate functions.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
}

/*!
* \brief Calcluate the output shape of strided_slice, the entry point for Relay type relation
* \brief Calculate the output shape of strided_slice, the entry point for Relay type relation
*
* \param ishape The input tensor shape
* \param begin The indices to begin with in the slicing
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from . import _make
from .dyn import _make as _dyn_make
from ..expr import Tuple, Expr, Constant
from ..expr import Tuple, Expr, Constant, Call
from . import op as reg


Expand Down Expand Up @@ -1141,12 +1141,15 @@ def concatenate(data, axis):
result: relay.Expr
The concatenated tensor.
"""
data = list(data)
if not isinstance(data, Call):
data = list(data)
if not data:
raise ValueError("relay.concatenate requires data to be non-empty.")
if not isinstance(data, Call):
data = Tuple(data)
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")
return _make.concatenate(Tuple(data), axis)
return _make.concatenate(data, axis)


def einsum(data, equation):
Expand Down
9 changes: 8 additions & 1 deletion python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,10 +1376,17 @@ def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1):
def SplitArgs(max_function_args):
"""Split function with huge number of arguments to smaller pieces.

Parameters
----------
max_function_args: int
Maximum number of function arguments. If it equals 0 then SplitArgs
shouldn't split the function.


Returns
-------
ret : tvm.transform.Pass
The registered pass for constant folding.
The registered pass.
"""
return _ffi_api.SplitArgs(max_function_args)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def max_shared_memory_per_block(self):

@property
def max_function_args(self):
return int(self.attrs.get("max_function_args", -1))
return int(self.attrs.get("max_function_args", 0))

@property
def vtcm_capacity(self):
Expand Down
136 changes: 136 additions & 0 deletions src/relay/analysis/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ void GraphPartitioner::MergeFromTo(Group* child, Group* parent) {
if (child == parent) return;
// update the number of nodes of the parent group
parent->num_nodes += child->num_nodes;
parent->args_num += child->args_num;
child->parent = parent;
// update anchor ref and pattern
if (child->anchor_ref != nullptr) {
Expand All @@ -180,6 +181,10 @@ void GraphPartitioner::MergeFromTo(Group* child, Group* parent) {

void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
Group* target) {
if (postpone_node_ != nullptr) {
postponed_fusing_map_.insert({postpone_node_, src});
return;
}
if (src == sink) return;
if (visited_.count(src)) return;
visited_.insert(src);
Expand Down Expand Up @@ -220,7 +225,113 @@ size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node*
return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent);
}

size_t GraphPartitioner::CountAdditionalArgs_(const TensorTypeNode* ttype, bool with_strides) {
size_t any_dims = 0;
for (const auto& dim : ttype->shape) {
if (dim.as<AnyNode>()) {
any_dims++;
}
}
if (with_strides && any_dims > 0) any_dims += ttype->shape.size();
return any_dims;
}

size_t GraphPartitioner::CountArgs_(IndexedForwardGraph::Node* src,
const IndexedForwardGraph& graph, bool update_postpone) {
std::unordered_set<Group*> visited_groups;
Group* gnode = groups_[src->index];
ICHECK(gnode != nullptr);
auto sum = gnode->args_num;
visited_groups.insert(gnode->FindRoot());
auto calc_args_number = [this, src, &graph, &visited_groups,
update_postpone](const relay::Expr& arg) -> size_t {
if (arg.as<VarNode>()) return 0;
auto* node = graph.node_map.at(arg.get());
Group* prev_group = groups_[node->index]->FindRoot();
if (visited_groups.count(prev_group) == 0) {
visited_groups.insert(prev_group);
if (prev_group->args_num > 0) {
// Get the number of arguments from the group
return prev_group->args_num;
} else if (update_postpone) {
// Update pointer to the node which should be postponed for deferred fusing
postpone_node_ = src;
} else {
// Calculate the number of arguments for the node which wasn't processed before
return CountArgs_(node, graph, update_postpone);
}
}
return 0;
};
if (auto call_node = GetRef<ObjectRef>(src->ref).as<CallNode>()) {
for (auto& it : call_node->args) {
sum += calc_args_number(it);
}
} else if (auto tuple_node = GetRef<ObjectRef>(src->ref).as<TupleNode>()) {
for (auto& it : tuple_node->fields) {
sum += calc_args_number(it);
}
}
return sum;
}

size_t GraphPartitioner::CountArgsLimit_(const IndexedForwardGraph::Node* child) {
auto* outputs_list = child->outputs.head;
size_t output_args = 0;
while (outputs_list != nullptr) {
output_args++;
if (auto call_node = GetRef<ObjectRef>(outputs_list->value.node->ref).as<CallNode>()) {
if (const auto* ttype = call_node->checked_type().as<TensorTypeNode>()) {
output_args += CountAdditionalArgs_(ttype, false);
}
}
outputs_list = outputs_list->next;
}
return (max_function_args_ > output_args) ? max_function_args_ - output_args : 0;
}

size_t GraphPartitioner::CountFusedArgs(const IndexedForwardGraph& graph,
IndexedForwardGraph::Node* child) {
size_t args_num = 0;
auto* outputs_list = child->outputs.head;
while (outputs_list != nullptr) {
args_num = std::max(args_num, CountArgs_(outputs_list->value.node, graph));
outputs_list = outputs_list->next;
}
return args_num;
}

void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
auto args_counter = [this](const tvm::Object* obj) {
size_t args_num = 0;
if (auto call_node = GetRef<ObjectRef>(obj).as<CallNode>()) {
for (auto& it : call_node->args) {
if (it.as<VarNode>() || it.as<TupleGetItemNode>()) {
args_num++;
if (const auto* ttype = it.as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
}
}
} else if (auto tuple_node = GetRef<ObjectRef>(obj).as<TupleNode>()) {
for (auto& it : tuple_node->fields) {
if (it.as<VarNode>() || it.as<TupleGetItemNode>()) {
args_num++;
if (const auto* ttype = it.as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
}
}
} else if (GetRef<ObjectRef>(obj).as<VarNode>()) {
args_num++;
if (const auto* ttype =
GetRef<ObjectRef>(obj).as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
}
return args_num;
};

groups_.resize(graph.post_dfs_order.size());
for (size_t nid = 0; nid < groups_.size(); ++nid) {
const auto* graph_node = graph.post_dfs_order[nid];
Expand All @@ -231,6 +342,7 @@ void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
if (group_node->pattern == relay::kOutEWiseFusable) {
group_node->anchor_ref = graph_node->ref;
}
group_node->args_num = args_counter(graph_node->ref);
groups_[nid] = group_node;
}
}
Expand All @@ -244,6 +356,21 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
auto* dom_node = post_dom_tree.nodes[nid];
Group* group_node = groups_[nid];
ICHECK(group_node != nullptr);
postpone_node_ = nullptr;
// Check if the fusing of some inputs was postponed
if (postponed_fusing_map_.count(graph_node)) {
auto range = postponed_fusing_map_.equal_range(graph_node);
for (auto it = range.first; it != range.second; ++it) {
// If the number of arguments is less than the limit then the input can be fused
if (CountArgs_(graph_node, graph, false) <= CountArgsLimit_(graph_node)) {
auto* src = it->second;
auto* snode = post_dom_tree.nodes[src->index]->parent->gnode;
if (groups_[snode->index]->anchor_ref != nullptr) continue;
CommitFuse(src, snode);
}
}
postponed_fusing_map_.erase(graph_node);
}
// no actions for opaque nodes
if (group_node->pattern == kOpaque) continue;
// no actions needed if the current node have no dominator
Expand All @@ -254,6 +381,15 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
// refuse the fusion if too many ops are going to be fused together
if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
continue;
// Refuse the fusion if too many arguments are going to be in the fused function
if (max_function_args_ > 0) {
auto limit = CountArgsLimit_(graph_node);
if (limit > 0) {
if (CountFusedArgs(graph, graph_node) > limit) {
continue;
}
}
}

if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
Expand Down
46 changes: 43 additions & 3 deletions src/relay/analysis/graph_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class IndexedForwardGraph {
std::vector<Node*> post_dfs_order;

/*! \brief Dump the graph into string. */
void DebugDump() {
void DebugDump() const {
std::ostringstream os;
for (size_t i = 0; i < post_dfs_order.size(); ++i) {
Node* node = post_dfs_order[i];
Expand Down Expand Up @@ -162,8 +162,12 @@ class DominatorTree {
*/
class GraphPartitioner {
public:
explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth)
: arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {}
explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth,
size_t max_function_args)
: arena_(arena),
opt_level_(opt_level),
max_fuse_depth_(max_fuse_depth),
max_function_args_(max_function_args) {}
/*!
* \brief Group as a union find data structure.
*/
Expand All @@ -183,6 +187,10 @@ class GraphPartitioner {
* \brief The number of nodes belonging to this group
*/
uint32_t num_nodes{1};
/*!
* \brief The number of function arguments belonging to this group
*/
size_t args_num{0};

/*! \brief Optional attributes to annotate the grouped function. */
runtime::Map<runtime::String, ObjectRef> attrs;
Expand All @@ -205,10 +213,21 @@ class GraphPartitioner {
int opt_level_;
/*! \brief The maximum number of operations in one fused function */
size_t max_fuse_depth_;
/*! \brief The maximum number of arguments in one fused function */
size_t max_function_args_;
/*! \brief The internal groups. */
std::vector<Group*> groups_;
/*! \brief internal field used for deduplication */
std::unordered_set<IndexedForwardGraph::Node*> visited_;
/*! \brief The map with nodes which were postponed for fusing. */
std::unordered_multimap<const IndexedForwardGraph::Node*, IndexedForwardGraph::Node*>
postponed_fusing_map_;
/*!
* \brief Fusing of this node should be postponed till all child nodes will be evaluated.
* It is used to calculate the number of arguments which will be passed to this node in
* the generated function.
*/
const IndexedForwardGraph::Node* postpone_node_{nullptr};
// Internal implementation of CheckPath
template <typename F>
bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond);
Expand Down Expand Up @@ -247,6 +266,23 @@ class GraphPartitioner {
void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);

size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);
// Count the number of additional arguments. In the case of dynamic shape,
// generated function takes several additional arguments, such as the sizes of
// the dynamic dimensions and strides.
// This function calculates the number of such additional arguments.
size_t CountAdditionalArgs_(const TensorTypeNode* ttype, bool with_strides = true);
// Calculate the number of arguments for the node.
size_t CountArgs_(IndexedForwardGraph::Node* src, const IndexedForwardGraph& graph,
bool update_postpone = true);
// Count the actual limit of arguments for a generated function.
// max_function_args_ specifies the number of maximum function arguments. But
// usually, output tensors are also passed to the function as arguments.
// Additionally, in the case of dynamic shape, it is necessary to take into
// account the number of parameters which specifies the sizes of the dynamic
// dimensions.
// This function computes the limit of arguments by the following formula:
// limit = max_function_args_ - output_args_count
size_t CountArgsLimit_(const IndexedForwardGraph::Node* child);

// Count the number of nodes in a fused subgraph if child is additionally fused.
// dom_parent is already known to be a part of the subgraph.
Expand All @@ -256,6 +292,10 @@ class GraphPartitioner {
// is important for correct calculation.
size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
IndexedForwardGraph::Node* dom_parent);
// Count the number of arguments in a fused subgraph. This function also takes into account the
// number of the child's output node argument. It helps to stop fusing before the node when the
// limit will be exceeded.
size_t CountFusedArgs(const IndexedForwardGraph& graph, IndexedForwardGraph::Node* child);

// Initialize the groups.
void InitGroups(const IndexedForwardGraph& graph);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if (config_->optional_homogeneous_target.defined()) {
// This pass currently only supports the homogeneous case.
pass_seqs.push_back(transform::SplitArgs(
config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", -1)
config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", 0)
.value()
.IntValue()));
}
Expand Down
7 changes: 7 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,13 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
// Always plan devices so the remaining passes don't need to distinguish homogeneous vs
// heterogeneous execution.
pass_seqs.push_back(transform::PlanDevices(config_));
if (config_->optional_homogeneous_target.defined()) {
// This pass currently only supports the homogeneous case.
pass_seqs.push_back(transform::SplitArgs(
config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", 0)
.value()
.IntValue()));
}

pass_seqs.push_back(transform::FuseOps());

Expand Down
Loading