Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add better partial args/aux handling in symbol optimize_for (#18350)
Browse files Browse the repository at this point in the history
* Add missing args/aux support in optimize_for and deferred inference option

Signed-off-by: Serge Panev <[email protected]>

* Add input shape_dict, type_dict and stype_dict to optimize_for

Signed-off-by: Serge Panev <[email protected]>

* Remove warnings for Werror

Signed-off-by: Serge Panev <[email protected]>

* Address PR comments

Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed Jul 14, 2020
1 parent 9d62392 commit 7f7e1c5
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 116 deletions.
30 changes: 30 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2235,6 +2235,25 @@ MXNET_DLL int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle
* \param num_options number of key value pairs
* \param keys keys for options
* \param vals values corresponding to keys
* \param num_input_shapes number of input shapes
* \param input_shape_names names of the input shapes
* \param input_shape_data pointer to the contiguous data shapes
* \param input_shape_idx array of per shape starting idx, the shape length for the i-th input shape
* is calculate as input_shape_idx[i+1] - input_shape_idx[i]
* \param num_input_dtypes number of input data types
* \param input_dtype_names array of names of the input data types
* \param input_dtypes array of values of the input data types
* \param num_input_stypesnumber of input storage types
* \param input_stype_names array of names of the input storage types
* \param input_stypes array of values of input storage types
* \param skip_infer if the optimization should skip the attribute inferences
* (to use if the backend does not require shape inference)
* \param new_args_cnt pointer a number to store the number of new args
* \param new_args_handle pointer on array to store the new args handles
* \param new_arg_names_handle pointer on array to store the new args names
* \param new_aux_cnt pointer a number to store the number of new aux
* \param new_aux_handle pointer on array to store the new aux handles
* \param new_aux_names_handle pointer on array to store the new aux names
*/
MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
const char* backend_name,
Expand All @@ -2247,6 +2266,17 @@ MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
const mx_uint num_options,
const char** keys,
const char** vals,
const uint32_t num_input_shapes,
const char** input_shape_names,
const int64_t* input_shape_data,
const uint32_t* input_shape_idx,
const uint32_t num_input_dtypes,
const char** input_dtype_names,
const int* input_dtypes,
const uint32_t num_input_stypes,
const char** input_stype_names,
const int* input_stypes,
bool skip_infer,
int* new_args_cnt,
NDArrayHandle** new_args_handle,
char*** new_arg_names_handle,
Expand Down
98 changes: 92 additions & 6 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,7 +1446,8 @@ def _gen_atomic_symbol(self):


# pylint: disable=too-many-locals
def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
def optimize_for(self, backend, args=None, aux=None, ctx=None,
shape_dict=None, type_dict=None, stype_dict=None, skip_infer=False, **kwargs):
"""Partitions current symbol and optimizes it for a given backend,
returns new partitioned symbol.
Expand All @@ -1457,19 +1458,33 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
args : dict of str to NDArray, optional
Input arguments to the symbol, required to infer shapes/types before partitioning
- If type is a dict of str to `NDArray`, then it maps the name of arguments
to the corresponding `NDArray`.
to the corresponding `NDArray`. Non defined arguments' `NDArray`s don't have to be
specified in the dict.
aux : dict of str to NDArray, optional
Input auxiliary arguments to the symbol
- If type is a dict of str to `NDArray`, then it maps the name of arguments
to the corresponding `NDArray`.
ctx : Context, optional
Device context, used to infer stypes
shape_dict : Dict of str->tuple, optional
Input shape dictionary.
Used iff input NDArray is not in `args`.
type_dict : Dict of str->numpy.dtype, optional
Input type dictionary.
Used iff input NDArray is not in `args`.
stype_dict : Dict of str->str, optional
Input storage type dictionary.
Used iff input NDArray is not in `args`.
skip_infer : bool, optional
If True, the optimization skips the shape, type and storage type inference pass.
kwargs : optional arguments
Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
Expand All @@ -1488,18 +1503,78 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
args_handle = c_array(NDArrayHandle, [])
else:
args_handle, args_ = self._get_ndarray_inputs('args', args,
self.list_arguments(), False)
self.list_arguments(), True)

if aux is None or len(aux) == 0:
aux_ = []
aux_handle = c_array(NDArrayHandle, [])
else:
aux_handle, aux_ = self._get_ndarray_inputs('aux_states', aux,
self.list_auxiliary_states(), False)
self.list_auxiliary_states(), True)
if ctx is None:
ctx = current_context()
assert isinstance(ctx, Context)


# parse input data shape dict
num_input_shapes = 0
input_shape_names = ctypes.POINTER(ctypes.c_char_p)()
input_shape_data = ctypes.POINTER(mx_int64)()
input_shape_idx = ctypes.POINTER(mx_uint)()
if shape_dict is not None:
input_shape_names = []
input_shape_data = []
input_shape_idx = [0]
for k, v in shape_dict.items():
if isinstance(v, (tuple, list)):
input_shape_names.append(k)
input_shape_data.extend(v)
input_shape_idx.append(len(input_shape_data))
else:
raise ValueError(str(v) + " has to be a tuple or list.")
num_input_shapes = mx_uint(len(input_shape_names))
input_shape_names = c_str_array(input_shape_names)
input_shape_data = c_array_buf(mx_int64, array('q', input_shape_data))
input_shape_idx = c_array_buf(mx_uint, array('i', input_shape_idx))

# parse input data types dict
num_input_types = 0
input_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names
input_type_data = ctypes.POINTER(mx_uint)() # provided types
if type_dict is not None:
input_type_names = []
input_type_data = []
for k, v in type_dict.items():
v = _numpy.dtype(v).type
if v in _DTYPE_NP_TO_MX:
input_type_names.append(k)
input_type_data.append(_DTYPE_NP_TO_MX[v])
else:
raise ValueError(str(v) + " is not a MXNet type.")

num_input_types = mx_uint(len(input_type_names))
input_type_names = c_str_array(input_type_names)
input_type_data = c_array_buf(ctypes.c_int, array('i', input_type_data))

# parse input data storage types dict
num_input_stypes = 0
# provided storage type argument names
input_stype_names = ctypes.POINTER(ctypes.c_char_p)()
input_stype_data = ctypes.POINTER(mx_uint)() # provided storage types
if stype_dict is not None:
input_stype_names = []
input_stype_data = []
for k, v in stype_dict.items():
if v in _STORAGE_TYPE_STR_TO_ID:
input_stype_names.append(k)
input_stype_data.append(_STORAGE_TYPE_STR_TO_ID[v])
else:
raise ValueError(str(v) + " is not a MXNet storage type.")

num_input_stypes = mx_uint(len(input_stype_names))
input_stype_names = c_str_array(input_stype_names)
input_stype_data = c_array_buf(ctypes.c_int, array('i', input_stype_data))

new_args_size = ctypes.c_uint()
new_arg_names = ctypes.POINTER(ctypes.c_char_p)()
new_args_handle = ctypes.POINTER(NDArrayHandle)()
Expand All @@ -1523,6 +1598,17 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
mx_uint(len(key_list)),
c_str_array(key_list),
c_str_array(val_list),
num_input_shapes,
input_shape_names,
input_shape_data,
input_shape_idx,
num_input_types,
input_type_names,
input_type_data,
num_input_stypes,
input_stype_names,
input_stype_data,
ctypes.c_bool(skip_infer),
ctypes.byref(new_args_size),
ctypes.byref(new_args_handle),
ctypes.byref(new_arg_names),
Expand Down
120 changes: 82 additions & 38 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,17 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
const mx_uint num_options,
const char** keys,
const char** vals,
const uint32_t num_input_shapes,
const char** input_shape_names,
const int64_t* input_shape_data,
const uint32_t* input_shape_idx,
const uint32_t num_input_dtypes,
const char** input_dtype_names,
const int* input_dtypes,
const uint32_t num_input_stypes,
const char** input_stype_names,
const int* input_stypes,
bool skip_infer,
int* new_args_cnt,
NDArrayHandle** new_args_handle,
char*** new_arg_names_handle,
Expand All @@ -1383,47 +1394,80 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
if (args_len || aux_len) {
NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(args_len + aux_len);
nnvm::DTypeVector arg_dtypes(args_len + aux_len);
StorageTypeVector arg_stypes(args_len + aux_len);
size_t args_top = 0, aux_top = 0;
// loop over inputs to symbol in order and add to args/aux if mutable
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = indexed_graph.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_len)
<< "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
const auto &in_arg = *(in_aux_ptr[aux_top++]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
} else {
CHECK_LT(args_top, args_len)
<< "Cannot find arg '" << input_names[i] << "' in provided args to optimize_for";
const auto &in_arg = *(in_args_ptr[args_top++]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
if (!skip_infer) {
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(args_len + aux_len);
nnvm::DTypeVector arg_dtypes(args_len + aux_len);
StorageTypeVector arg_stypes(args_len + aux_len);

// create the input shape, dtype and stype maps
std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
for (uint32_t i = 0; i < num_input_shapes; ++i) {
input_shape_map.emplace(input_shape_names[i],
mxnet::TShape(input_shape_data + input_shape_idx[i],
input_shape_data + input_shape_idx[i+1]));
}
std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
for (uint32_t i = 0; i < num_input_dtypes; ++i) {
input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
}
std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
for (uint32_t i = 0; i < num_input_stypes; ++i) {
input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
}
}

g.attrs["context"] = std::make_shared<nnvm::any>(
exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
size_t args_top = 0, aux_top = 0;
// loop over inputs to symbol in order and add to args/aux if mutable
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = indexed_graph.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_len)
<< "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
if (in_aux_ptr[aux_top] != nullptr) {
const auto &in_arg = *(in_aux_ptr[aux_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
}
aux_top++;
} else {
auto name = input_names[i];
CHECK_LT(args_top, args_len)
<< "Cannot find arg '" << name << "' in provided args to optimize_for";
if (in_args_ptr[args_top] != nullptr) {
const auto &in_arg = *(in_args_ptr[args_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
} else {
// input_names[i] is not in args but can be in the optional
// shape/type/stype attribute dicts.
auto it_shape = input_shape_map.find(name);
if (it_shape != input_shape_map.end()) {
arg_shapes[i] = it_shape->second;
}
auto it_type = input_dtype_map.find(name);
if (it_type != input_dtype_map.end()) {
arg_dtypes[i] = it_type->second;
}
it_type = input_stype_map.find(name);
if (it_type != input_stype_map.end()) {
arg_stypes[i] = it_type->second;
}
}
args_top++;
}
}

// infer shapes
g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
// infer dtypes
g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
common::HandleInferTypeError(num_forward_inputs, indexed_graph,
g.GetAttr<nnvm::DTypeVector>("dtype"));
}
// infer stypes
g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
common::HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
g.GetAttr<StorageTypeVector>("storage_type"));
g.attrs["context"] = std::make_shared<nnvm::any>(
exec::ContextVector(indexed_graph.num_nodes(), default_ctx));

// infer shapes
g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
// infer dtypes
g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
// infer stypes
g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
}
// set args/aux as attributes on graph so that subgraph property can use them
std::vector<std::string> arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
Expand Down
Loading

0 comments on commit 7f7e1c5

Please sign in to comment.