Skip to content

Commit

Permalink
Initial checkin
Browse files Browse the repository at this point in the history
Add init functions for simple bind in graph_executor

Add simple_bind c_api

Add simple bind c-api

Assign zeros to in_args, arg_grads, and aux_states

Add simple_bind2 python interface

Fix python interface bugs

Interface changes

Fix

Fix core dump

Add bind_ith_exec c_api

Change simple_bind2

Fix seg fault

Finish simple_bind

Change _bind_ith_exec

Refactor simple_bind initialization flow for bind

Consolidate bind and simple_bind graph init flow

Fix bug

Clean up

Add comments

Clean up

Clean up

Minor correction

Rename APIs in graph executor

Refactor

Rebase

Delete deprecated functions

Move more front-end work to backend

Bug fix

Fix failed tests

Minor fix

Fix lint

Fix lint

Revert unnecessary changes

Revert

Revert

Clean up

Fix lint

Fix bind_ith_exec calling simple_bind

Fix bugs for _bind_ith_exec
  • Loading branch information
reminisce committed May 31, 2017
1 parent 2a09f16 commit fa94bbb
Show file tree
Hide file tree
Showing 9 changed files with 1,111 additions and 243 deletions.
32 changes: 32 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,38 @@ MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle,
NDArrayHandle *aux_states,
ExecutorHandle shared_exec,
ExecutorHandle *out);

MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
const mx_uint num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const mx_uint provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
mx_uint* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);
/*!
* \brief set a call back to notify the completion of operation
*/
Expand Down
32 changes: 32 additions & 0 deletions include/mxnet/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ class Executor {
* \return array of outputs in the executor.
*/
virtual const std::vector<NDArray> &outputs() const = 0;
/*!
* \brief get input argument map, key is arg name, value is arg's NDArray.
* \return input argument map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& in_arg_map() const = 0;
/*!
* \brief get input argument graident map, key is arg name, value is gradient's NDArray.
* \return input argument gradient map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& arg_grad_map() const = 0;
/*!
* \brief get aux state map, key is arg name, value is aux state's NDArray.
* \return aux state map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& aux_state_map() const = 0;
/*!
* \brief Create an operator by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp.
Expand All @@ -91,6 +106,23 @@ class Executor {
const std::vector<OpReqType> &grad_req_type,
const std::vector<NDArray> &aux_states,
Executor* shared_exec = NULL);

static Executor* SimpleBind(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& group2ctx,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& param_names,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grads,
std::vector<NDArray>* aux_states,
std::unordered_map<std::string, NDArray>*
shared_data_arrays = nullptr,
Executor* shared_exec = nullptr);
/*!
* \brief the prototype of user-defined monitor callback
*/
Expand Down
81 changes: 5 additions & 76 deletions python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
from collections import OrderedDict

import numpy as np

from .. import context as ctx
Expand Down Expand Up @@ -564,6 +563,7 @@ def update_metric(self, eval_metric, labels):

def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
"""Internal utility function to bind the i-th executor.
This function utilizes simple_bind python interface.
"""
shared_exec = None if shared_group is None else shared_group.execs[i]
context = self.contexts[i]
Expand All @@ -573,85 +573,14 @@ def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
if label_shapes is not None:
input_shapes.update(dict(label_shapes))

arg_shapes, _, aux_shapes = self.symbol.infer_shape(**input_shapes)
assert arg_shapes is not None, "shape inference failed"

input_types = {x.name: x.dtype for x in data_shapes}
if label_shapes is not None:
input_types.update({x.name: x.dtype for x in label_shapes})
arg_types, _, aux_types = self.symbol.infer_type(**input_types)
assert arg_types is not None, "type inference failed"

arg_arrays = []
grad_arrays = {} if self.for_training else None

def _get_or_reshape(name, shared_data_arrays, arg_shape, arg_type, context, logger):
"""Internal helper to get a memory block or re-use by re-shaping."""
if name in shared_data_arrays:
arg_arr = shared_data_arrays[name]

if np.prod(arg_arr.shape) >= np.prod(arg_shape):
# nice, we can directly re-use this data blob
assert arg_arr.dtype == arg_type
arg_arr = arg_arr.reshape(arg_shape)
else:
logger.warning(('bucketing: data "%s" has a shape %s' % (name, arg_shape)) +
(', which is larger than already allocated ') +
('shape %s' % (arg_arr.shape,)) +
('. Need to re-allocate. Consider putting ') +
('default_bucket_key to') +
(' be the bucket taking the largest input for better ') +
('memory sharing.'))
arg_arr = nd.zeros(arg_shape, context, dtype=arg_type)

# replace existing shared array because the new one is bigger
shared_data_arrays[name] = arg_arr
else:
arg_arr = nd.zeros(arg_shape, context, dtype=arg_type)
shared_data_arrays[name] = arg_arr

return arg_arr

# create or borrow arguments and gradients
for j in range(len(self.arg_names)):
name = self.arg_names[j]
if name in self.param_names: # model parameters
if shared_exec is None:
arg_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
if self.grad_req[name] != 'null':
grad_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
grad_arrays[name] = grad_arr
else:
arg_arr = shared_exec.arg_dict[name]
assert arg_arr.shape == arg_shapes[j]
assert arg_arr.dtype == arg_types[j]
if self.grad_req[name] != 'null':
grad_arrays[name] = shared_exec.grad_dict[name]
else: # data, label, or states
arg_arr = _get_or_reshape(name, shared_data_arrays, arg_shapes[j], arg_types[j],
context, self.logger)

# data might also need grad if inputs_need_grad is True
if self.grad_req[name] != 'null':
grad_arrays[name] = _get_or_reshape('grad of ' + name, shared_data_arrays,
arg_shapes[j], arg_types[j], context,
self.logger)

arg_arrays.append(arg_arr)

# create or borrow aux variables
if shared_exec is None:
aux_arrays = [nd.zeros(s, context, dtype=t) for s, t in zip(aux_shapes, aux_types)]
else:
for j, arr in enumerate(shared_exec.aux_arrays):
assert aux_shapes[j] == arr.shape
assert aux_types[j] == arr.dtype
aux_arrays = shared_exec.aux_arrays[:]

executor = self.symbol.bind(ctx=context, args=arg_arrays,
args_grad=grad_arrays, aux_states=aux_arrays,
grad_req=self.grad_req, shared_exec=shared_exec)
# Get the total bytes allocated for this executor
executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req,
type_dict=input_types, shared_arg_names=self.param_names,
shared_exec=shared_exec,
shared_buffer=shared_data_arrays, **input_shapes)
self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
return executor

Expand Down
Loading

0 comments on commit fa94bbb

Please sign in to comment.