Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix bugs for _bind_ith_exec
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed May 23, 2017
1 parent e8b77b5 commit a07229b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 30 deletions.
8 changes: 5 additions & 3 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1101,9 +1101,11 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const int* provided_arg_dtypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
mx_uint* shared_buffer_len,
const char*** shared_buffer_name_list,
NDArrayHandle** shared_buffer_handle_list,
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
Expand Down
20 changes: 11 additions & 9 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,7 +1300,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,

# prepare shared_buffer
if shared_buffer is None:
shared_buffer_len = mx_uint()
shared_buffer_len = ctypes.c_int(-1)
shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)()
shared_buffer_handles = ctypes.POINTER(NDArrayHandle)()
else:
Expand All @@ -1312,8 +1312,10 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
shared_buffer_names.append(c_str(k))
shared_buffer_handles.append(v.handle)
shared_buffer_names = c_array(ctypes.c_char_p, shared_buffer_names)
shared_buffer_len = mx_uint(len(shared_buffer_handles))
shared_buffer_len = ctypes.c_int(len(shared_buffer_handles))
shared_buffer_handles = c_array(NDArrayHandle, shared_buffer_handles)
updated_shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)()
updated_shared_buffer_handles = ctypes.POINTER(NDArrayHandle)()

# prepare shared_exec_handle
shared_exec_handle = shared_exec.handle if shared_exec is not None else ExecutorHandle()
Expand Down Expand Up @@ -1348,8 +1350,10 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
mx_uint(len(shared_arg_name_list)),
c_array(ctypes.c_char_p, shared_arg_name_list),
ctypes.byref(shared_buffer_len),
ctypes.byref(shared_buffer_names),
ctypes.byref(shared_buffer_handles),
shared_buffer_names,
shared_buffer_handles,
ctypes.byref(updated_shared_buffer_names),
ctypes.byref(updated_shared_buffer_handles),
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
Expand All @@ -1360,11 +1364,9 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,

# update shared_buffer
if shared_buffer is not None:
updated_shared_buffer = [NDArray(NDArrayHandle(shared_buffer_handles[i]))
for i in range(shared_buffer_len.value)]
updated_shared_buffer_names = [py_str(shared_buffer_names[i])
for i in range(shared_buffer_len.value)]
for k, v in zip(updated_shared_buffer_names, updated_shared_buffer):
for i in range(shared_buffer_len.value):
k = py_str(updated_shared_buffer_names[i])
v = NDArray(NDArrayHandle(updated_shared_buffer_handles[i]))
shared_buffer[k] = v

# create in_args, arg_grads, and aux_states for the current executor
Expand Down
26 changes: 15 additions & 11 deletions src/c_api/c_api_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ int MXExecutorBindEX(SymbolHandle symbol_handle,
* \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec
* \param shared_buffer_name_list shared data array names passed from _bind_ith_exec
* \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec
* \param updated_shared_buffer_name_list updated shared data array names after binding
* \param updated_shared_buffer_handle_list updated shared data arrays after binding
* \param num_in_args number of input arguments of this sym
* \param in_args list_arguments associated with the current executor
* \param arg_grads list of gradients of in_args associated with the current executor
Expand Down Expand Up @@ -205,9 +207,11 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const int* provided_arg_dtypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
mx_uint* shared_buffer_len,
const char*** shared_buffer_name_list,
NDArrayHandle** shared_buffer_handle_list,
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
Expand Down Expand Up @@ -373,14 +377,14 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
std::vector<NDArray> shared_exec_in_args;
std::vector<NDArray> shared_exec_arg_grads;
std::vector<NDArray> shared_exec_aux_states;
bool use_shared_buffer = (nullptr != *shared_buffer_handle_list);
if (use_shared_buffer) {
bool use_shared_buffer = (*shared_buffer_len >= 0);
if (*shared_buffer_len > 0) {
// create shared_buffer_map
shared_buffer_map.reserve(*shared_buffer_len);
NDArray*** shared_buffer_ptrs =
reinterpret_cast<NDArray***>(shared_buffer_handle_list);
for (mx_uint i = 0; i < *shared_buffer_len; ++i) {
shared_buffer_map[*shared_buffer_name_list[i]] = *(*shared_buffer_ptrs)[i];
NDArray** shared_buffer_ptrs =
reinterpret_cast<NDArray**>(shared_buffer_handle_list);
for (int i = 0; i < *shared_buffer_len; ++i) {
shared_buffer_map[shared_buffer_name_list[i]] = *(shared_buffer_ptrs[i]);
}
}

Expand Down Expand Up @@ -449,8 +453,8 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
ret->ret_vec_charp.push_back(kv.first.c_str());
}
*shared_buffer_len = shared_buffer_map.size();
*shared_buffer_handle_list = &(ret->ret_handles[nd_idx]);
*shared_buffer_name_list = &(ret->ret_vec_charp[0]);
*updated_shared_buffer_handle_list = &(ret->ret_handles[nd_idx]);
*updated_shared_buffer_name_list = &(ret->ret_vec_charp[0]);
}

API_END();
Expand Down
10 changes: 3 additions & 7 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -539,10 +539,6 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
data_entry_.resize(idx.num_node_entries());
size_t arg_top = 0, aux_top = 0;
auto mutable_nodes = idx.mutable_input_nodes();
const auto& shared_exec_in_args = shared_exec->in_arg_map();
const auto& shared_exec_arg_grads = shared_exec->arg_grad_map();
const auto& shared_exec_aux_states = shared_exec->aux_state_map();
// TODO(junwu): populate in_arg_map, arg_grad_map, and aux_state_map
for (size_t i = 0; i < num_forward_inputs_; ++i) {
const uint32_t nid = idx.input_nodes().at(i);
const uint32_t eid = idx.entry_id(nid, 0);
Expand All @@ -551,7 +547,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
const std::string& arg_name = idx[nid].source->attrs.name;
if (mutable_nodes.count(nid)) { // aux_states
if (nullptr != shared_exec) {
const NDArray& aux_nd = shared_exec_aux_states.at(arg_name);
const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name);
CHECK_EQ(inferred_shape, aux_nd.shape())
<< "Inferred shape does not match shared_exec.aux_array's shape."
" Therefore, the allocated memory for shared_exec.aux_array cannot"
Expand All @@ -574,7 +570,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
} else { // in_args
if (shared_arg_names.count(arg_name)) { // model parameter
if (nullptr != shared_exec) {
const NDArray& in_arg_nd = shared_exec_in_args.at(arg_name);
const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name);
CHECK_EQ(inferred_shape, in_arg_nd.shape())
<< "Inferred shape does not match shared_exec.arg_array's shape"
" Therefore, the allocated memory for shared_exec.arg_array cannot"
Expand All @@ -589,7 +585,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
if (kNullOp == grad_req_types[arg_top]) {
arg_grad_vec->emplace_back();
} else {
arg_grad_vec->emplace_back(shared_exec_arg_grads.at(arg_name));
arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name));
grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
} // if (kNullOp == grad_req_types[arg_top])
} else { // !has shared_exec
Expand Down

0 comments on commit a07229b

Please sign in to comment.