Skip to content

Commit

Permalink
[BYOC] Added annotate_non_call_ops parameter to AnnotateTarget pass
Browse files Browse the repository at this point in the history
Added annotate_non_call_ops parameter to AnnotateTarget pass to prevent
non-call to be promoted to previously annotated operations
This is useful in case if you are not running MergeCompilerRegions
pass after AnnotateTarget.
  • Loading branch information
d-smirnov committed Nov 19, 2020
1 parent 3950639 commit 301bd38
Show file tree
Hide file tree
Showing 3 changed files with 328 additions and 79 deletions.
9 changes: 7 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,14 +697,17 @@ def PartitionGraph():
return _ffi_api.PartitionGraph()


def AnnotateTarget(targets):
def AnnotateTarget(targets, include_non_call_ops=True):
"""Annotate ops in an experession with a provied compiler/target and then
use it for codegen.
Parameters
----------
targets : str or List[str]
The list of target compilers used for codegen.
include_non_call_ops : boolean
If True then non-call ops also will be annotated with targets
If False then non-call ops will not be processed
Returns
-------
Expand All @@ -714,7 +717,9 @@ def AnnotateTarget(targets):
"""
if isinstance(targets, str):
targets = [targets]
return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])
return _ffi_api.AnnotateTarget(
[tvm.runtime.container.String(t) for t in targets], include_non_call_ops
)


def DynamicToStatic():
Expand Down
187 changes: 143 additions & 44 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,19 @@ static const PackedFunc* make_begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
static const PackedFunc* make_end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");

// A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler.
static const std::string default_target = "default";
// A helper class to insert annotation boundaries for all the ops of a program
// region that will be handled by a specific compiler.
class AnnotateTargetRewriter : public ExprRewriter {
public:
explicit AnnotateTargetRewriter(Array<runtime::String> targets) : targets_(std::move(targets)) {}

protected:
/*! \brief The target backends for annotation. */
Array<runtime::String> targets_;
/*! \brief Maintain the decision of the target for each op expr. */
std::unordered_map<Expr, std::string, ObjectPtrHash, ObjectPtrEqual> op_expr_to_target_;

/*!
* \brief This function annotates a compiler end and a compiler begin to all arguments.
*
Expand All @@ -61,20 +67,27 @@ class AnnotateTargetRewriter : public ExprRewriter {
std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args,
const std::string& target = "") {
std::string ref_target = "";
Array<Expr> compiler_begins;
Array<Expr> compiler_ends;
for (auto arg : args) {
std::string arg_target = "default";
std::string arg_target = default_target;
const CallNode* call = arg.as<CallNode>();

if (call && call->op == CompilerBeginOp()) {
// Argument is already compiler begin node meaning that this is not the first time
// running this pass, so we simply remove it and will add a new one later.
ICHECK_EQ(call->args.size(), 1U);
// Do not alter existing annotation if not default
if (default_target != call->attrs.as<CompilerAttrs>()->compiler) {
compiler_begins.push_back(arg);
} else {
// Remove default
compiler_ends.push_back(call->args[0]);
}
const CallNode* end = call->args[0].as<CallNode>();
if (end->op == CompilerEndOp()) {
if (end && end->op == CompilerEndOp()) {
arg_target = end->attrs.as<CompilerAttrs>()->compiler;
}
compiler_ends.push_back(call->args[0]);
} else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
arg_target = op_expr_to_target_[arg];
// If an argument is a call node and has no argument, then it should be tensor ops such as
Expand All @@ -93,18 +106,20 @@ class AnnotateTargetRewriter : public ExprRewriter {
if (ref_target == "") {
ref_target = arg_target;
} else if (ref_target != arg_target) {
ref_target = "default";
ref_target = default_target;
}
}

// Determine compiler begin target.
std::string op_target = (target == "") ? ref_target : target;

Array<Expr> compiler_begins;
for (const auto& end : compiler_ends) {
compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op));
if (ref_target != "") {
for (const auto& end : compiler_ends) {
compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op));
}
} else {
return {op_target, args};
}

return {op_target, compiler_begins};
}

Expand All @@ -128,14 +143,31 @@ class AnnotateTargetRewriter : public ExprRewriter {
* \return An annotated and target-propagated relay expression.
*/
Expr new_expr = expr;
if (op_expr_to_target_.find(expr) != op_expr_to_target_.end() && FreeVars(expr).size() != 0) {
new_expr = InsertAnnotation(expr, op_expr_to_target_[expr], make_end_op);
op_expr_to_target_[new_expr] = op_expr_to_target_[expr];
const CallNode* call = expr.as<CallNode>();
if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
// Check whether expr has args, if not - do not insert compiler_end.
expr->IsInstance<RefWriteNode>();
if (expr->IsInstance<RefWriteNode>() || expr->IsInstance<RefCreateNode>() ||
expr->IsInstance<RefReadNode>() || (call && !call->args.empty())) {
std::string target = op_expr_to_target_[new_expr];
new_expr = InsertAnnotation(new_expr, target, make_end_op);
op_expr_to_target_[new_expr] = target;
}
} else if (call && call->op == CompilerEndOp()) {
if (default_target == call->attrs.as<CompilerAttrs>()->compiler) {
ICHECK_EQ(call->args.size(), 1U);
new_expr = call->args[0];
std::string target = op_expr_to_target_[new_expr];
new_expr = InsertAnnotation(new_expr, target, make_end_op);
op_expr_to_target_[new_expr] = target;
}
}

return std::move(new_expr);
}

Expr Rewrite_(const CallNode* pre, const Expr& post) final {
public:
virtual Expr Rewrite_(const CallNode* pre, const Expr& post) override {
// Supported targets for this node. The order implies the priority.
std::vector<std::string> supported_targets;

Expand All @@ -146,13 +178,19 @@ class AnnotateTargetRewriter : public ExprRewriter {
// Bypass compiler begin due to lack of target information. It will be processed
// when the following op handling arguments.
ICHECK_EQ(pre->args.size(), 1U);
return post.as<CallNode>()->args[0];
// Preserve annotations
return post;
} else if (op_node && pre->op == CompilerEndOp()) {
// Override compiler end with the new target.
ICHECK_EQ(pre->args.size(), 1U);
auto input_expr = post.as<CallNode>()->args[0];
// Already annotated. Recover target
if (op_expr_to_target_.find(input_expr) == op_expr_to_target_.end()) {
op_expr_to_target_[input_expr] = post.as<CallNode>()->attrs.as<CompilerAttrs>()->compiler;
}
ICHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());
return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op);
// Preserve annotated nodes
return post;
}
// Check prior to peeking first argument
if (pre->args.size()) {
Expand All @@ -161,8 +199,9 @@ class AnnotateTargetRewriter : public ExprRewriter {
const CallNode* first_arg_call = pre->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
if (arg_target != default_target) {
// annotated already
return post;
}
}
}
Expand All @@ -188,7 +227,6 @@ class AnnotateTargetRewriter : public ExprRewriter {
// if it is in the target list.
Function func = Downcast<Function>(pre->op);
ICHECK(func.defined());

if (auto comp_name = func->GetAttr<String>(attr::kComposite)) {
std::string comp_name_str = comp_name.value();
size_t i = comp_name_str.find('.');
Expand All @@ -203,16 +241,18 @@ class AnnotateTargetRewriter : public ExprRewriter {
}
}
}
supported_targets.push_back("default"); // Make default as the last option.

supported_targets.push_back(default_target); // Make default as the last option.
// Visit and mutate arguments after the target of this op has been determined.
Call post_call = Downcast<Call>(post);
if (pre->op->IsInstance<VarNode>()) {
auto new_call = RewriteVarCall(post_call);
if (nullptr != new_call) return GetRef<Expr>(new_call->get());
}
// TODO(@comaniac, @zhiics): Now we simply assign this node to the target with
// the highest priority, but we should preserve all supported targets so that
// we can make a better decision.
std::string target = supported_targets[0];

// Visit and mutate arguments after the target of this op has been determined.
Call post_call = Downcast<Call>(post);

// Add annotations to each arg.
auto target_n_args = AnnotateArgs(post_call->args, target);
Array<Expr> compiler_begins = std::get<1>(target_n_args);
Expand All @@ -221,11 +261,12 @@ class AnnotateTargetRewriter : public ExprRewriter {

// Update the target map.
op_expr_to_target_[new_call] = target;

return std::move(new_call);
}

Expr Rewrite_(const TupleNode* op, const Expr& post) final {
virtual std::unique_ptr<Call> RewriteVarCall(const Call& post_call) { return nullptr; }

virtual Expr Rewrite_(const TupleNode* op, const Expr& post) override {
auto expr = Downcast<Tuple>(post);

auto target_n_args = AnnotateArgs(expr->fields);
Expand All @@ -234,7 +275,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
return std::move(new_expr);
}

Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
virtual Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override {
auto expr = Downcast<TupleGetItem>(post);

auto target_n_args = AnnotateArgs(Array<Expr>({expr->tuple}));
Expand All @@ -243,7 +284,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
return std::move(new_expr);
}

Expr Rewrite_(const FunctionNode* fn, const Expr& post) final {
virtual Expr Rewrite_(const FunctionNode* fn, const Expr& post) override {
Function func;
Expr new_body;
// don't step into composite functions
Expand All @@ -257,7 +298,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs);
}

Expr Rewrite_(const LetNode* op, const Expr& post) final {
virtual Expr Rewrite_(const LetNode* op, const Expr& post) override {
auto let = Downcast<Let>(post);

Expr new_expr;
Expand All @@ -274,7 +315,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
return std::move(new_expr);
}

Expr Rewrite_(const IfNode* op, const Expr& post) final {
virtual Expr Rewrite_(const IfNode* op, const Expr& post) override {
auto expr = Downcast<If>(post);
Expr new_cond = InsertCompilerEndAndPropogateTarget(expr->cond);
Expr new_true_branch = InsertCompilerEndAndPropogateTarget(expr->true_branch);
Expand All @@ -284,7 +325,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
return std::move(new_expr);
}

Expr Rewrite_(const RefCreateNode* op, const Expr& post) final {
virtual Expr Rewrite_(const RefCreateNode* op, const Expr& post) override {
auto expr = Downcast<RefCreate>(post);

auto target_n_args = AnnotateArgs(Array<Expr>({expr->value}));
Expand All @@ -293,7 +334,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
return std::move(new_expr);
}

Expr Rewrite_(const RefReadNode* op, const Expr& post) final {
virtual Expr Rewrite_(const RefReadNode* op, const Expr& post) override {
auto expr = Downcast<RefRead>(post);

auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref}));
Expand All @@ -302,35 +343,93 @@ class AnnotateTargetRewriter : public ExprRewriter {
return std::move(new_expr);
}

Expr Rewrite_(const RefWriteNode* op, const Expr& post) final {
virtual Expr Rewrite_(const RefWriteNode* op, const Expr& post) override {
auto expr = Downcast<RefWrite>(post);

auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref, expr->value}));
auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
}
};

private:
/*! \brief The target backends for annotation. */
Array<runtime::String> targets_;
/*! \brief Maintain the decision of the target for each op expr. */
std::unordered_map<Expr, std::string, ObjectPtrHash, ObjectPtrEqual> op_expr_to_target_;
// A helper class to insert annotation boundaries for call ops and function nodes
// in a program region that will be handled by a specific compiler.
class CallOpsTargetRewriter : public AnnotateTargetRewriter {
public:
explicit CallOpsTargetRewriter(Array<runtime::String> targets)
: AnnotateTargetRewriter(std::move(targets)) {}

virtual std::unique_ptr<Call> RewriteVarCall(const Call& post_call) override {
Array<Expr> ends;
for (auto arg : post_call->args) {
ends.push_back(InsertCompilerEndAndPropogateTarget(arg));
}
auto new_call = std::make_unique<Call>(post_call->op, ends, post_call->attrs);
(*new_call)->checked_type_ = post_call->checked_type_;
return new_call;
}

virtual Expr Rewrite_(const TupleNode* op, const Expr& post) override {
auto expr = Downcast<Tuple>(post);
Array<Expr> new_fields;
for (auto f : expr->fields) {
new_fields.push_back(InsertCompilerEndAndPropogateTarget(f));
}
return std::move(Tuple(new_fields));
}

virtual Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override {
auto expr = Downcast<TupleGetItem>(post);
return std::move(TupleGetItem(InsertCompilerEndAndPropogateTarget(expr->tuple), expr->index));
}

virtual Expr Rewrite_(const IfNode* op, const Expr& post) override {
auto expr = Downcast<If>(post);
Expr new_cond = InsertCompilerEndAndPropogateTarget(expr->cond);
Expr new_true_branch = InsertCompilerEndAndPropogateTarget(expr->true_branch);
Expr new_false_branch = InsertCompilerEndAndPropogateTarget(expr->false_branch);

auto new_expr = If(new_cond, new_true_branch, new_false_branch);
return std::move(new_expr);
}

virtual Expr Rewrite_(const RefCreateNode* op, const Expr& post) override {
auto expr = Downcast<RefCreate>(post);
auto new_expr = RefCreate(InsertCompilerEndAndPropogateTarget(expr->value));
return std::move(new_expr);
}

virtual Expr Rewrite_(const RefReadNode* op, const Expr& post) override {
auto expr = Downcast<RefRead>(post);
auto new_expr = RefRead(InsertCompilerEndAndPropogateTarget(expr->ref));
return std::move(new_expr);
}

virtual Expr Rewrite_(const RefWriteNode* op, const Expr& post) override {
auto expr = Downcast<RefWrite>(post);
auto new_expr = RefWrite(InsertCompilerEndAndPropogateTarget(expr->ref),
InsertCompilerEndAndPropogateTarget(expr->value));
return std::move(new_expr);
}
};

Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets) {
auto rewriter = AnnotateTargetRewriter(targets);
return PostOrderRewrite(expr, &rewriter);
Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets,
bool include_non_call_ops) {
auto r = include_non_call_ops ? std::make_unique<AnnotateTargetRewriter>(targets)
: std::make_unique<CallOpsTargetRewriter>(targets);
return PostOrderRewrite(expr, r.get());
}

} // namespace annotate_target

namespace transform {

Pass AnnotateTarget(const Array<runtime::String>& targets) {
Pass AnnotateTarget(const Array<runtime::String>& targets, bool include_non_call_ops) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, targets));
return Downcast<Function>(
relay::annotate_target::AnnotateTarget(f, targets, include_non_call_ops));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", {"InferType"});
return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
Expand Down
Loading

0 comments on commit 301bd38

Please sign in to comment.