Skip to content

Commit

Permalink
[Relay] Add support for TupleGetItem in op fusion (apache#2914)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and wweic committed Apr 11, 2019
1 parent f4e3837 commit 93a272c
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 4 deletions.
44 changes: 41 additions & 3 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,30 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}

void VisitExpr_(const TupleGetItemNode* op) final {
CHECK(graph_.node_map.count(op));
Node* node = graph_.node_map.at(op);
this->Update(op->tuple, node, kOpaque);
auto tuple_type = op->tuple->checked_type().as<TupleTypeNode>();
CHECK(tuple_type);
// If this tuple contain a reference type, and we fuse TupleGetItem and
// the reference, a fused function will have a tuple containing a reference
// in its parameters. But when TVM lowers a fused function, it expects all
// arguments to be a Tensor or a tuple containing only Tensors.
// To avoid modifying codegen logic, we do not allow fusing through a reference.
// The reference itself will be recursively visited via call to ExprVisitor::VisitExpr_(op)
// below and corresponding visitor methods
bool has_reference = false;
for (auto ty : tuple_type->fields) {
if (ty.as<RefTypeNode>()) {
has_reference = true;
break;
}
}
if (has_reference) {
this->Update(op->tuple, nullptr, kOpaque);
} else {
CHECK(graph_.node_map.count(op));
Node* node = graph_.node_map.at(op);
node->pattern = kInjective;
this->Update(op->tuple, node, kInjective);
}
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
Expand Down Expand Up @@ -809,6 +830,23 @@ class FuseMutator : private ExprMutator {
return TupleNode::make(new_fields);
}

Expr VisitExpr_(const TupleGetItemNode* tuple_get) {
auto* ret_group = gmap_.at(tuple_get)->FindRoot();
auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index);
if (ret_group == gmap_.at(tuple_get)) {
if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
// Isolated. This case occurs when tuple is created by an Opaque op
// e.g. multibox_transform_loc
return ExprMutator::VisitExpr_(tuple_get);
}
// A new function whose output is a tuple field access
return MakeNewFunction(ret_group, tuple_get->checked_type(), new_node);
}
// This is an intermediate node in the group
return new_node;
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
Expand Down
39 changes: 39 additions & 0 deletions tests/python/relay/test_backend_graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add
from tvm.relay.module import Module
from tvm.relay.testing.config import ctx_list

# @tq, @jr should we put this in testing ns?
def check_rts(expr, args, expected_result, mod=None):
Expand Down Expand Up @@ -127,9 +128,47 @@ def test_plan_memory():
assert len(device_types) == 1


def test_gru_like():
def unit(rnn_dim):
X = relay.var("X", shape=(1, rnn_dim))
W = relay.var("y", shape=(3 * rnn_dim, rnn_dim))
matmul = relay.nn.dense(X, W)
splitted = relay.split(matmul, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
return relay.Function([X, W], out)

def sigmoid(x):
return 1 / (1 + np.exp(-x))

def unit_numpy(X, W):
prod = np.dot(X, W.transpose())
splits = np.split(prod, indices_or_sections=3, axis=1)
return sigmoid(splits[0]) + np.tanh(splits[1]) * np.exp(splits[2])

dtype = "float32"
rnn_dim = 1000
x = np.random.rand(1, rnn_dim).astype(dtype)
y = np.random.rand(3*rnn_dim, rnn_dim).astype(dtype) * 0.01 - 0.005
out_shape = (1, rnn_dim)
z = unit(rnn_dim)

for target, ctx in ctx_list():
with relay.build_config(opt_level=2):
graph, lib, params = relay.build(z, target)
m = graph_runtime.create(graph, lib, ctx)
m.set_input("X", tvm.nd.array(x.astype(dtype)))
m.set_input("y", tvm.nd.array(y.astype(dtype)))
m.set_input(**params)
m.run()
out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
ref = unit_numpy(x, y)
tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)


if __name__ == "__main__":
test_plan_memory()
test_with_params()
test_add_op_scalar()
test_add_op_tensor()
test_add_op_broadcast()
test_gru_like()
78 changes: 77 additions & 1 deletion tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def expected(dshape):
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
print(zz.astext())


def test_stop_fusion():
Expand Down Expand Up @@ -287,6 +286,81 @@ def expected(dshape, dtype):
assert relay.ir_pass.alpha_equal(f, after)


def test_fuse_tuple_get_elemwise():
def before(dim):
X = relay.var("X", shape=(1, dim))
W = relay.var("W", shape=(3 * dim, dim))
matmul = relay.nn.dense(X, W)
splitted = relay.split(matmul, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
return relay.Function([X, W], out)

def expected(dim):
p0 = relay.var("p0", shape=(1, dim))
p1 = relay.var("p1", shape=(3 * dim, dim))
matmul = relay.nn.dense(p0, p1)
f0 = relay.Function([p0, p1], matmul)

p01 = relay.var("p01", shape=(1, 3 * dim))
splitted = relay.split(p01, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
f1 = relay.Function([p01], out)

X = relay.var("X", shape=(1, dim))
W = relay.var("W", shape=(3 * dim, dim))
y = relay.Call(f0, [X, W])
z = relay.Call(f1, [y])
return relay.Function([X, W], z)

dim = 10
z = before(dim)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dim))
assert relay.ir_pass.alpha_equal(zz, after)


def test_tuple_get_root():
def before(dim):
X = relay.var("X", shape=(1, 3 * dim))
W = relay.var("W", shape=(dim, dim))
splitted = relay.split(X, indices_or_sections=3, axis=1)
out = relay.nn.dense(splitted[0], W)
return relay.Function([X, W], out)

def expected(dim):
p0 = relay.var("p0", shape=(1, 3 * dim))
splitted = relay.split(p0, indices_or_sections=3, axis=1)
out = splitted[0]
f0 = relay.Function([p0], out)

p01 = relay.var("p01", shape=(1, dim))
p1 = relay.var("p1", shape=(dim, dim))
out = relay.nn.dense(p01, p1)
f1 = relay.Function([p01, p1], out)

X = relay.var("X", shape=(1, 3 * dim))
W = relay.var("W", shape=(dim, dim))
y = relay.Call(f0, [X])
z = relay.Call(f1, [y, W])
return relay.Function([X, W], z)

dim = 10
z = before(dim)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dim))
assert relay.ir_pass.alpha_equal(zz, after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand All @@ -295,3 +369,5 @@ def expected(dshape, dtype):
test_tuple_strided_slice()
test_stop_fusion()
test_fuse_myia_regression()
test_fuse_tuple_get_elemwise()
test_tuple_get_root()

0 comments on commit 93a272c

Please sign in to comment.