Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC] Added "include_non_call_ops" parameter to AnnotateTarget pass #6655

Merged
merged 5 commits into from
Dec 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,14 +697,17 @@ def PartitionGraph():
return _ffi_api.PartitionGraph()


def AnnotateTarget(targets):
def AnnotateTarget(targets, include_non_call_ops=True):
"""Annotate ops in an experession with a provied compiler/target and then
use it for codegen.

Parameters
----------
targets : str or List[str]
The list of target compilers used for codegen.
include_non_call_ops : boolean
If True then non-call ops also will be annotated with targets
If False then non-call ops will not be processed

Returns
-------
Expand All @@ -714,7 +717,9 @@ def AnnotateTarget(targets):
"""
if isinstance(targets, str):
targets = [targets]
return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])
return _ffi_api.AnnotateTarget(
[tvm.runtime.container.String(t) for t in targets], include_non_call_ops
)


def DynamicToStatic():
Expand Down
187 changes: 143 additions & 44 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,19 @@ static const PackedFunc* make_begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
static const PackedFunc* make_end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");

// A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler.
static const char default_target[] = "default";
// A helper class to insert annotation boundaries for all the ops of a program
// region that will be handled by a specific compiler.
class AnnotateTargetRewriter : public ExprRewriter {
public:
explicit AnnotateTargetRewriter(Array<runtime::String> targets) : targets_(std::move(targets)) {}

protected:
/*! \brief The target backends for annotation. */
Array<runtime::String> targets_;
/*! \brief Maintain the decision of the target for each op expr. */
std::unordered_map<Expr, std::string, ObjectPtrHash, ObjectPtrEqual> op_expr_to_target_;

/*!
* \brief This function annotates a compiler end and a compiler begin to all arguments.
*
Expand All @@ -61,20 +67,27 @@ class AnnotateTargetRewriter : public ExprRewriter {
std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args,
const std::string& target = "") {
std::string ref_target = "";
Array<Expr> compiler_begins;
Array<Expr> compiler_ends;
for (auto arg : args) {
std::string arg_target = "default";
std::string arg_target = default_target;
const CallNode* call = arg.as<CallNode>();

if (call && call->op == CompilerBeginOp()) {
// Argument is already compiler begin node meaning that this is not the first time
// running this pass, so we simply remove it and will add a new one later.
ICHECK_EQ(call->args.size(), 1U);
// Do not alter existing annotation if not default
if (default_target != call->attrs.as<CompilerAttrs>()->compiler) {
compiler_begins.push_back(arg);
} else {
// Remove default
compiler_ends.push_back(call->args[0]);
}
comaniac marked this conversation as resolved.
Show resolved Hide resolved
const CallNode* end = call->args[0].as<CallNode>();
if (end->op == CompilerEndOp()) {
if (end && end->op == CompilerEndOp()) {
arg_target = end->attrs.as<CompilerAttrs>()->compiler;
}
compiler_ends.push_back(call->args[0]);
} else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
arg_target = op_expr_to_target_[arg];
// If an argument is a call node and has no argument, then it should be tensor ops such as
Expand All @@ -93,18 +106,20 @@ class AnnotateTargetRewriter : public ExprRewriter {
if (ref_target == "") {
ref_target = arg_target;
} else if (ref_target != arg_target) {
ref_target = "default";
ref_target = default_target;
}
}

// Determine compiler begin target.
std::string op_target = (target == "") ? ref_target : target;

Array<Expr> compiler_begins;
for (const auto& end : compiler_ends) {
compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op));
if (ref_target != "") {
for (const auto& end : compiler_ends) {
compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op));
}
} else {
return {op_target, args};
comaniac marked this conversation as resolved.
Show resolved Hide resolved
}

return {op_target, compiler_begins};
}

Expand All @@ -128,14 +143,31 @@ class AnnotateTargetRewriter : public ExprRewriter {
* \return An annotated and target-propagated relay expression.
*/
Expr new_expr = expr;
if (op_expr_to_target_.find(expr) != op_expr_to_target_.end() && FreeVars(expr).size() != 0) {
new_expr = InsertAnnotation(expr, op_expr_to_target_[expr], make_end_op);
op_expr_to_target_[new_expr] = op_expr_to_target_[expr];
const CallNode* call = expr.as<CallNode>();
if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
// Check whether expr has args, if not - do not insert compiler_end.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot connect this comment to the following logic. Could you elaborate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment related to this part (call && !call->args.empty())) of the condition

if (expr->IsInstance<RefWriteNode>() || expr->IsInstance<RefCreateNode>() ||
expr->IsInstance<RefReadNode>() || expr->IsInstance<TupleNode>() ||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There would be more nodes, like constructors. But I am still concerned if this changed is needed. This really makes this already complicated pass more complicated. I still don't see a good point why we don't run mergecompilerregions. That would solve this problem. Without running it, we would have a large number of small segments, which requires frequent data transfer between the host and device as well as frequent kernel launch.

Copy link
Contributor

@manupak manupak Nov 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While running merge compiler regions helps cut down the regions, it also makes the external codegen's responsibility to allocate memory for intermediate tensors on those partitions. Thus, in the specific case of ACL, I think there is not much gained by such a merger as ACL would be implementing each ACL primitive operator and let tvm handle the memory allocation of the tensors passed onto external function. Moreover, the kernel launch overhead should also be minimal as it is running on the CPU (so the host and device is almost the same here). Also such a merger will also make the IO tensors live throughout the execution of external function while the space could be re-used if it was not merged.

The problem is the specification of the ACL did not indicate the simple regions (or non-call ops) to be annotated, thus annotate target here is doing something extra than it was asked.

I quite agree that this pass is complicated and needs breakdown. I guess that should be discussed in a RFC as to how it should look like. One direction maybe to take out the annotation of simple regions (non-call ops) as a seperate part ( I believe this was how it looked liked sometime back when it had something called AnnotateRestDefault until it got merged here :) ).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment here : #5277

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it would be nice to have the RFC and list the options there

expr->IsInstance<TupleGetItemNode>() || (call && !call->args.empty())) {
std::string target = op_expr_to_target_[new_expr];
new_expr = InsertAnnotation(new_expr, target, make_end_op);
op_expr_to_target_[new_expr] = target;
}
} else if (call && call->op == CompilerEndOp()) {
if (default_target == call->attrs.as<CompilerAttrs>()->compiler) {
ICHECK_EQ(call->args.size(), 1U);
new_expr = call->args[0];
std::string target = op_expr_to_target_[new_expr];
new_expr = InsertAnnotation(new_expr, target, make_end_op);
op_expr_to_target_[new_expr] = target;
}
}

return std::move(new_expr);
}

Expr Rewrite_(const CallNode* pre, const Expr& post) final {
public:
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
// Supported targets for this node. The order implies the priority.
std::vector<std::string> supported_targets;

Expand All @@ -146,13 +178,19 @@ class AnnotateTargetRewriter : public ExprRewriter {
// Bypass compiler begin due to lack of target information. It will be processed
// when the following op handling arguments.
ICHECK_EQ(pre->args.size(), 1U);
return post.as<CallNode>()->args[0];
// Preserve annotations
return post;
} else if (op_node && pre->op == CompilerEndOp()) {
// Override compiler end with the new target.
ICHECK_EQ(pre->args.size(), 1U);
auto input_expr = post.as<CallNode>()->args[0];
// Already annotated. Recover target
if (op_expr_to_target_.find(input_expr) == op_expr_to_target_.end()) {
op_expr_to_target_[input_expr] = post.as<CallNode>()->attrs.as<CompilerAttrs>()->compiler;
}
Comment on lines +187 to +190
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you don't need the IF? Even input_expr is already in op_expr_to_target_, you can still override it, as suggested by the comment in L188. Accordingly, if you will override the target, you need InsertAnnotation.

Copy link
Contributor Author

@d-smirnov d-smirnov Nov 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is not the first invocation of the pass this branch supposed to restore targets from already annotated ops.

ICHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());
return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op);
// Preserve annotated nodes
return post;
}
// Check prior to peeking first argument
if (pre->args.size()) {
Expand All @@ -161,8 +199,9 @@ class AnnotateTargetRewriter : public ExprRewriter {
const CallNode* first_arg_call = pre->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
if (arg_target != default_target) {
// annotated already
return post;
Comment on lines +202 to +204
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you remove the feature that considers the target in existing annotation nodes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it is peeking first arg and if it is already annotated with non-default target it returns the node untouched, preserving the target.

}
}
}
Expand All @@ -188,7 +227,6 @@ class AnnotateTargetRewriter : public ExprRewriter {
// if it is in the target list.
Function func = Downcast<Function>(pre->op);
ICHECK(func.defined());

if (auto comp_name = func->GetAttr<String>(attr::kComposite)) {
std::string comp_name_str = comp_name.value();
size_t i = comp_name_str.find('.');
Expand All @@ -203,16 +241,18 @@ class AnnotateTargetRewriter : public ExprRewriter {
}
}
}
supported_targets.push_back("default"); // Make default as the last option.

supported_targets.push_back(default_target); // Make default as the last option.
// Visit and mutate arguments after the target of this op has been determined.
Call post_call = Downcast<Call>(post);
if (pre->op->IsInstance<VarNode>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate why a CallNode may have a VarNode as its op?

Copy link
Contributor Author

@d-smirnov d-smirnov Nov 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test case is tests/python/relay/test_pass_annotate_target.py::test_while_let

auto new_call = RewriteVarCall(post_call);
if (nullptr != new_call) return GetRef<Expr>(new_call->get());
}
// TODO(@comaniac, @zhiics): Now we simply assign this node to the target with
// the highest priority, but we should preserve all supported targets so that
// we can make a better decision.
std::string target = supported_targets[0];

// Visit and mutate arguments after the target of this op has been determined.
Call post_call = Downcast<Call>(post);

// Add annotations to each arg.
auto target_n_args = AnnotateArgs(post_call->args, target);
Array<Expr> compiler_begins = std::get<1>(target_n_args);
Expand All @@ -221,11 +261,12 @@ class AnnotateTargetRewriter : public ExprRewriter {

// Update the target map.
op_expr_to_target_[new_call] = target;

return std::move(new_call);
}

Expr Rewrite_(const TupleNode* op, const Expr& post) final {
virtual std::unique_ptr<Call> RewriteVarCall(const Call& post_call) { return nullptr; }

Expr Rewrite_(const TupleNode* op, const Expr& post) override {
auto expr = Downcast<Tuple>(post);

auto target_n_args = AnnotateArgs(expr->fields);
Expand All @@ -234,7 +275,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
return std::move(new_expr);
}

Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override {
auto expr = Downcast<TupleGetItem>(post);

auto target_n_args = AnnotateArgs(Array<Expr>({expr->tuple}));
Expand All @@ -243,7 +284,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
return std::move(new_expr);
}

Expr Rewrite_(const FunctionNode* fn, const Expr& post) final {
Expr Rewrite_(const FunctionNode* fn, const Expr& post) override {
Function func;
Expr new_body;
// don't step into composite functions
Expand All @@ -257,7 +298,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs);
}

Expr Rewrite_(const LetNode* op, const Expr& post) final {
Expr Rewrite_(const LetNode* op, const Expr& post) override {
auto let = Downcast<Let>(post);

Expr new_expr;
Expand All @@ -274,7 +315,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
return std::move(new_expr);
}

Expr Rewrite_(const IfNode* op, const Expr& post) final {
Expr Rewrite_(const IfNode* op, const Expr& post) override {
auto expr = Downcast<If>(post);
Expr new_cond = InsertCompilerEndAndPropogateTarget(expr->cond);
Expr new_true_branch = InsertCompilerEndAndPropogateTarget(expr->true_branch);
Expand All @@ -284,7 +325,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
return std::move(new_expr);
}

Expr Rewrite_(const RefCreateNode* op, const Expr& post) final {
Expr Rewrite_(const RefCreateNode* op, const Expr& post) override {
auto expr = Downcast<RefCreate>(post);

auto target_n_args = AnnotateArgs(Array<Expr>({expr->value}));
Expand All @@ -293,7 +334,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
return std::move(new_expr);
}

Expr Rewrite_(const RefReadNode* op, const Expr& post) final {
Expr Rewrite_(const RefReadNode* op, const Expr& post) override {
auto expr = Downcast<RefRead>(post);

auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref}));
Expand All @@ -302,35 +343,93 @@ class AnnotateTargetRewriter : public ExprRewriter {
return std::move(new_expr);
}

Expr Rewrite_(const RefWriteNode* op, const Expr& post) final {
Expr Rewrite_(const RefWriteNode* op, const Expr& post) override {
auto expr = Downcast<RefWrite>(post);

auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref, expr->value}));
auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
}
};

private:
/*! \brief The target backends for annotation. */
Array<runtime::String> targets_;
/*! \brief Maintain the decision of the target for each op expr. */
std::unordered_map<Expr, std::string, ObjectPtrHash, ObjectPtrEqual> op_expr_to_target_;
// A helper class to insert annotation boundaries for call ops and function nodes
// in a program region that will be handled by a specific compiler.
class CallOpsTargetRewriter : public AnnotateTargetRewriter {
public:
explicit CallOpsTargetRewriter(Array<runtime::String> targets)
: AnnotateTargetRewriter(std::move(targets)) {}

std::unique_ptr<Call> RewriteVarCall(const Call& post_call) override {
Array<Expr> ends;
for (auto arg : post_call->args) {
ends.push_back(InsertCompilerEndAndPropogateTarget(arg));
}
auto new_call = std::make_unique<Call>(post_call->op, ends, post_call->attrs);
(*new_call)->checked_type_ = post_call->checked_type_;
return new_call;
}

Expr Rewrite_(const TupleNode* op, const Expr& post) override {
auto expr = Downcast<Tuple>(post);
Array<Expr> new_fields;
for (auto f : expr->fields) {
new_fields.push_back(InsertCompilerEndAndPropogateTarget(f));
}
return std::move(Tuple(new_fields));
}

Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override {
auto expr = Downcast<TupleGetItem>(post);
return std::move(TupleGetItem(InsertCompilerEndAndPropogateTarget(expr->tuple), expr->index));
}

Expr Rewrite_(const IfNode* op, const Expr& post) override {
auto expr = Downcast<If>(post);
Expr new_cond = InsertCompilerEndAndPropogateTarget(expr->cond);
Expr new_true_branch = InsertCompilerEndAndPropogateTarget(expr->true_branch);
Expr new_false_branch = InsertCompilerEndAndPropogateTarget(expr->false_branch);

auto new_expr = If(new_cond, new_true_branch, new_false_branch);
return std::move(new_expr);
}

Expr Rewrite_(const RefCreateNode* op, const Expr& post) override {
auto expr = Downcast<RefCreate>(post);
auto new_expr = RefCreate(InsertCompilerEndAndPropogateTarget(expr->value));
return std::move(new_expr);
}

Expr Rewrite_(const RefReadNode* op, const Expr& post) override {
auto expr = Downcast<RefRead>(post);
auto new_expr = RefRead(InsertCompilerEndAndPropogateTarget(expr->ref));
return std::move(new_expr);
}

Expr Rewrite_(const RefWriteNode* op, const Expr& post) override {
auto expr = Downcast<RefWrite>(post);
auto new_expr = RefWrite(InsertCompilerEndAndPropogateTarget(expr->ref),
InsertCompilerEndAndPropogateTarget(expr->value));
return std::move(new_expr);
}
};

Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets) {
auto rewriter = AnnotateTargetRewriter(targets);
return PostOrderRewrite(expr, &rewriter);
Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets,
bool include_non_call_ops) {
auto r = include_non_call_ops ? std::make_unique<AnnotateTargetRewriter>(targets)
: std::make_unique<CallOpsTargetRewriter>(targets);
return PostOrderRewrite(expr, r.get());
}

} // namespace annotate_target

namespace transform {

Pass AnnotateTarget(const Array<runtime::String>& targets) {
Pass AnnotateTarget(const Array<runtime::String>& targets, bool include_non_call_ops) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, targets));
return Downcast<Function>(
relay::annotate_target::AnnotateTarget(f, targets, include_non_call_ops));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", {"InferType"});
return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
Expand Down
Loading