diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index e5dd1f21820d..9e65fb25c203 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -2235,6 +2235,7 @@ MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle, const mx_uint num_options, const char** keys, const char** vals, + bool deferred_infer, int* new_args_cnt, NDArrayHandle** new_args_handle, char*** new_arg_names_handle, diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index d4ff7954c181..2fb5da38b00e 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1447,7 +1447,7 @@ def _gen_atomic_symbol(self): # pylint: disable=too-many-locals - def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): + def optimize_for(self, backend, args=None, aux=None, ctx=None, deferred_infer=False, **kwargs): """Partitions current symbol and optimizes it for a given backend, returns new partitioned symbol. @@ -1460,19 +1460,25 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): Input arguments to the symbol, required to infer shapes/types before partitioning - If type is a list of `NDArray`, the order is the same as that of `list_arguments()`. + Non defined arguments' `NDArray`s should be marked as None to respected the order. - If type is a dict of str to `NDArray`, then it maps the name of arguments - to the corresponding `NDArray`. + to the corresponding `NDArray`. Non defined arguments' `NDArray`s don't have to be + specified in the dict. aux : list of NDArray or dict of str to NDArray, optional Input auxiliary arguments to the symbol - - If type is a list of `NDArray`, the order is the same as that of `list_arguments()`. + - If type is a list of `NDArray`, the order is the same as that of `list_auxiliary_states()`. - If type is a dict of str to `NDArray`, then it maps the name of arguments to the corresponding `NDArray`. ctx : Context, optional Device context, used to infer stypes + deferred_infer : bool, optional + If True, the optimization skips the shape, type and storage type inference pass. + (Deferring it to `bind`.) + kwargs : optional arguments Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty` @@ -1489,7 +1495,7 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): args_handle = c_array(NDArrayHandle, []) else: args_handle, args_ = self._get_ndarray_inputs('args', args, - self.list_arguments(), False) + self.list_arguments(), True) if aux is None or len(aux) == 0: aux_ = [] @@ -1524,6 +1530,7 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): mx_uint(len(key_list)), c_str_array(key_list), c_str_array(val_list), + ctypes.c_bool(deferred_infer), ctypes.byref(new_args_size), ctypes.byref(new_args_handle), ctypes.byref(new_arg_names), diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 74455388d0e8..493e9738f4f9 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -1360,6 +1360,7 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, const mx_uint num_options, const char** keys, const char** vals, + bool deferred_infer, int* new_args_cnt, NDArrayHandle** new_args_handle, char*** new_arg_names_handle, @@ -1383,47 +1384,47 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, if (args_len || aux_len) { NDArray **in_args_ptr = reinterpret_cast(in_args_handle); NDArray **in_aux_ptr = reinterpret_cast(in_aux_handle); - Context default_ctx = Context::Create(static_cast(dev_type), 0); - mxnet::ShapeVector arg_shapes(args_len + aux_len); - nnvm::DTypeVector arg_dtypes(args_len + aux_len); - StorageTypeVector arg_stypes(args_len + aux_len); - size_t args_top = 0, aux_top = 0; - // loop over inputs to symbol in order and add to args/aux if mutable - for (size_t i = 0; i < num_forward_inputs; ++i) { - const uint32_t nid = indexed_graph.input_nodes().at(i); - if (mutable_nodes.count(nid)) { - CHECK_LT(aux_top, aux_len) - << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for"; - const auto &in_arg = *(in_aux_ptr[aux_top++]); - arg_shapes[i] = in_arg.shape(); - arg_dtypes[i] = in_arg.dtype(); - arg_stypes[i] = in_arg.storage_type(); - } else { - CHECK_LT(args_top, args_len) - << "Cannot find arg '" << input_names[i] << "' in provided args to optimize_for"; - const auto &in_arg = *(in_args_ptr[args_top++]); - arg_shapes[i] = in_arg.shape(); - arg_dtypes[i] = in_arg.dtype(); - arg_stypes[i] = in_arg.storage_type(); + if (!deferred_infer) { + Context default_ctx = Context::Create(static_cast(dev_type), 0); + mxnet::ShapeVector arg_shapes(args_len + aux_len); + nnvm::DTypeVector arg_dtypes(args_len + aux_len); + StorageTypeVector arg_stypes(args_len + aux_len); + size_t args_top = 0, aux_top = 0; + // loop over inputs to symbol in order and add to args/aux if mutable + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = indexed_graph.input_nodes().at(i); + if (mutable_nodes.count(nid)) { + CHECK_LT(aux_top, aux_len) + << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for"; + if (in_aux_ptr[aux_top] != nullptr) { + const auto &in_arg = *(in_aux_ptr[aux_top]); + arg_shapes[i] = in_arg.shape(); + arg_dtypes[i] = in_arg.dtype(); + arg_stypes[i] = in_arg.storage_type(); + } + aux_top++; + } else { + CHECK_LT(args_top, args_len) + << "Cannot find arg '" << input_names[i] << "' in provided args to optimize_for"; + if (in_args_ptr[args_top] != nullptr) { + const auto &in_arg = *(in_args_ptr[args_top]); + arg_shapes[i] = in_arg.shape(); + arg_dtypes[i] = in_arg.dtype(); + arg_stypes[i] = in_arg.storage_type(); + } + args_top++; + } } - } - g.attrs["context"] = std::make_shared( - exec::ContextVector(indexed_graph.num_nodes(), default_ctx)); + g.attrs["context"] = std::make_shared( + exec::ContextVector(indexed_graph.num_nodes(), default_ctx)); - // infer shapes - g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); - // infer dtypes - g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); - if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { - common::HandleInferTypeError(num_forward_inputs, indexed_graph, - g.GetAttr("dtype")); - } - // infer stypes - g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); - if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { - common::HandleInferStorageTypeError(num_forward_inputs, indexed_graph, - g.GetAttr("storage_type")); + // infer shapes + g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); + // infer dtypes + g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); + // infer stypes + g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); } // set args/aux as attributes on graph so that subgraph property can use them std::vector arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs);