Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add missing args/aux support in optimize_for and deferred inference o…
Browse files Browse the repository at this point in the history
…ption

Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed May 15, 2020
1 parent 47a38d1 commit 308f0f2
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 42 deletions.
1 change: 1 addition & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2235,6 +2235,7 @@ MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
const mx_uint num_options,
const char** keys,
const char** vals,
bool deferred_infer,
int* new_args_cnt,
NDArrayHandle** new_args_handle,
char*** new_arg_names_handle,
Expand Down
15 changes: 11 additions & 4 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ def _gen_atomic_symbol(self):


# pylint: disable=too-many-locals
def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
def optimize_for(self, backend, args=None, aux=None, ctx=None, deferred_infer=False, **kwargs):
"""Partitions current symbol and optimizes it for a given backend,
returns new partitioned symbol.
Expand All @@ -1460,19 +1460,25 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
Input arguments to the symbol, required to infer shapes/types before partitioning
- If type is a list of `NDArray`, the order is the same as that of `list_arguments()`.
Non defined arguments' `NDArray`s should be marked as None to respected the order.
- If type is a dict of str to `NDArray`, then it maps the name of arguments
to the corresponding `NDArray`.
to the corresponding `NDArray`. Non defined arguments' `NDArray`s don't have to be
specified in the dict.
aux : list of NDArray or dict of str to NDArray, optional
Input auxiliary arguments to the symbol
- If type is a list of `NDArray`, the order is the same as that of `list_arguments()`.
- If type is a list of `NDArray`, the order is the same as that of `list_auxiliary_states()`.
- If type is a dict of str to `NDArray`, then it maps the name of arguments
to the corresponding `NDArray`.
ctx : Context, optional
Device context, used to infer stypes
deferred_infer : bool, optional
If True, the optimization skips the shape, type and storage type inference pass.
(Deferring it to `bind`.)
kwargs : optional arguments
Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
Expand All @@ -1489,7 +1495,7 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
args_handle = c_array(NDArrayHandle, [])
else:
args_handle, args_ = self._get_ndarray_inputs('args', args,
self.list_arguments(), False)
self.list_arguments(), True)

if aux is None or len(aux) == 0:
aux_ = []
Expand Down Expand Up @@ -1524,6 +1530,7 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
mx_uint(len(key_list)),
c_str_array(key_list),
c_str_array(val_list),
ctypes.c_bool(deferred_infer),
ctypes.byref(new_args_size),
ctypes.byref(new_args_handle),
ctypes.byref(new_arg_names),
Expand Down
77 changes: 39 additions & 38 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,7 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
const mx_uint num_options,
const char** keys,
const char** vals,
bool deferred_infer,
int* new_args_cnt,
NDArrayHandle** new_args_handle,
char*** new_arg_names_handle,
Expand All @@ -1383,47 +1384,47 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
if (args_len || aux_len) {
NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(args_len + aux_len);
nnvm::DTypeVector arg_dtypes(args_len + aux_len);
StorageTypeVector arg_stypes(args_len + aux_len);
size_t args_top = 0, aux_top = 0;
// loop over inputs to symbol in order and add to args/aux if mutable
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = indexed_graph.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_len)
<< "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
const auto &in_arg = *(in_aux_ptr[aux_top++]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
} else {
CHECK_LT(args_top, args_len)
<< "Cannot find arg '" << input_names[i] << "' in provided args to optimize_for";
const auto &in_arg = *(in_args_ptr[args_top++]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
if (!deferred_infer) {
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(args_len + aux_len);
nnvm::DTypeVector arg_dtypes(args_len + aux_len);
StorageTypeVector arg_stypes(args_len + aux_len);
size_t args_top = 0, aux_top = 0;
// loop over inputs to symbol in order and add to args/aux if mutable
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = indexed_graph.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_len)
<< "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
if (in_aux_ptr[aux_top] != nullptr) {
const auto &in_arg = *(in_aux_ptr[aux_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
}
aux_top++;
} else {
CHECK_LT(args_top, args_len)
<< "Cannot find arg '" << input_names[i] << "' in provided args to optimize_for";
if (in_args_ptr[args_top] != nullptr) {
const auto &in_arg = *(in_args_ptr[args_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
}
args_top++;
}
}
}

g.attrs["context"] = std::make_shared<nnvm::any>(
exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
g.attrs["context"] = std::make_shared<nnvm::any>(
exec::ContextVector(indexed_graph.num_nodes(), default_ctx));

// infer shapes
g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
// infer dtypes
g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
common::HandleInferTypeError(num_forward_inputs, indexed_graph,
g.GetAttr<nnvm::DTypeVector>("dtype"));
}
// infer stypes
g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
common::HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
g.GetAttr<StorageTypeVector>("storage_type"));
// infer shapes
g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
// infer dtypes
g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
// infer stypes
g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
}
// set args/aux as attributes on graph so that subgraph property can use them
std::vector<std::string> arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
Expand Down

0 comments on commit 308f0f2

Please sign in to comment.