Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed May 14, 2017
1 parent 4743a0d commit b9854e4
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 304 deletions.
10 changes: 5 additions & 5 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1096,11 +1096,11 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_param_names,
const char** param_name_list,
mx_uint* num_shared_data_arrays,
const char*** shared_data_array_name_list,
NDArrayHandle** shared_data_array_handle_list,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
mx_uint* shared_buffer_len,
const char*** shared_buffer_name_list,
NDArrayHandle** shared_buffer_handle_list,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
Expand Down
98 changes: 0 additions & 98 deletions python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,104 +557,6 @@ def update_metric(self, eval_metric, labels):
preds = OrderedDict(zip(self.output_names, texec.outputs))
eval_metric.update_dict(labels, preds)

def _bind_ith_exec_v1(self, i, data_shapes, label_shapes, shared_group):
"""Internal utility function to bind the i-th executor.
"""
warnings.warn(
'\033[91mmxnet.module.executor_group.DataParallelExecutorGroup._bind_ith_exec_v1' +
'has been deprecated. ' +
'Please use _bind_ith_exec instead.\033[0m',
DeprecationWarning, stacklevel=2)
shared_exec = None if shared_group is None else shared_group.execs[i]
context = self.contexts[i]
shared_data_arrays = self.shared_data_arrays[i]

input_shapes = dict(data_shapes)
if label_shapes is not None:
input_shapes.update(dict(label_shapes))

arg_shapes, _, aux_shapes = self.symbol.infer_shape(**input_shapes)
assert arg_shapes is not None, "shape inference failed"

input_types = {x.name: x.dtype for x in data_shapes}
if label_shapes is not None:
input_types.update({x.name: x.dtype for x in label_shapes})
arg_types, _, aux_types = self.symbol.infer_type(**input_types)
assert arg_types is not None, "type inference failed"

arg_arrays = []
grad_arrays = {} if self.for_training else None

def _get_or_reshape(name, shared_data_arrays, arg_shape, arg_type, context, logger):
"""Internal helper to get a memory block or re-use by re-shaping."""
if name in shared_data_arrays:
arg_arr = shared_data_arrays[name]

if np.prod(arg_arr.shape) >= np.prod(arg_shape):
# nice, we can directly re-use this data blob
assert arg_arr.dtype == arg_type
arg_arr = arg_arr.reshape(arg_shape)
else:
logger.warning(('bucketing: data "%s" has a shape %s' % (name, arg_shape)) +
(', which is larger than already allocated ') +
('shape %s' % (arg_arr.shape,)) +
('. Need to re-allocate. Consider putting ') +
('default_bucket_key to') +
(' be the bucket taking the largest input for better ') +
('memory sharing.'))
arg_arr = nd.zeros(arg_shape, context, dtype=arg_type)

# replace existing shared array because the new one is bigger
shared_data_arrays[name] = arg_arr
else:
arg_arr = nd.zeros(arg_shape, context, dtype=arg_type)
shared_data_arrays[name] = arg_arr

return arg_arr

# create or borrow arguments and gradients
for j in range(len(self.arg_names)):
name = self.arg_names[j]
if name in self.param_names: # model parameters
if shared_exec is None:
arg_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
if self.grad_req[name] != 'null':
grad_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
grad_arrays[name] = grad_arr
else:
arg_arr = shared_exec.arg_dict[name]
assert arg_arr.shape == arg_shapes[j]
assert arg_arr.dtype == arg_types[j]
if self.grad_req[name] != 'null':
grad_arrays[name] = shared_exec.grad_dict[name]
else: # data, label, or states
arg_arr = _get_or_reshape(name, shared_data_arrays, arg_shapes[j], arg_types[j],
context, self.logger)

# data might also need grad if inputs_need_grad is True
if self.grad_req[name] != 'null':
grad_arrays[name] = _get_or_reshape('grad of ' + name, shared_data_arrays,
arg_shapes[j], arg_types[j], context,
self.logger)

arg_arrays.append(arg_arr)

# create or borrow aux variables
if shared_exec is None:
aux_arrays = [nd.zeros(s, context, dtype=t) for s, t in zip(aux_shapes, aux_types)]
else:
for j, arr in enumerate(shared_exec.aux_arrays):
assert aux_shapes[j] == arr.shape
assert aux_types[j] == arr.dtype
aux_arrays = shared_exec.aux_arrays[:]

executor = self.symbol.bind(ctx=context, args=arg_arrays,
args_grad=grad_arrays, aux_states=aux_arrays,
grad_req=self.grad_req, shared_exec=shared_exec)
# Get the total bytes allocated for this executor
self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
return executor

def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
"""Internal utility function to bind the i-th executor.
This function utilizes simple_bind python interface.
Expand Down
190 changes: 48 additions & 142 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,7 @@ def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing):
return c_array(NDArrayHandle, arg_handles), arg_arrays

def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
param_names=None, shared_exec=None, shared_data_arrays=None, **kwargs):
shared_arg_names=None, shared_exec=None, shared_buffer=None, **kwargs):
"""Bind current symbol to get an executor, allocate all the arguments needed.
Allows specifying data types.
Expand Down Expand Up @@ -1160,6 +1160,19 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
group2ctx : Dict of string to mx.Context
The dict mapping the `ctx_group` attribute to the context assignment.
shared_arg_names : List of string
The argument names whose `NDArray` of shared_exec can be reused for initializing
the current executor.
shared_exec : Executor
The executor whose arg_arrays, arg_arrays, grad_arrays, and aux_arrays can be
reused for initializing the current executor.
shared_buffer : Dict of string to `NDArray`
The dict mapping argument names to the `NDArray` that can be reused for initializing
the current executor. This buffer will be checked for reuse if one argument name
of the current executor is not found in `shared_arg_names`.
kwargs : Dict of str->shape
Input shape dictionary, name->shape
Expand Down Expand Up @@ -1239,28 +1252,28 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
ctx_map_dev_ids = c_array(ctypes.c_int, ctx_map_dev_ids)

# prepare param names
param_name_list = []
if param_names is not None:
if not isinstance(param_names, list):
raise ValueError('param_names in simple_bind must be a list or None')
param_name_list = [c_str(name) for name in param_names]

# prepare shared_data_arrays
if shared_data_arrays is None:
num_shared_data_arrays = mx_uint()
shared_data_array_names = ctypes.POINTER(ctypes.c_char_p)()
shared_data_array_handles = ctypes.POINTER(NDArrayHandle)()
shared_arg_name_list = []
if shared_arg_names is not None:
if not isinstance(shared_arg_names, list):
raise ValueError('shared_arg_names in simple_bind must be a list or None')
shared_arg_name_list = [c_str(name) for name in shared_arg_names]

# prepare shared_buffer
if shared_buffer is None:
shared_buffer_len = mx_uint()
shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)()
shared_buffer_handles = ctypes.POINTER(NDArrayHandle)()
else:
if not isinstance(shared_data_arrays, dict):
raise ValueError('shared_data_arrays in simple_bind must be dict or None')
shared_data_array_names = []
shared_data_array_handles = []
for k, v in shared_data_arrays.items():
shared_data_array_names.append(c_str(k))
shared_data_array_handles.append(v.handle)
shared_data_array_names = c_array(ctypes.c_char_p, shared_data_array_names)
num_shared_data_arrays = mx_uint(len(shared_data_array_handles))
shared_data_array_handles = c_array(NDArrayHandle, shared_data_array_handles)
if not isinstance(shared_buffer, dict):
raise ValueError('shared_buffer in simple_bind must be dict or None')
shared_buffer_names = []
shared_buffer_handles = []
for k, v in shared_buffer.items():
shared_buffer_names.append(c_str(k))
shared_buffer_handles.append(v.handle)
shared_buffer_names = c_array(ctypes.c_char_p, shared_buffer_names)
shared_buffer_len = mx_uint(len(shared_buffer_handles))
shared_buffer_handles = c_array(NDArrayHandle, shared_buffer_handles)

# prepare shared_exec_handle
shared_exec_handle = shared_exec.handle if shared_exec is not None else ExecutorHandle()
Expand Down Expand Up @@ -1292,11 +1305,11 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
mx_uint(len(param_name_list)),
c_array(ctypes.c_char_p, param_name_list),
ctypes.byref(num_shared_data_arrays),
ctypes.byref(shared_data_array_names),
ctypes.byref(shared_data_array_handles),
mx_uint(len(shared_arg_name_list)),
c_array(ctypes.c_char_p, shared_arg_name_list),
ctypes.byref(shared_buffer_len),
ctypes.byref(shared_buffer_names),
ctypes.byref(shared_buffer_handles),
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
Expand All @@ -1305,14 +1318,14 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
shared_exec_handle,
ctypes.byref(exe_handle)))

# update shared_data_arrays
if shared_data_arrays is not None:
updated_shared_data_arrays = [NDArray(NDArrayHandle(shared_data_array_handles[i]))
for i in range(num_shared_data_arrays.value)]
updated_shared_data_array_names = [py_str(shared_data_array_names[i])
for i in range(num_shared_data_arrays.value)]
for k, v in zip(updated_shared_data_array_names, updated_shared_data_arrays):
shared_data_arrays[k] = v
# update shared_buffer
if shared_buffer is not None:
updated_shared_buffer = [NDArray(NDArrayHandle(shared_buffer_handles[i]))
for i in range(shared_buffer_len.value)]
updated_shared_buffer_names = [py_str(shared_buffer_names[i])
for i in range(shared_buffer_len.value)]
for k, v in zip(updated_shared_buffer_names, updated_shared_buffer):
shared_buffer[k] = v

# create in_args, arg_grads, and aux_states for the current executor
arg_arrays = [NDArray(NDArrayHandle(in_arg_handles[i])) for i in range(num_in_args.value)]
Expand All @@ -1328,113 +1341,6 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
executor.aux_arrays = aux_arrays
return executor

def simple_bind_v1(self, ctx,
grad_req='write',
type_dict=None,
group2ctx=None,
**kwargs):
"""This function is DEPRECATED.
Bind current symbol to get an executor, allocate all the arguments needed.
Allows specifying data types.
This function simplifies the binding procedure. You need to specify only input data shapes.
Before binding the executor, the function allocates arguments and auxiliary states
that were not explicitly specified. Allows specifying data types.
Example usage:
----------
>>> x = mx.sym.Variable('x')
>>> y = mx.sym.FullyConnected(x, num_hidden=4)
>>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null')
>>> exe.forward()
[<NDArray 5x4 @cpu(0)>]
>>> exe.outputs[0].asnumpy()
array([[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.]], dtype=float32)
>>> exe.arg_arrays
[<NDArray 5x4 @cpu(0)>, <NDArray 4x4 @cpu(0)>, <NDArray 4 @cpu(0)>]
>>> exe.grad_arrays
[<NDArray 5x4 @cpu(0)>, <NDArray 4x4 @cpu(0)>, <NDArray 4 @cpu(0)>]
Parameters
----------
ctx : Context
The device context the generated executor to run on.
grad_req: string
{'write', 'add', 'null'}, or list of str or dict of str to str, optional
To specify how we should update the gradient to the `args_grad`.
- 'write' means every time gradient is written to specified `args_grad` NDArray.
- 'add' means every time gradient is added to the specified NDArray.
- 'null' means no action is taken, the gradient may not be calculated.
type_dict : Dict of str->numpy.dtype
Input type dictionary, name->dtype
group2ctx : Dict of string to mx.Context
The dict mapping the `ctx_group` attribute to the context assignment.
kwargs : Dict of str->shape
Input shape dictionary, name->shape
Returns
-------
executor : mxnet.Executor
The generated executor
"""
# pylint: disable=too-many-locals
warnings.warn(
'\033[91mmxnet.symbol.simple_bind_v1' +
'has been deprecated. ' +
'Please use simple_bind instead.\033[0m',
DeprecationWarning, stacklevel=2)

if type_dict is None:
attrs = self.attr_dict()
type_dict = {k: mx_real_t for k in self.list_arguments()
if k not in attrs or '__dtype__' not in attrs[k]}
arg_shapes, _, aux_shapes = self.infer_shape(**kwargs)
arg_types, _, aux_types = self.infer_type(**type_dict)

if arg_shapes is None or arg_types is None:
raise ValueError("Input node is not complete")

if group2ctx is not None:
attr_dict = self.attr_dict()
arg_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx) \
if name in attr_dict and '__ctx_group__' in attr_dict[name] \
else ctx for name in self.list_arguments()]
aux_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx) \
if name in attr_dict and '__ctx_group__' in attr_dict[name] \
else ctx for name in self.list_auxiliary_states()]
else:
arg_ctx = [ctx] * len(arg_shapes)
aux_ctx = [ctx] * len(aux_shapes)

# alloc space
arg_ndarrays = [
_nd_zeros(shape, dev, dtype=dtype)
for dtype, dev, shape in zip(arg_types, arg_ctx, arg_shapes)]
if grad_req != 'null':
grad_ndarrays = {}
for name, shape, dev, dtype in zip(
self.list_arguments(), arg_shapes, arg_ctx, arg_types):
if not isinstance(grad_req, dict) or grad_req[name] != 'null':
grad_ndarrays[name] = _nd_zeros(shape, dev, dtype=dtype)
else:
grad_ndarrays = None

aux_ndarrays = [_nd_zeros(shape, dev, dtype=dtype)
for shape, dev, dtype in zip(aux_shapes, aux_ctx, aux_types)]
executor = self.bind(ctx, arg_ndarrays,
grad_ndarrays, grad_req, aux_ndarrays,
group2ctx=group2ctx)
return executor

def bind(self, ctx, args, args_grad=None, grad_req='write',
aux_states=None, group2ctx=None, shared_exec=None):
"""Binds the current symbol to an executor and returns it.
Expand Down
Loading

0 comments on commit b9854e4

Please sign in to comment.