Skip to content

Commit

Permalink
[BYOC] Added default_tuples parameter to AnnotateTarget pass
Browse files Browse the repository at this point in the history
Added default_tuples parameter to AnnotateTarget pass to prevent
tuples to be promoted to previously annotated operations
This is useful in case if you are not running MergeCompilerRegions
pass after AnnotateTarget.
  • Loading branch information
d-smirnov committed Oct 9, 2020
1 parent 0922d17 commit f22bf74
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 19 deletions.
6 changes: 4 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def PartitionGraph():
return _ffi_api.PartitionGraph()


def AnnotateTarget(targets):
def AnnotateTarget(targets, default_tuples=False):
"""Annotate ops in an experession with a provied compiler/target and then
use it for codegen.
Expand All @@ -683,7 +683,9 @@ def AnnotateTarget(targets):
"""
if isinstance(targets, str):
targets = [targets]
return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])
return _ffi_api.AnnotateTarget(
[tvm.runtime.container.String(t) for t in targets], default_tuples
)


def DynamicToStatic():
Expand Down
49 changes: 32 additions & 17 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ static const PackedFunc* make_begin_op =
static const PackedFunc* make_end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");

#define DEFAULT_TARGET_NAME ("default")

// A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler.
class AnnotateTargetRewriter : public ExprRewriter {
public:
explicit AnnotateTargetRewriter(Array<runtime::String> targets) : targets_(std::move(targets)) {}
explicit AnnotateTargetRewriter(Array<runtime::String> targets, bool default_tuples)
: default_tuples_{default_tuples}, targets_(std::move(targets)) {}

/*!
* \brief This function annotates a compiler end and a compiler begin to all arguments.
Expand All @@ -62,19 +65,22 @@ class AnnotateTargetRewriter : public ExprRewriter {
const std::string& target = "") {
std::string ref_target = "";
Array<Expr> compiler_ends;
Array<Expr> compiler_begins;
for (auto arg : args) {
std::string arg_target = "default";
std::string arg_target = DEFAULT_TARGET_NAME;
const CallNode* call = arg.as<CallNode>();

if (call && call->op == CompilerBeginOp()) {
// Argument is already compiler begin node meaning that this is not the first time
// running this pass, so we simply remove it and will add a new one later.
CHECK_EQ(call->args.size(), 1U);
const CallNode* end = call->args[0].as<CallNode>();
if (end->op == CompilerEndOp()) {
if (end && end->op == CompilerEndOp()) {
arg_target = end->attrs.as<CompilerAttrs>()->compiler;
}
compiler_ends.push_back(call->args[0]);
if (arg_target == "" || arg_target == DEFAULT_TARGET_NAME)
compiler_ends.push_back(call->args[0]);
else
compiler_begins.push_back(arg);
} else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
arg_target = op_expr_to_target_[arg];
compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
Expand All @@ -87,14 +93,13 @@ class AnnotateTargetRewriter : public ExprRewriter {
if (ref_target == "") {
ref_target = arg_target;
} else if (ref_target != arg_target) {
ref_target = "default";
ref_target = DEFAULT_TARGET_NAME;
}
}

// Determine compiler begin target.
std::string op_target = (target == "") ? ref_target : target;

Array<Expr> compiler_begins;
for (const auto& end : compiler_ends) {
compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op));
}
Expand All @@ -113,18 +118,26 @@ class AnnotateTargetRewriter : public ExprRewriter {
std::vector<std::string> supported_targets;

auto op_node = pre->op.as<OpNode>();

// This graph has annotations, meaning that this is not the first time running this pass.
if (op_node && pre->op == CompilerBeginOp()) {
// Bypass compiler begin due to lack of target information. It will be processed
// when the following op handling arguments.
CHECK_EQ(pre->args.size(), 1U);
return post.as<CallNode>()->args[0];

std::string begin_target = pre->attrs.as<CompilerAttrs>()->compiler;
if ("" == begin_target || begin_target == DEFAULT_TARGET_NAME)
return post.as<CallNode>()->args[0];

return post;
} else if (op_node && pre->op == CompilerEndOp()) {
// Override compiler end with the new target.
CHECK_EQ(pre->args.size(), 1U);
auto input_expr = post.as<CallNode>()->args[0];
CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());

std::string begin_target = pre->attrs.as<CompilerAttrs>()->compiler;
if ("" != begin_target && begin_target != DEFAULT_TARGET_NAME) return post;

return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op);
}

Expand All @@ -133,7 +146,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
const CallNode* first_arg_call = pre->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
if (arg_target != DEFAULT_TARGET_NAME) {
supported_targets.push_back(arg_target);
}
}
Expand Down Expand Up @@ -173,7 +186,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
}
}
}
supported_targets.push_back("default"); // Make default as the last option.
supported_targets.push_back(DEFAULT_TARGET_NAME); // Make default as the last option.

// TODO(@comaniac, @zhiics): Now we simply assign this node to the target with
// the highest priority, but we should preserve all supported targets so that
Expand All @@ -197,8 +210,8 @@ class AnnotateTargetRewriter : public ExprRewriter {

Expr Rewrite_(const TupleNode* op, const Expr& post) final {
auto expr = Downcast<Tuple>(post);

auto target_n_args = AnnotateArgs(expr->fields);
auto target_n_args = default_tuples_ ? AnnotateArgs(expr->fields, DEFAULT_TARGET_NAME)
: AnnotateArgs(expr->fields);
auto new_expr = Tuple(std::get<1>(target_n_args));
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
Expand Down Expand Up @@ -279,25 +292,27 @@ class AnnotateTargetRewriter : public ExprRewriter {
}

private:
bool default_tuples_;
/*! \brief The target backends for annotation. */
Array<runtime::String> targets_;
/*! \brief Maintain the decision of the target for each op expr. */
std::unordered_map<Expr, std::string, ObjectPtrHash, ObjectPtrEqual> op_expr_to_target_;
};

Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets) {
auto rewriter = AnnotateTargetRewriter(targets);
Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets, bool default_tuples) {
auto rewriter = AnnotateTargetRewriter(targets, default_tuples);
return PostOrderRewrite(expr, &rewriter);
}

} // namespace annotate_target

namespace transform {

Pass AnnotateTarget(const Array<runtime::String>& targets) {
Pass AnnotateTarget(const Array<runtime::String>& targets, bool default_tuples) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, targets));
return Downcast<Function>(
relay::annotate_target::AnnotateTarget(f, targets, default_tuples));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", {"InferType"});
return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
Expand Down
37 changes: 37 additions & 0 deletions tests/python/relay/test_pass_annotate_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,42 @@ def after():
assert tvm.ir.structural_equal(expected, result)


def test_tuple_two_targets():
"""Tests whether the TupleNode is promoted to previously annotatated operation or is excluded."""
target_relu = "relu_target"
target_maximum = "maximum_target"
target_default = "default"

@tvm.ir.register_op_attr("nn.relu", "target." + target_relu)
def relu(attrs, args): # pylint: disable=unused-variable
return True

@tvm.ir.register_op_attr("maximum", "target." + target_maximum)
def maximum(attrs, args): # pylint: disable=unused-variable
return True

def before():
a = relay.var("a", shape=(10, 5))
b = relay.var("b", shape=(10, 5))
r = relay.nn.relu(b)
t1 = relay.Tuple((r, r))
r2 = relay.nn.relu(t1)
m = relay.maximum(a, b)
t2 = relay.Tuple((m, r2))
f = relay.Function([a, b], t2)
return tvm.IRModule.from_expr(f)

for default_tuples, parts in [(True, 3), (False, 2)]:
result = before()
result = transform.AnnotateTarget([target_relu], default_tuples)(result)
result = transform.AnnotateTarget([target_maximum], True)(result)
result = transform.MergeCompilerRegions()(result)
result = transform.PartitionGraph()(result)
assert parts == len(
list(filter(lambda _: "target" in _.name_hint, result.get_global_vars()))
)


def test_multiple_runs():
@tvm.ir.register_op_attr("nn.relu", "target.A")
def relu(attrs, args): # pylint: disable=unused-variable
Expand Down Expand Up @@ -361,3 +397,4 @@ def before():
test_type_propagation()
test_tuple()
test_multiple_runs()
test_tuple_two_targets()

0 comments on commit f22bf74

Please sign in to comment.