diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 66ff9caf4ae4..c7b16da9036c 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -261,9 +261,30 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } void VisitExpr_(const TupleGetItemNode* op) final { - CHECK(graph_.node_map.count(op)); - Node* node = graph_.node_map.at(op); - this->Update(op->tuple, node, kOpaque); + auto tuple_type = op->tuple->checked_type().as(); + CHECK(tuple_type); + // If this tuple contain a reference type, and we fuse TupleGetItem and + // the reference, a fused function will have a tuple containing a reference + // in its parameters. But when TVM lowers a fused function, it expects all + // arguments to be a Tensor or a tuple containing only Tensors. + // To avoid modifying codegen logic, we do not allow fusing through a reference. + // The reference itself will be recursively visited via call to ExprVisitor::VisitExpr_(op) + // below and corresponding visitor methods + bool has_reference = false; + for (auto ty : tuple_type->fields) { + if (ty.as()) { + has_reference = true; + break; + } + } + if (has_reference) { + this->Update(op->tuple, nullptr, kOpaque); + } else { + CHECK(graph_.node_map.count(op)); + Node* node = graph_.node_map.at(op); + node->pattern = kInjective; + this->Update(op->tuple, node, kInjective); + } ExprVisitor::VisitExpr_(op); this->AddNode(op); } @@ -809,6 +830,23 @@ class FuseMutator : private ExprMutator { return TupleNode::make(new_fields); } + Expr VisitExpr_(const TupleGetItemNode* tuple_get) { + auto* ret_group = gmap_.at(tuple_get)->FindRoot(); + auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0]; + auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index); + if (ret_group == gmap_.at(tuple_get)) { + if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) { + // Isolated. This case occurs when tuple is created by an Opaque op + // e.g. multibox_transform_loc + return ExprMutator::VisitExpr_(tuple_get); + } + // A new function whose output is a tuple field access + return MakeNewFunction(ret_group, tuple_get->checked_type(), new_node); + } + // This is an intermediate node in the group + return new_node; + } + Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { const GroupInfo& ginfo = ginfo_[group]; auto func = FunctionNode::make(ginfo.params, body, ret_type, {}); diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 434b0e6ddfa1..56da263c9b4e 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -7,6 +7,7 @@ from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.op import add from tvm.relay.module import Module +from tvm.relay.testing.config import ctx_list # @tq, @jr should we put this in testing ns? def check_rts(expr, args, expected_result, mod=None): @@ -127,9 +128,47 @@ def test_plan_memory(): assert len(device_types) == 1 +def test_gru_like(): + def unit(rnn_dim): + X = relay.var("X", shape=(1, rnn_dim)) + W = relay.var("y", shape=(3 * rnn_dim, rnn_dim)) + matmul = relay.nn.dense(X, W) + splitted = relay.split(matmul, indices_or_sections=3, axis=1) + out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) + return relay.Function([X, W], out) + + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + def unit_numpy(X, W): + prod = np.dot(X, W.transpose()) + splits = np.split(prod, indices_or_sections=3, axis=1) + return sigmoid(splits[0]) + np.tanh(splits[1]) * np.exp(splits[2]) + + dtype = "float32" + rnn_dim = 1000 + x = np.random.rand(1, rnn_dim).astype(dtype) + y = np.random.rand(3*rnn_dim, rnn_dim).astype(dtype) * 0.01 - 0.005 + out_shape = (1, rnn_dim) + z = unit(rnn_dim) + + for target, ctx in ctx_list(): + with relay.build_config(opt_level=2): + graph, lib, params = relay.build(z, target) + m = graph_runtime.create(graph, lib, ctx) + m.set_input("X", tvm.nd.array(x.astype(dtype))) + m.set_input("y", tvm.nd.array(y.astype(dtype))) + m.set_input(**params) + m.run() + out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy() + ref = unit_numpy(x, y) + tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": test_plan_memory() test_with_params() test_add_op_scalar() test_add_op_tensor() test_add_op_broadcast() + test_gru_like() diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 634d69bae823..5df6ad7d5226 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -217,7 +217,6 @@ def expected(dshape): assert not relay.ir_pass.free_vars(zz) after = relay.ir_pass.infer_type(expected(dshape)) assert relay.ir_pass.alpha_equal(zz, after) - print(zz.astext()) def test_stop_fusion(): @@ -287,6 +286,81 @@ def expected(dshape, dtype): assert relay.ir_pass.alpha_equal(f, after) +def test_fuse_tuple_get_elemwise(): + def before(dim): + X = relay.var("X", shape=(1, dim)) + W = relay.var("W", shape=(3 * dim, dim)) + matmul = relay.nn.dense(X, W) + splitted = relay.split(matmul, indices_or_sections=3, axis=1) + out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) + return relay.Function([X, W], out) + + def expected(dim): + p0 = relay.var("p0", shape=(1, dim)) + p1 = relay.var("p1", shape=(3 * dim, dim)) + matmul = relay.nn.dense(p0, p1) + f0 = relay.Function([p0, p1], matmul) + + p01 = relay.var("p01", shape=(1, 3 * dim)) + splitted = relay.split(p01, indices_or_sections=3, axis=1) + out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) + f1 = relay.Function([p01], out) + + X = relay.var("X", shape=(1, dim)) + W = relay.var("W", shape=(3 * dim, dim)) + y = relay.Call(f0, [X, W]) + z = relay.Call(f1, [y]) + return relay.Function([X, W], z) + + dim = 10 + z = before(dim) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(dim)) + assert relay.ir_pass.alpha_equal(zz, after) + + +def test_tuple_get_root(): + def before(dim): + X = relay.var("X", shape=(1, 3 * dim)) + W = relay.var("W", shape=(dim, dim)) + splitted = relay.split(X, indices_or_sections=3, axis=1) + out = relay.nn.dense(splitted[0], W) + return relay.Function([X, W], out) + + def expected(dim): + p0 = relay.var("p0", shape=(1, 3 * dim)) + splitted = relay.split(p0, indices_or_sections=3, axis=1) + out = splitted[0] + f0 = relay.Function([p0], out) + + p01 = relay.var("p01", shape=(1, dim)) + p1 = relay.var("p1", shape=(dim, dim)) + out = relay.nn.dense(p01, p1) + f1 = relay.Function([p01, p1], out) + + X = relay.var("X", shape=(1, 3 * dim)) + W = relay.var("W", shape=(dim, dim)) + y = relay.Call(f0, [X]) + z = relay.Call(f1, [y, W]) + return relay.Function([X, W], z) + + dim = 10 + z = before(dim) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(dim)) + assert relay.ir_pass.alpha_equal(zz, after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -295,3 +369,5 @@ def expected(dshape, dtype): test_tuple_strided_slice() test_stop_fusion() test_fuse_myia_regression() + test_fuse_tuple_get_elemwise() + test_tuple_get_root()