Skip to content

Commit

Permalink
[Relay][heterogeneous pass] remove on_device op after annotation (apa…
Browse files Browse the repository at this point in the history
…che#3204)

* remove on_device op after annotation

* Update src/relay/pass/device_annotation.cc

Co-Authored-By: MORINAGA <[email protected]>
  • Loading branch information
2 people authored and Wei Chen committed Jun 26, 2019
1 parent 04d2b04 commit 1bf6184
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 23 deletions.
47 changes: 46 additions & 1 deletion src/relay/pass/device_annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,52 @@ class DeviceInfo {

Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) {
RewriteAnnotation rewrote = RewriteAnnotation();
return rewrote.Rewrite(expr, fallback_device);
Expr new_expr = rewrote.Rewrite(expr, fallback_device);

// Remove OnDevice operators. Note that these operators are only present at the
// leaves after annotation. Therefore, we can simply reconstruct the
// Function/Expr by removing them directly.
if (const FunctionNode* fn = new_expr.as<FunctionNode>()) {
auto params = fn->params;
auto body = fn->body;
std::vector<Expr> new_body;
if (const TupleNode* tuple = body.as<TupleNode>()) {
for (const auto& field : tuple->fields) {
if (!IsOnDeviceNode(field.operator->())) {
new_body.push_back(field);
}
}
CHECK_GT(new_body.size(), 0U);
if (new_body.size() == 1) {
return FunctionNode::make(params, new_body[0], Type(nullptr),
fn->type_params, fn->attrs);
} else if (tuple->fields.size() == new_body.size()) {
return new_expr;
} else {
Tuple tuple_body = TupleNode::make(new_body);
return FunctionNode::make(params, tuple_body, Type(nullptr),
fn->type_params, fn->attrs);
}
} else {
return new_expr;
}
} else if (const TupleNode* tuple = new_expr.as<TupleNode>()) {
std::vector<Expr> new_fields;
for (const auto& field : tuple->fields) {
if (!IsOnDeviceNode(field.operator->())) {
new_fields.push_back(field);
}
}
CHECK_GT(new_fields.size(), 0U);
if (tuple->fields.size() == new_fields.size()) {
return new_fields.size() == 1 ? new_fields[0] : new_expr;
} else {
return new_fields.size() == 1 ? new_fields[0]
: TupleNode::make(new_fields);
}
} else {
return new_expr;
}
}

Map<Expr, Integer> CollectDeviceInfo(const Expr& expr) {
Expand Down
61 changes: 39 additions & 22 deletions tests/python/relay/test_pass_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def annotated():
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[2]),
func.body[2])
return func

def expected():
add = relay.add(x, y)
Expand All @@ -58,6 +56,35 @@ def expected():
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)


def test_annotate_expr():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
x = relay.var("x", shape=(3,))
y = relay.var("y", shape=(3,))
z = relay.var("z", shape=(3,))

def annotated():
add = relay.add(x, y)
_add = relay.annotation.on_device(add, ctx1)
sub = relay.subtract(add, z)
_sub = relay.annotation.on_device(sub, ctx2)
expr = relay.Tuple([sub, _add, _sub])
expr = relay.ir_pass.infer_type(expr)
expr = relay.ir_pass.rewrite_annotated_ops(expr,
ctx1.device_type)
return expr

def expected():
add = relay.add(x, y)
copy_add_sub = relay.device_copy(add, ctx1, ctx2)
sub = relay.subtract(copy_add_sub, z)
return sub

annotated_expr = relay.ir_pass.infer_type(annotated())
expected_expr = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.graph_equal(annotated_expr, expected_expr)


def test_annotate_all():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
Expand All @@ -77,9 +104,7 @@ def annotated():
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[2]),
func.body[2])
return func

def expected():
add = relay.add(x, y)
Expand All @@ -91,6 +116,7 @@ def expected():
expected_func = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)


def test_annotate_none():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
Expand Down Expand Up @@ -174,9 +200,7 @@ def annotated():
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
tvm.context(3).device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[4]),
func.body[4])
return func

def expected():
conv2d_1 = relay.nn.conv2d(
Expand All @@ -202,7 +226,7 @@ def expected():
kernel_size=(3, 3),
padding=(1, 1))

func = relay.Function([data1, weight, data2], conv2d_3)
func = relay.Function([data1, data2, weight], conv2d_3)
return func

def check_storage_and_device_types():
Expand Down Expand Up @@ -306,9 +330,7 @@ def annotated():
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[2]),
func.body[2])
return func

def expected():
add = relay.add(x, y)
Expand Down Expand Up @@ -358,9 +380,7 @@ def annotated():
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[5]),
func.body[5])
return func

annotated_func = annotated()
expected_func = get_func()
Expand All @@ -386,9 +406,7 @@ def annotated():
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[1]),
func.body[1])
return func

def expected():
add = relay.add(x, y)
Expand Down Expand Up @@ -462,9 +480,7 @@ def annotated():
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[3]),
func.body[3])
return func

def expected():
add = relay.add(a, b)
Expand Down Expand Up @@ -506,6 +522,7 @@ def test_check_run():

if __name__ == "__main__":
test_redundant_annotation()
test_annotate_expr()
test_annotate_all()
test_annotate_none()
test_conv_network()
Expand Down

0 comments on commit 1bf6184

Please sign in to comment.