Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple bind with infer storage type #32

Merged
merged 4 commits into from
May 17, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,39 @@ MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle,
NDArrayHandle *aux_states,
ExecutorHandle shared_exec,
ExecutorHandle *out);

MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
const mx_uint num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const mx_uint provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
mx_uint* shared_buffer_len,
const char*** shared_buffer_name_list,
NDArrayHandle** shared_buffer_handle_list,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
mx_uint* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);
/*!
* \brief set a call back to notify the completion of operation
*/
Expand Down
33 changes: 33 additions & 0 deletions include/mxnet/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ class Executor {
* \return array of outputs in the executor.
*/
virtual const std::vector<NDArray> &outputs() const = 0;
/*!
* \brief get input argument map, key is arg name, value is arg's NDArray.
* \return input argument map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& in_arg_map() const = 0;
/*!
* \brief get input argument graident map, key is arg name, value is gradient's NDArray.
* \return input argument gradient map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& arg_grad_map() const = 0;
/*!
* \brief get aux state map, key is arg name, value is aux state's NDArray.
* \return aux state map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& aux_state_map() const = 0;
/*!
* \brief Create an operator by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp.
Expand All @@ -91,6 +106,24 @@ class Executor {
const std::vector<OpReqType> &grad_req_type,
const std::vector<NDArray> &aux_states,
Executor* shared_exec = NULL);

static Executor* SimpleBind(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& group2ctx,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::unordered_map<std::string, int>& arg_stype_map,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& param_names,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grads,
std::vector<NDArray>* aux_states,
std::unordered_map<std::string, NDArray>*
shared_data_arrays = nullptr,
Executor* shared_exec = nullptr);
/*!
* \brief the prototype of user-defined monitor callback
*/
Expand Down
8 changes: 4 additions & 4 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class NDArray {
} else if (storage_type == kCSRStorage) {
aux_types = {CSR_IND_PTR_TYPE, CSR_IDX_DTYPE};
} else {
LOG(FATAL) << "Unknown storage type";
LOG(FATAL) << "Unknown storage type" << storage_type;
}
}
// Assign default shapes if not given
Expand All @@ -139,7 +139,7 @@ class NDArray {
// aux shapes for indptr and indices
aux_shapes = {TShape({0}), TShape({0})};
} else {
LOG(FATAL) << "Unknown storage type";
LOG(FATAL) << "Unknown storage type" << storage_type;
}
}
if (storage_shape.Size() == 0) {
Expand All @@ -149,7 +149,7 @@ class NDArray {
} else if (storage_type == kCSRStorage) {
storage_shape = aux_shapes[csr::kIdx];
} else {
LOG(FATAL) << "Unknown storage type";
LOG(FATAL) << "Unknown storage type" << storage_type;
}
}
ptr_ = std::make_shared<Chunk>(storage_type, storage_shape, ctx, delay_alloc,
Expand Down Expand Up @@ -735,7 +735,7 @@ class NDArray {
if (skip_free == false) {
Storage::Get()->Free(h);
for (size_t i = 0; i < aux_h.size(); i++) {
Storage::Get()->Free(aux_h[i]);
if (aux_h[i].size > 0) Storage::Get()->Free(aux_h[i]);
}
}
}, shandle.ctx, var);
Expand Down
81 changes: 5 additions & 76 deletions python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
from collections import OrderedDict

import numpy as np

from .. import context as ctx
Expand Down Expand Up @@ -559,6 +558,7 @@ def update_metric(self, eval_metric, labels):

def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
"""Internal utility function to bind the i-th executor.
This function utilizes simple_bind python interface.
"""
shared_exec = None if shared_group is None else shared_group.execs[i]
context = self.contexts[i]
Expand All @@ -568,85 +568,14 @@ def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
if label_shapes is not None:
input_shapes.update(dict(label_shapes))

arg_shapes, _, aux_shapes = self.symbol.infer_shape(**input_shapes)
assert arg_shapes is not None, "shape inference failed"

input_types = {x.name: x.dtype for x in data_shapes}
if label_shapes is not None:
input_types.update({x.name: x.dtype for x in label_shapes})
arg_types, _, aux_types = self.symbol.infer_type(**input_types)
assert arg_types is not None, "type inference failed"

arg_arrays = []
grad_arrays = {} if self.for_training else None

def _get_or_reshape(name, shared_data_arrays, arg_shape, arg_type, context, logger):
"""Internal helper to get a memory block or re-use by re-shaping."""
if name in shared_data_arrays:
arg_arr = shared_data_arrays[name]

if np.prod(arg_arr.shape) >= np.prod(arg_shape):
# nice, we can directly re-use this data blob
assert arg_arr.dtype == arg_type
arg_arr = arg_arr.reshape(arg_shape)
else:
logger.warning(('bucketing: data "%s" has a shape %s' % (name, arg_shape)) +
(', which is larger than already allocated ') +
('shape %s' % (arg_arr.shape,)) +
('. Need to re-allocate. Consider putting ') +
('default_bucket_key to') +
(' be the bucket taking the largest input for better ') +
('memory sharing.'))
arg_arr = nd.zeros(arg_shape, context, dtype=arg_type)

# replace existing shared array because the new one is bigger
shared_data_arrays[name] = arg_arr
else:
arg_arr = nd.zeros(arg_shape, context, dtype=arg_type)
shared_data_arrays[name] = arg_arr

return arg_arr

# create or borrow arguments and gradients
for j in range(len(self.arg_names)):
name = self.arg_names[j]
if name in self.param_names: # model parameters
if shared_exec is None:
arg_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
if self.grad_req[name] != 'null':
grad_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
grad_arrays[name] = grad_arr
else:
arg_arr = shared_exec.arg_dict[name]
assert arg_arr.shape == arg_shapes[j]
assert arg_arr.dtype == arg_types[j]
if self.grad_req[name] != 'null':
grad_arrays[name] = shared_exec.grad_dict[name]
else: # data, label, or states
arg_arr = _get_or_reshape(name, shared_data_arrays, arg_shapes[j], arg_types[j],
context, self.logger)

# data might also need grad if inputs_need_grad is True
if self.grad_req[name] != 'null':
grad_arrays[name] = _get_or_reshape('grad of ' + name, shared_data_arrays,
arg_shapes[j], arg_types[j], context,
self.logger)

arg_arrays.append(arg_arr)

# create or borrow aux variables
if shared_exec is None:
aux_arrays = [nd.zeros(s, context, dtype=t) for s, t in zip(aux_shapes, aux_types)]
else:
for j, arr in enumerate(shared_exec.aux_arrays):
assert aux_shapes[j] == arr.shape
assert aux_types[j] == arr.dtype
aux_arrays = shared_exec.aux_arrays[:]

executor = self.symbol.bind(ctx=context, args=arg_arrays,
args_grad=grad_arrays, aux_states=aux_arrays,
grad_req=self.grad_req, shared_exec=shared_exec)
# Get the total bytes allocated for this executor
executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req,
type_dict=input_types, param_names=self.param_names,
shared_exec=shared_exec,
shared_data_arrays=shared_data_arrays, **input_shapes)
self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
return executor

Expand Down
Loading