Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
move exec.reshape to backend (#10882)
Browse files Browse the repository at this point in the history
* move exec.reshape to backend

* fix lint

* fix lint

* fix Symbol._get_ndarray_inputs

* update

* update

* move Reshape as a member function of Executor

* address comments
  • Loading branch information
ZiyueHuang authored and eric-haibin-lin committed May 24, 2018
1 parent 30ca4e3 commit 704d218
Show file tree
Hide file tree
Showing 8 changed files with 355 additions and 61 deletions.
41 changes: 41 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,47 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);

/*!
* \brief Return a new executor with the same symbol and shared memory,
* but different input/output shapes.
*
* \param partial_shaping Whether to allow changing the shape of unspecified arguments.
* \param allow_up_sizing Whether to allow allocating new ndarrays that's larger than the original.
* \param dev_type device type of default context
* \param dev_id device id of default context
* \param num_map_keys size of group2ctx map
* \param map_keys keys of group2ctx map
* \param map_dev_types device type of group2ctx map
* \param map_dev_ids device id of group2ctx map
* \param num_in_args length of in_args
* \param in_args in args array
* \param arg_grads arg grads handle array
* \param num_aux_states length of auxiliary states
* \param aux_states auxiliary states array
* \param shared_exec input executor handle for memory sharing
* \param out output executor handle
* \return a new executor
*/
MXNET_DLL int MXExecutorReshape(int partial_shaping,
int allow_up_sizing,
int dev_type,
int dev_id,
mx_uint num_map_keys,
const char** map_keys,
const int* map_dev_types,
const int* map_dev_ids,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
mx_uint* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec,
ExecutorHandle *out);
/*!
* \brief set a call back to notify the completion of operation
*/
Expand Down
23 changes: 23 additions & 0 deletions include/mxnet/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,29 @@ class Executor {
* \return aux state map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& aux_state_map() const = 0;
/*!
* \brief Return a new executor with the same symbol and shared memory,
* but different input/output shapes.
*
* \param partial_shaping Whether to allow changing the shape of unspecified arguments.
* \param allow_up_sizing Whether to allow allocating new ndarrays that's larger than the original.
* \param default_ctx the default context of binding.
* \param ctx_map Context mapping group to context.
* \param provided_arg_shapes New shape for arguments.
* \param in_args the NDArray that stores the input arguments.
* \param arg_grads NDArray that is used to store the gradient output of the input arguments.
* \param aux_states NDArray that is used as internal states.
* \return a new executor.
*/
virtual Executor* Reshape(const bool partial_shaping,
const bool allow_up_sizing,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::unordered_map<std::string, TShape>&
provided_arg_shapes,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grads,
std::vector<NDArray>* aux_states) = 0;
/*!
* \brief Create an operator by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp.
Expand Down
129 changes: 70 additions & 59 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
"""Symbolic Executor component of MXNet."""
from __future__ import absolute_import

from array import array as py_array
import ctypes
import copy
import numpy as np
from .base import _LIB
from .base import mx_uint, NDArrayHandle, ExecutorHandle
from .base import check_call, c_handle_array, py_str
from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str
from .base import check_call, c_handle_array, c_array_buf, c_str_array
from .ndarray import NDArray
from .ndarray import _ndarray_cls
from . import ndarray as nd

# those functions are not used here, we just import them to keep backward compatibility
# in case the end user calls them, as they originally lives here
Expand Down Expand Up @@ -399,62 +399,73 @@ def reshape(self, partial_shaping=False, allow_up_sizing=False, **kwargs):
>>> texec.reshape(allow_up_sizing=True, **new_shape)
"""
# pylint: disable=too-many-branches
arg_shapes, _, aux_shapes = self._symbol.infer_shape(**kwargs)
if arg_shapes is None:
raise ValueError("Insufficient argument shapes provided.")

new_arg_dict = {}
new_grad_dict = {}
for i, name in enumerate(self._symbol.list_arguments()):
new_shape = arg_shapes[i]
arr = self.arg_arrays[i]
darr = None if self.grad_arrays is None else self.grad_arrays[i]
if partial_shaping or name in kwargs or new_shape == arr.shape:
if np.prod(new_shape) > np.prod(arr.shape):
assert allow_up_sizing, "New shape of arg:%s larger than original. "%name + \
"First making a big executor and then down sizing it " + \
"is more efficient than the reverse." + \
"If you really want to up size, set allow_up_sizing=True " + \
"to enable allocation of new arrays."
new_arg_dict[name] = nd.empty(new_shape, ctx=arr.context, dtype=arr.dtype)
if darr is not None:
new_grad_dict[name] = nd.empty(new_shape, ctx=darr.context, dtype=arr.dtype)
else:
new_arg_dict[name] = arr.reshape(new_shape)
if darr is not None:
new_grad_dict[name] = darr.reshape(new_shape)
else:
raise AssertionError("Shape of unspecified array arg:%s changed. "%name + \
"This can cause the new executor to not share parameters " + \
"with the old one. Please check for error in network." +\
"If this is intended, set partial_shaping=True to suppress this warning.")

new_aux_dict = {}
for name, new_shape, arr in zip(self._symbol.list_auxiliary_states(),
aux_shapes, self.aux_arrays):
if partial_shaping or new_shape == arr.shape:
if np.prod(new_shape) > np.prod(arr.shape):
assert allow_up_sizing, "New shape of arg:%s larger than original. "%name + \
"First making a big executor and then down sizing it " + \
"is more efficient than the reverse." + \
"If you really want to up size, set allow_up_sizing=True " + \
"to enable allocation of new arrays."
new_aux_dict[name] = nd.empty(new_shape, ctx=arr.context, dtype=arr.dtype)
else:
new_aux_dict[name] = arr.reshape(new_shape)
else:
raise AssertionError("Shape of unspecified array aux:%s changed. "%name + \
"This can cause the new executor to not share parameters " + \
"with the old one. Please check for error in network." +\
"If this is intended, set partial_shaping=True to suppress this warning.")

return self._symbol.bind(self._ctx,
args=new_arg_dict,
args_grad=new_grad_dict,
grad_req=self._grad_req,
aux_states=new_aux_dict,
group2ctx=self._group2ctx,
shared_exec=self)
provided_arg_shape_data = [] # shape data
# argument shape index in sdata,
# e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg
provided_arg_shape_idx = [0]
provided_arg_shape_names = [] # provided argument names
for k, v in kwargs.items():
if isinstance(v, tuple):
provided_arg_shape_names.append(k)
provided_arg_shape_data.extend(v)
provided_arg_shape_idx.append(len(provided_arg_shape_data))

ctx_map_keys = []
ctx_map_dev_types = []
ctx_map_dev_ids = []

if self._group2ctx:
for key, val in self._group2ctx.items():
ctx_map_keys.append(key)
ctx_map_dev_types.append(val.device_typeid)
ctx_map_dev_ids.append(val.device_id)

handle = ExecutorHandle()
shared_handle = self.handle

num_in_args = ctypes.c_uint()
in_arg_handles = ctypes.POINTER(NDArrayHandle)()
arg_grad_handles = ctypes.POINTER(NDArrayHandle)()
num_aux_states = ctypes.c_uint()
aux_state_handles = ctypes.POINTER(NDArrayHandle)()

check_call(_LIB.MXExecutorReshape(ctypes.c_int(int(partial_shaping)),
ctypes.c_int(int(allow_up_sizing)),
ctypes.c_int(self._ctx.device_typeid),
ctypes.c_int(self._ctx.device_id),
mx_uint(len(ctx_map_keys)),
c_str_array(ctx_map_keys),
c_array_buf(ctypes.c_int,
py_array('i', ctx_map_dev_types)),
c_array_buf(ctypes.c_int,
py_array('i', ctx_map_dev_ids)),
mx_uint(len(provided_arg_shape_names)),
c_str_array(provided_arg_shape_names),
c_array_buf(mx_uint,
py_array('I', provided_arg_shape_data)),
c_array_buf(mx_uint,
py_array('I', provided_arg_shape_idx)),
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
ctypes.byref(num_aux_states),
ctypes.byref(aux_state_handles),
shared_handle,
ctypes.byref(handle)))

arg_arrays = [_ndarray_cls(NDArrayHandle(in_arg_handles[i]))
for i in range(num_in_args.value)]
grad_arrays = [_ndarray_cls(NDArrayHandle(arg_grad_handles[i]))
if arg_grad_handles[i] is not None
else None for i in range(num_in_args.value)]
aux_arrays = [_ndarray_cls(NDArrayHandle(aux_state_handles[i]))
for i in range(num_aux_states.value)]

executor = Executor(handle, self._symbol, self._ctx, self._grad_req, self._group2ctx)
executor.arg_arrays = arg_arrays
executor.grad_arrays = grad_arrays
executor.aux_arrays = aux_arrays
return executor

def debug_str(self):
"""Get a debug string about internal execution plan.
Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,9 +1259,12 @@ def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing):
if len(args) != len(arg_names):
raise ValueError('Length of %s does not match the number of arguments' % arg_key)
for narr in args:
if not isinstance(narr, NDArray):
if narr is None and allow_missing:
arg_handles.append(None)
elif not isinstance(narr, NDArray):
raise TypeError('Only accept list of NDArrays or dict of str to NDArray')
arg_handles.append(narr.handle)
else:
arg_handles.append(narr.handle)
arg_arrays = args
elif isinstance(args, dict):
for name in arg_names:
Expand Down
87 changes: 87 additions & 0 deletions src/c_api/c_api_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,93 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
API_END();
}

int MXExecutorReshape(int partial_shaping,
int allow_up_sizing,
int dev_type,
int dev_id,
mx_uint num_map_keys,
const char** map_keys,
const int* map_dev_types,
const int* map_dev_ids,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
mx_uint* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec,
ExecutorHandle *out) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
// create shape map for in_args and aux_states
std::unordered_map<std::string, TShape> kwargs(num_provided_arg_shapes);
for (mx_uint i = 0; i < num_provided_arg_shapes; ++i) {
auto p = kwargs.emplace(provided_arg_shape_names[i],
TShape(provided_arg_shape_data+provided_arg_shape_idx[i],
provided_arg_shape_data+provided_arg_shape_idx[i+1]));
CHECK(p.second) << "Duplicate shapes are provided for argument "
<< provided_arg_shape_names[i] << " in reshape of executor";
}

Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
std::map<std::string, Context> ctx_map;
for (mx_uint i = 0; i < num_map_keys; ++i) {
ctx_map[std::string(map_keys[i])] = Context::Create(
static_cast<Context::DeviceType>(map_dev_types[i]), map_dev_ids[i]);
}
std::vector<NDArray> in_arg_vec;
std::vector<NDArray> arg_grad_vec;
std::vector<NDArray> aux_state_vec;

Executor* exec = static_cast<Executor*>(shared_exec);
*out = exec->Reshape(partial_shaping, allow_up_sizing, ctx, ctx_map, kwargs,
&in_arg_vec, &arg_grad_vec, &aux_state_vec);

ret->ret_handles.clear();
ret->ret_handles.reserve(in_arg_vec.size()+arg_grad_vec.size()+aux_state_vec.size());

size_t nd_idx = 0;
for (const auto& nd : in_arg_vec) {
if (nd.is_none()) {
LOG(FATAL) << "Input argument NDArray cannot be un-allocated";
}
ret->ret_handles.push_back(new NDArray(nd));
}
if (in_arg_vec.size() > 0) {
*num_in_args = in_arg_vec.size();
*in_args = &(ret->ret_handles[nd_idx]);
nd_idx = ret->ret_handles.size();
}

for (const auto& nd : arg_grad_vec) {
if (nd.is_none()) {
ret->ret_handles.push_back(nullptr);
} else {
ret->ret_handles.push_back(new NDArray(nd));
}
}
if (arg_grad_vec.size() > 0) {
*arg_grads = &(ret->ret_handles[nd_idx]);
nd_idx = ret->ret_handles.size();
}

for (const auto& nd : aux_state_vec) {
if (nd.is_none()) {
LOG(FATAL) << "Auxiliary argument NDArray cannot be un-allocated";
}
ret->ret_handles.push_back(new NDArray(nd));
}
if (aux_state_vec.size() > 0) {
*num_aux_states = aux_state_vec.size();
*aux_states = &(ret->ret_handles[nd_idx]);
nd_idx = ret->ret_handles.size();
}
API_END_HANDLE_ERROR(delete out);
}

int MXExecutorSetMonitorCallback(ExecutorHandle handle,
ExecutorMonitorCallback callback,
void* callback_handle) {
Expand Down
Loading

0 comments on commit 704d218

Please sign in to comment.