Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add input shape_dict, type_dict and stype_dict to optimize_for
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed May 18, 2020
1 parent 308f0f2 commit f7bbe7b
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 9 deletions.
12 changes: 11 additions & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2235,7 +2235,17 @@ MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
const mx_uint num_options,
const char** keys,
const char** vals,
bool deferred_infer,
const uint32_t num_input_shapes,
const char** input_shape_names,
const int64_t* input_shape_data,
const uint32_t* input_shape_idx,
const uint32_t num_input_dtypes,
const char** input_dtype_names,
const int* input_dtypes,
const uint32_t num_input_stypes,
const char** input_stype_names,
const int* input_stypes,
bool skip_infer,
int* new_args_cnt,
NDArrayHandle** new_args_handle,
char*** new_arg_names_handle,
Expand Down
92 changes: 87 additions & 5 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,8 @@ def _gen_atomic_symbol(self):


# pylint: disable=too-many-locals
def optimize_for(self, backend, args=None, aux=None, ctx=None, deferred_infer=False, **kwargs):
def optimize_for(self, backend, args=None, aux=None, ctx=None,
shape_dict=None, type_dict=None, stype_dict=None, skip_infer=False, **kwargs):
"""Partitions current symbol and optimizes it for a given backend,
returns new partitioned symbol.
Expand Down Expand Up @@ -1475,9 +1476,20 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, deferred_infer=Fa
ctx : Context, optional
Device context, used to infer stypes
deferred_infer : bool, optional
shape_dict : Dict of str->tuple, optional
Input shape dictionary.
Used iff input NDArray is not in `args`.
type_dict : Dict of str->numpy.dtype, optional
Input type dictionary.
Used iff input NDArray is not in `args`.
stype_dict : Dict of str->str, optional
Input storage type dictionary.
Used iff input NDArray is not in `args`.
skip_infer : bool, optional
If True, the optimization skips the shape, type and storage type inference pass.
(Deferring it to `bind`.)
kwargs : optional arguments
Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
Expand All @@ -1502,11 +1514,71 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, deferred_infer=Fa
aux_handle = c_array(NDArrayHandle, [])
else:
aux_handle, aux_ = self._get_ndarray_inputs('aux_states', aux,
self.list_auxiliary_states(), False)
self.list_auxiliary_states(), True)
if ctx is None:
ctx = current_context()
assert isinstance(ctx, Context)


# parse input data shape dict
num_input_shapes = 0
input_shape_names = ctypes.POINTER(ctypes.c_char_p)()
input_shape_data = ctypes.POINTER(mx_int64)()
input_shape_idx = ctypes.POINTER(mx_uint)()
if shape_dict is not None:
input_shape_names = []
input_shape_data = []
input_shape_idx = [0]
for k, v in shape_dict.items():
if isinstance(v, (tuple, list)):
input_shape_names.append(k)
input_shape_data.extend(v)
input_shape_idx.append(len(input_shape_data))
else:
raise ValueError(str(v) + " has to be a tuple or list.")
num_input_shapes = mx_uint(len(input_shape_names))
input_shape_names = c_str_array(input_shape_names)
input_shape_data = c_array_buf(mx_int64, array('q', input_shape_data))
input_shape_idx = c_array_buf(mx_uint, array('i', input_shape_idx))

# parse input data types dict
num_input_types = 0
input_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names
input_type_data = ctypes.POINTER(mx_uint)() # provided types
if type_dict is not None:
input_type_names = []
input_type_data = []
for k, v in type_dict.items():
v = _numpy.dtype(v).type
if v in _DTYPE_NP_TO_MX:
input_type_names.append(k)
input_type_data.append(_DTYPE_NP_TO_MX[v])
else:
raise ValueError(str(v) + " is not a MXNet type.")

num_input_types = mx_uint(len(input_type_names))
input_type_names = c_str_array(input_type_names)
input_type_data = c_array_buf(ctypes.c_int, array('i', input_type_data))

# parse input data storage types dict
num_input_stypes = 0
# provided storage type argument names
input_stype_names = ctypes.POINTER(ctypes.c_char_p)()
input_stype_data = ctypes.POINTER(mx_uint)() # provided storage types
if stype_dict is not None:
input_stype_names = []
input_stype_data = []
for k, v in stype_dict.items():
if v in _STORAGE_TYPE_STR_TO_ID:
input_stype_names.append(k)
input_stype_data.append(_STORAGE_TYPE_STR_TO_ID[v])
else:
raise ValueError(str(v) + " is not a MXNet storage type.")

num_input_stypes = mx_uint(len(input_stype_names))
input_stype_names = c_str_array(input_stype_names)
input_stype_data = c_array_buf(ctypes.c_int, array('i', input_stype_data))

new_args_size = ctypes.c_uint()
new_arg_names = ctypes.POINTER(ctypes.c_char_p)()
new_args_handle = ctypes.POINTER(NDArrayHandle)()
Expand All @@ -1530,7 +1602,17 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, deferred_infer=Fa
mx_uint(len(key_list)),
c_str_array(key_list),
c_str_array(val_list),
ctypes.c_bool(deferred_infer),
num_input_shapes,
input_shape_names,
input_shape_data,
input_shape_idx,
num_input_types,
input_type_names,
input_type_data,
num_input_stypes,
input_stype_names,
input_stype_data,
ctypes.c_bool(skip_infer),
ctypes.byref(new_args_size),
ctypes.byref(new_args_handle),
ctypes.byref(new_arg_names),
Expand Down
47 changes: 44 additions & 3 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,17 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
const mx_uint num_options,
const char** keys,
const char** vals,
bool deferred_infer,
const uint32_t num_input_shapes,
const char** input_shape_names,
const int64_t* input_shape_data,
const uint32_t* input_shape_idx,
const uint32_t num_input_dtypes,
const char** input_dtype_names,
const int* input_dtypes,
const uint32_t num_input_stypes,
const char** input_stype_names,
const int* input_stypes,
bool skip_infer,
int* new_args_cnt,
NDArrayHandle** new_args_handle,
char*** new_arg_names_handle,
Expand All @@ -1384,23 +1394,54 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
if (args_len || aux_len) {
NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
if (!deferred_infer) {
if (!skip_infer) {
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(args_len + aux_len);
nnvm::DTypeVector arg_dtypes(args_len + aux_len);
StorageTypeVector arg_stypes(args_len + aux_len);

// create the input shape, dtype and stype maps
std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
for (int i = 0; i < num_input_shapes; ++i) {
input_shape_map.emplace(input_shape_names[i],
mxnet::TShape(input_shape_data + input_shape_idx[i],
input_shape_data + input_shape_idx[i+1]));
}
std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
for (int i = 0; i < num_input_dtypes; ++i) {
input_dtype_map.emplace(input_dtype_names[i]], input_dtypes[i]);
}
std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
for (int i = 0; i < num_input_stypes; ++i) {
input_stype_map.emplace(input_stype_names[i]], input_stypes[i]);
}

size_t args_top = 0, aux_top = 0;
// loop over inputs to symbol in order and add to args/aux if mutable
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = indexed_graph.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
auto name = input_names[i];
CHECK_LT(aux_top, aux_len)
<< "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
<< "Cannot find aux '" << name << "' in provided aux to optimize_for";
if (in_aux_ptr[aux_top] != nullptr) {
const auto &in_arg = *(in_aux_ptr[aux_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
} else {
auto it_shape = input_shape_map.find(name);
if (it_shape != input_shape_map.end()) {
arg_shapes[i] = it_shape->second();
}
auto it_type = input_dtype_map.find(name);
if (it_type != input_dtype_map.end()) {
arg_dtypes[i] = it_type->second();
}
it_type = input_stype_map.find(name);
if (it_type != input_stype_map.end()) {
arg_stypes[i] = it_type->second();
}
}
aux_top++;
} else {
Expand Down

0 comments on commit f7bbe7b

Please sign in to comment.