diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index ad794623246e..a0e90e2bed79 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1029,14 +1029,35 @@ def _build_cache(self, *args): arg_dict, aux_dict = dict(), dict() if self._backend: - ctx = args[0].context + # set context for inputs + _, _, ctx_set, _ = _gather_type_ctx_info(list(args)) + ctx = ctx_set.pop() if len(ctx_set) > 0 else None # get list of params in the order of out.list_arguments - arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_arguments()}) - aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_auxiliary_states()}) - # Partition the graph. - out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) + input_shapes = dict() + for name in out.list_arguments(): + if name in data_names.keys() and data_names[name] < len(args): + if isinstance(args[data_names[name]], NDArray): + arg_dict[name] = args[data_names[name]] + elif (isinstance(args[data_names[name]], symbol.Symbol) and + '__shape__' in args[data_names[name]].list_attr()): + shape_str = args[data_names[name]].list_attr()['__shape__'] + input_shapes[name] = tuple(map(int, shape_str.strip('()').split(','))) + elif name in params: + arg_dict[name] = params[name].data() + + for name in out.list_auxiliary_states(): + if name in data_names.keys() and data_names[name] < len(args): + if isinstance(args[data_names[name]], NDArray): + aux_dict[name] = args[data_names[name]] + elif (isinstance(args[data_names[name]], symbol.Symbol) and + '__shape__' in args[data_names[name]].list_attr()): + shape_str = args[data_names[name]].list_attr()['__shape__'] + input_shapes[name] = tuple(map(int, shape_str.strip('()').split(','))) + elif name in params: + aux_dict[name] = params[name].data() + + # Partition the graph + out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, input_shapes, **self._backend_opts) # convert to numpy symbol if needed if _mx_npx.is_np_array(): @@ -1079,7 +1100,7 @@ def _build_cache(self, *args): param = Parameter(name, dtype=param_data.dtype) param._var_name = name serialization_name = name # HybridBlock.export - param._load_init(param_data, args[0].context) + param._load_init(param_data, param_data.context) triple = (False, serialization_name, param) self._cached_op_args.append(triple) @@ -1181,14 +1202,11 @@ def optimize_for(self, x, *args, backend=None, backend_opts=None, clear=True, ** # do part of forward API call has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args)) - if has_symbol: - raise ValueError('Inputs must be NDArrays for the optimize_for API' - ' Please check the type of the args.\n') if not has_symbol and not has_ndarray: - raise ValueError('In HybridBlock, there must be one NDArray as input.' + raise ValueError('In HybridBlock, there must be one NDArray or one Symbol in the input.' ' Please check the type of the args.\n') if len(ctx_set) > 1: - raise ValueError('Find multiple contexts in the input, ' + raise ValueError('Found multiple contexts in the input, ' 'After hybridized, the HybridBlock only supports one input ' 'context. You can print the ele.ctx in the ' 'input arguments to inspect their contexts. ' diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index a02c0a6b80c7..959f2e026473 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1377,50 +1377,54 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // convert input args for (size_t i=0; i < in_arg_names.size(); i++) { - arg_names.push_back(in_arg_names[i].c_str()); - const NDArray &in_arg = *(in_args_ptr[i]); + if (in_args_ptr[i] != nullptr) { + arg_names.push_back(in_arg_names[i].c_str()); + const NDArray &in_arg = *(in_args_ptr[i]); #if MXNET_USE_MKLDNN == 1 - // reorder data if in MKLDNN format - if (in_arg.IsMKLDNNData()) { - in_arg.Reorder2DefaultAsync(); - in_arg.WaitToRead(); - } + // reorder data if in MKLDNN format + if (in_arg.IsMKLDNNData()) { + in_arg.Reorder2DefaultAsync(); + in_arg.WaitToRead(); + } #endif - // pull out parts of NDArray to send to backend - arg_data.push_back(in_arg.data().dptr_); - arg_shapes.push_back(in_arg.shape().data()); - arg_dims.push_back(in_arg.shape().ndim()); - arg_types.push_back(in_arg.dtype()); - arg_verIDs.push_back(in_arg.version()); - const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; - arg_dev_type.push_back(arg_ctx_str); - arg_dev_id.push_back(in_arg.ctx().real_dev_id()); + // pull out parts of NDArray to send to backend + arg_data.push_back(in_arg.data().dptr_); + arg_shapes.push_back(in_arg.shape().data()); + arg_dims.push_back(in_arg.shape().ndim()); + arg_types.push_back(in_arg.dtype()); + arg_verIDs.push_back(in_arg.version()); + const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; + arg_dev_type.push_back(arg_ctx_str); + arg_dev_id.push_back(in_arg.ctx().real_dev_id()); + } } // convert input aux for (size_t i=0; i < in_aux_names.size(); i++) { - aux_names.push_back(in_aux_names[i].c_str()); - const auto &in_aux = *(in_aux_ptr[i]); + if (in_aux_ptr[i] != nullptr) { + aux_names.push_back(in_aux_names[i].c_str()); + const auto &in_aux = *(in_aux_ptr[i]); #if MXNET_USE_MKLDNN == 1 - // reorder data if in MKLDNN format - if (in_aux.IsMKLDNNData()) { - in_aux.Reorder2DefaultAsync(); - in_aux.WaitToRead(); - } + // reorder data if in MKLDNN format + if (in_aux.IsMKLDNNData()) { + in_aux.Reorder2DefaultAsync(); + in_aux.WaitToRead(); + } #endif - // pull out parts of NDArray to send to backend - aux_data.push_back(in_aux.data().dptr_); - aux_shapes.push_back(in_aux.shape().data()); - aux_dims.push_back(in_aux.shape().ndim()); - aux_types.push_back(in_aux.dtype()); - aux_verIDs.push_back(in_aux.version()); - const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; - aux_dev_type.push_back(aux_ctx_str); - aux_dev_id.push_back(in_aux.ctx().real_dev_id()); + // pull out parts of NDArray to send to backend + aux_data.push_back(in_aux.data().dptr_); + aux_shapes.push_back(in_aux.shape().data()); + aux_dims.push_back(in_aux.shape().ndim()); + aux_types.push_back(in_aux.dtype()); + aux_verIDs.push_back(in_aux.version()); + const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; + aux_dev_type.push_back(aux_ctx_str); + aux_dev_id.push_back(in_aux.ctx().real_dev_id()); + } } // convert graph to string diff --git a/src/operator/subgraph/partitioner/custom_subgraph_property.h b/src/operator/subgraph/partitioner/custom_subgraph_property.h index 49d5a8fb683f..cd4103563220 100644 --- a/src/operator/subgraph/partitioner/custom_subgraph_property.h +++ b/src/operator/subgraph/partitioner/custom_subgraph_property.h @@ -208,26 +208,28 @@ class CustomSubgraphProperty: public SubgraphProperty { arg_dev_type.clear(); arg_dev_id.clear(); for (size_t i=0; i < in_arg_names.size(); i++) { - arg_names.push_back(in_arg_names[i].c_str()); - const NDArray &in_arg = *(in_args_ptr[i]); + if (in_args_ptr[i] != nullptr) { + arg_names.push_back(in_arg_names[i].c_str()); + const NDArray &in_arg = *(in_args_ptr[i]); #if MXNET_USE_MKLDNN == 1 - // reorder data if in MKLDNN format - if (in_arg.IsMKLDNNData()) { - in_arg.Reorder2DefaultAsync(); - in_arg.WaitToRead(); - } + // reorder data if in MKLDNN format + if (in_arg.IsMKLDNNData()) { + in_arg.Reorder2DefaultAsync(); + in_arg.WaitToRead(); + } #endif - // pull out parts of NDArray to send to backend - arg_data.push_back(in_arg.data().dptr_); - arg_shapes.push_back(in_arg.shape().data()); - arg_dims.push_back(in_arg.shape().ndim()); - arg_types.push_back(in_arg.dtype()); - arg_verIDs.push_back(in_arg.version()); - const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; - arg_dev_type.push_back(arg_ctx_str); - arg_dev_id.push_back(in_arg.ctx().real_dev_id()); + // pull out parts of NDArray to send to backend + arg_data.push_back(in_arg.data().dptr_); + arg_shapes.push_back(in_arg.shape().data()); + arg_dims.push_back(in_arg.shape().ndim()); + arg_types.push_back(in_arg.dtype()); + arg_verIDs.push_back(in_arg.version()); + const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; + arg_dev_type.push_back(arg_ctx_str); + arg_dev_id.push_back(in_arg.ctx().real_dev_id()); + } } // convert input aux @@ -240,26 +242,28 @@ class CustomSubgraphProperty: public SubgraphProperty { aux_dev_type.clear(); aux_dev_id.clear(); for (size_t i=0; i < in_aux_names.size(); i++) { - aux_names.push_back(in_aux_names[i].c_str()); - const auto &in_aux = *(in_aux_ptr[i]); + if (in_aux_ptr[i] != nullptr) { + aux_names.push_back(in_aux_names[i].c_str()); + const auto &in_aux = *(in_aux_ptr[i]); #if MXNET_USE_MKLDNN == 1 - // reorder data if in MKLDNN format - if (in_aux.IsMKLDNNData()) { - in_aux.Reorder2DefaultAsync(); - in_aux.WaitToRead(); - } + // reorder data if in MKLDNN format + if (in_aux.IsMKLDNNData()) { + in_aux.Reorder2DefaultAsync(); + in_aux.WaitToRead(); + } #endif - // pull out parts of NDArray to send to backend - aux_data.push_back(in_aux.data().dptr_); - aux_shapes.push_back(in_aux.shape().data()); - aux_dims.push_back(in_aux.shape().ndim()); - aux_types.push_back(in_aux.dtype()); - aux_verIDs.push_back(in_aux.version()); - const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; - aux_dev_type.push_back(aux_ctx_str); - aux_dev_id.push_back(in_aux.ctx().real_dev_id()); + // pull out parts of NDArray to send to backend + aux_data.push_back(in_aux.data().dptr_); + aux_shapes.push_back(in_aux.shape().data()); + aux_dims.push_back(in_aux.shape().ndim()); + aux_types.push_back(in_aux.dtype()); + aux_verIDs.push_back(in_aux.version()); + const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; + aux_dev_type.push_back(aux_ctx_str); + aux_dev_id.push_back(in_aux.ctx().real_dev_id()); + } } // remove all graph attrs, some cannot be saved to json @@ -285,13 +289,17 @@ class CustomSubgraphProperty: public SubgraphProperty { for (unsigned oid = 0; oid < node->num_outputs(); oid++) { const uint32_t out_entry_id = indexed_graph.entry_id(nid, oid); mxnet::TShape& shape = shapes[out_entry_id]; - ss << shape; + if (shape.ndim() == -1) + ss << "[None]"; + else + ss << shape; if (oid < node->num_outputs()-1) ss << ","; } ss << "]"; node->attrs.dict[MX_STR_SHAPE] = ss.str(); } } + // set dtype attrs for each node in the graph if (g.HasAttr("dtype")) { std::vector dtypes = g.GetAttr >("dtype"); diff --git a/tests/python/unittest/test_extensions.py b/tests/python/unittest/test_extensions.py index 52f999571e13..8d9468024094 100644 --- a/tests/python/unittest/test_extensions.py +++ b/tests/python/unittest/test_extensions.py @@ -166,16 +166,23 @@ def test_subgraph(): # check that result matches one executed by MXNet assert_almost_equal(out[0].asnumpy(), out4[0].asnumpy(), rtol=1e-3, atol=1e-3) - # Gluon Hybridize partitioning with shapes/types + # Gluon Hybridize partitioning with sym.var sym_block2 = nn.SymbolBlock(sym, [a,b]) sym_block2.initialize() + a_var = mx.sym.var('a',shape=(3,2)) + b_var = mx.sym.var('b',shape=(3,2)) + sym_block2.optimize_for(a_var, b_var, backend='myProp') + + # Gluon Hybridize partitioning with shapes/types + sym_block3 = nn.SymbolBlock(sym, [a,b]) + sym_block3.initialize() a_data = mx.nd.ones((3,2)) b_data = mx.nd.ones((3,2)) - sym_block2.optimize_for(a_data, b_data, backend='myProp') - sym_block2.export('optimized') - sym_block3 = nn.SymbolBlock.imports('optimized-symbol.json',['a','b'], + sym_block3.optimize_for(a_data, b_data, backend='myProp') + sym_block3.export('optimized') + sym_block4 = nn.SymbolBlock.imports('optimized-symbol.json',['a','b'], 'optimized-0000.params') - out5 = sym_block3(a_data, b_data) + out5 = sym_block4(a_data, b_data) # check that result matches one executed by MXNet assert_almost_equal(out[0].asnumpy(), out5[0].asnumpy(), rtol=1e-3, atol=1e-3)