diff --git a/src/operator/subgraph/subgraph_op.cc b/src/operator/subgraph/subgraph_op.cc index fb71ea0f6afb..d0a55f8a003f 100644 --- a/src/operator/subgraph/subgraph_op.cc +++ b/src/operator/subgraph/subgraph_op.cc @@ -154,10 +154,27 @@ bool SubgraphOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size()); CHECK_EQ(idx_g.outputs().size(), out_shapes->size()); // TODO: make sure shape inputs matches the order from in_shapes - nnvm::ShapeVector shape_inputs = *in_shapes; - imperative::CheckAndInferShape(&g, std::move(shape_inputs), true); - const nnvm::ShapeVector& shapes = g.GetAttr("shape"); - const std::vector& input_nids = idx_g.input_nodes(); + + // Put the input and output shapes to the shape vector. + nnvm::ShapeVector shapes(idx_g.num_node_entries()); + const auto &input_nids = idx_g.input_nodes(); + CHECK_EQ(input_nids.size(), in_shapes->size()); + for (size_t i = 0; i < in_shapes->size(); i++) { + auto eid = idx_g.entry_id(input_nids[i], 0); + shapes[eid] = in_shapes->at(i); + } + CHECK_EQ(g.outputs.size(), out_shapes->size()); + for (size_t i = 0; i < out_shapes->size(); i++) { + auto eid = idx_g.entry_id(g.outputs[i]); + shapes[eid] = out_shapes->at(i); + } + + // Infer shape of the graph. + g.attrs["shape"] = std::make_shared(std::move(shapes)); + g = exec::InferShape(std::move(g)); + + // Copy the inferred shape back to the input shapes and the output shapes. + shapes = g.GetAttr("shape"); // assign to in_shapes for (size_t i = 0; i < in_shapes->size(); ++i) { const auto eid = idx_g.entry_id(input_nids[i], 0); @@ -168,7 +185,8 @@ bool SubgraphOpShape(const nnvm::NodeAttrs& attrs, const auto eid = idx_g.entry_id(g.outputs[i]); SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]); } - return true; + // Check if we have inferred the shapes correctly. + return g.GetAttr("shape_num_unknown_nodes") == 0; } bool SubgraphOpType(const nnvm::NodeAttrs& attrs, @@ -181,10 +199,26 @@ bool SubgraphOpType(const nnvm::NodeAttrs& attrs, CHECK_EQ(idx_g.input_nodes().size(), in_types->size()); CHECK_EQ(idx_g.outputs().size(), out_types->size()); // TODO: make sure type inputs matches the order from in_types - nnvm::DTypeVector type_inputs = *in_types; - imperative::CheckAndInferType(&g, std::move(type_inputs), true); - const nnvm::DTypeVector& types = g.GetAttr("dtype"); - const std::vector& input_nids = idx_g.input_nodes(); + + // Put the input and output data types to the dtype vector. + nnvm::DTypeVector types(idx_g.num_node_entries(), -1); + const auto &input_nids = idx_g.input_nodes(); + CHECK_EQ(input_nids.size(), in_types->size()); + for (size_t i = 0; i < in_types->size(); i++) { + auto eid = idx_g.entry_id(input_nids[i], 0); + types[eid] = in_types->at(i); + } + CHECK_EQ(g.outputs.size(), out_types->size()); + for (size_t i = 0; i < out_types->size(); i++) { + auto eid = idx_g.entry_id(g.outputs[i]); + types[eid] = out_types->at(i); + } + + // Infer data type of the graph. + g.attrs["dtype"] = std::make_shared(std::move(types)); + g = exec::InferType(std::move(g)); + + types = g.GetAttr("dtype"); // assign to in_types for (size_t i = 0; i < in_types->size(); ++i) { const auto eid = idx_g.entry_id(input_nids[i], 0); @@ -195,7 +229,8 @@ bool SubgraphOpType(const nnvm::NodeAttrs& attrs, const auto eid = idx_g.entry_id(g.outputs[i]); TYPE_ASSIGN_CHECK(*out_types, i, types[eid]); } - return true; + // Check if we have inferred the dtypes correctly. + return g.GetAttr("dtype_num_unknown_nodes") == 0; } bool SubgraphOpStorageType(const nnvm::NodeAttrs& attrs, @@ -209,13 +244,33 @@ bool SubgraphOpStorageType(const nnvm::NodeAttrs& attrs, const auto& idx_g = g.indexed_graph(); CHECK_EQ(idx_g.input_nodes().size(), in_stypes->size()); CHECK_EQ(idx_g.outputs().size(), out_stypes->size()); - exec::DevMaskVector dev_masks(idx_g.num_nodes(), dev_mask); + exec::DevMaskVector dev_masks(idx_g.num_node_entries(), dev_mask); // TODO: make sure type inputs matches the order from in_types - StorageTypeVector stype_inputs = *in_stypes; - imperative::CheckAndInferStorageType(&g, std::move(dev_masks), - std::move(stype_inputs), true); - const StorageTypeVector& stypes = g.GetAttr("storage_type"); - const std::vector& input_nids = idx_g.input_nodes(); + + // Put the input and output storages to the storage vector. + nnvm::StorageVector stypes(idx_g.num_node_entries(), exec::kBadStorageID); + const auto &input_nids = idx_g.input_nodes(); + CHECK_EQ(input_nids.size(), in_stypes->size()); + for (size_t i = 0; i < in_stypes->size(); i++) { + auto eid = idx_g.entry_id(input_nids[i], 0); + stypes[eid] = in_stypes->at(i); + } + CHECK_EQ(g.outputs.size(), out_stypes->size()); + for (size_t i = 0; i < out_stypes->size(); i++) { + auto eid = idx_g.entry_id(g.outputs[i]); + stypes[eid] = out_stypes->at(i); + } + + // Infer storage type of the graph. + bool dev_match = g.attrs.count("dev_mask") && + g.GetAttr("dev_mask") == dev_masks; + if (!dev_match) { + g.attrs["dev_mask"] = std::make_shared(std::move(dev_masks)); + } + g.attrs["storage_type"] = std::make_shared(std::move(stypes)); + g = exec::InferStorageType(std::move(g)); + + stypes = g.GetAttr("storage_type"); // assign to in_types for (size_t i = 0; i < in_stypes->size(); ++i) { const auto eid = idx_g.entry_id(input_nids[i], 0); @@ -228,7 +283,8 @@ bool SubgraphOpStorageType(const nnvm::NodeAttrs& attrs, const auto eid = idx_g.entry_id(g.outputs[i]); STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, stypes[eid]); } - return true; + // Check if we have inferred the storages correctly. + return g.GetAttr("storage_type_num_unknown_nodes") == 0; } void SubgraphOpForward(const OpStatePtr& state_ptr, diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py index 444907cc1d14..5d328f6fd977 100644 --- a/tests/python/unittest/test_subgraph_op.py +++ b/tests/python/unittest/test_subgraph_op.py @@ -5,45 +5,46 @@ import numpy as np -def test_subgraph_op_whole_graph(): - data1 = mx.sym.Variable('data1', shape=(3, 3, 10, 10), dtype=np.float32) - data2 = mx.sym.Variable('data2', shape=(1, 0, 2, 2)) - data3 = mx.sym.sin(data2) - - regular_sym = mx.sym.Convolution(data=data1, weight=data3, kernel=(2, 2), num_filter=1) - - out = SymbolHandle() - - op_names = [] - op_names = [mx.sym.sin.__name__, mx.sym.Convolution.__name__] - - check_call(_LIB.MXPartitionGraph(regular_sym.handle, mx_uint(len(op_names)), - c_str_array(op_names), ctypes.byref(out))) - - subgraph_sym = Symbol(out) - assert regular_sym.list_inputs() == subgraph_sym.list_inputs() - input_names = subgraph_sym.list_inputs() +def test_subgraph(): + def get_graph(): + data1 = mx.sym.Variable('data1', shape=(3, 3, 10, 10), dtype=np.float32) + data2 = mx.sym.Variable('data2', shape=(1, 0, 2, 2)) + data3 = mx.sym.sin(data2) + conv = mx.sym.Convolution(data=data1, weight=data3, kernel=(2, 2), num_filter=1) + rets = [] + rets.append((conv, [])) + rets.append((conv, [mx.sym.sin.__name__])) + rets.append((conv, [mx.sym.Convolution.__name__])) + rets.append((conv, [mx.sym.sin.__name__, mx.sym.Convolution.__name__])) + return rets + + for regular_sym, op_names in get_graph(): + input_names = regular_sym.list_inputs() + shapes = regular_sym.infer_shape() + types = regular_sym.infer_type() + out = SymbolHandle() - #assert regular_sym.list_outputs() == subgraph_sym.list_outputs() - print(subgraph_sym.list_outputs()) - assert regular_sym.infer_shape() == subgraph_sym.infer_shape() - assert regular_sym.infer_type() == subgraph_sym.infer_type() + check_call(_LIB.MXPartitionGraph(regular_sym.handle, mx_uint(len(op_names)), + c_str_array(op_names), ctypes.byref(out))) + subgraph_sym = Symbol(out) + assert input_names == subgraph_sym.list_inputs() - regular_exec = regular_sym.simple_bind(ctx=mx.cpu(), grad_req='null') - subgraph_exec = subgraph_sym.simple_bind(ctx=mx.cpu(), grad_req='null') + print(subgraph_sym.list_outputs()) + assert shapes == subgraph_sym.infer_shape() + assert types == subgraph_sym.infer_type() - regular_exec.arg_dict[data1.name][:] = mx.nd.random.normal(shape=regular_exec.arg_dict[data1.name].shape) - regular_exec.arg_dict[data2.name][:] = mx.nd.random.normal(shape=regular_exec.arg_dict[data2.name].shape) - regular_exec.arg_dict[input_names[-1]][:] = mx.nd.random.normal(shape=regular_exec.arg_dict[input_names[-1]].shape) + regular_exec = regular_sym.simple_bind(ctx=mx.cpu(), grad_req='null') + subgraph_exec = subgraph_sym.simple_bind(ctx=mx.cpu(), grad_req='null') - subgraph_exec.arg_dict[data1.name][:] = regular_exec.arg_dict[data1.name] - subgraph_exec.arg_dict[data2.name][:] = regular_exec.arg_dict[data2.name] - subgraph_exec.arg_dict[input_names[-1]][:] = regular_exec.arg_dict[input_names[-1]] + for name in input_names: + regular_exec.arg_dict[name][:] = mx.nd.random.normal( + shape=regular_exec.arg_dict[name].shape) + subgraph_exec.arg_dict[name][:] = regular_exec.arg_dict[name] - subgraph_exec.forward() - regular_exec.forward() - mx.nd.waitall() - assert (subgraph_exec.outputs[0] - regular_exec.outputs[0]).abs().sum().asscalar() == 0.0 + subgraph_exec.forward() + regular_exec.forward() + mx.nd.waitall() + assert (subgraph_exec.outputs[0] - regular_exec.outputs[0]).abs().sum().asscalar() == 0.0 def test_input_name_order():