Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
C Api for simplebind, fix comment for trigoops, add atol to assert (#…
Browse files Browse the repository at this point in the history
…16585)

* C Api for simplebind, fix comment for trigoops, add atol to assert

* fix build issues

* fix lint and add regression test

* fix indent

* api doc and function name change

* fix lint and add infer shape test
  • Loading branch information
ChaiBapchya authored and anirudh2290 committed Oct 25, 2019
1 parent 4e03e6a commit c0e616f
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 81 deletions.
38 changes: 38 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2255,6 +2255,44 @@ MXNET_DLL int MXExecutorSimpleBindEx(SymbolHandle symbol_handle,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);


MXNET_DLL int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
const uint32_t num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const uint32_t provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const uint32_t num_provided_arg_shapes,
const char** provided_arg_shape_names,
const int64_t* provided_arg_shape_data,
const uint32_t* provided_arg_shape_idx,
const uint32_t num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const uint32_t num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const uint32_t num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
uint32_t* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
uint32_t* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);


/*!
* \brief DEPRECATED. Use MXExecutorReshapeEx instead.
* Return a new executor with the same symbol and shared memory,
Expand Down
110 changes: 74 additions & 36 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,42 +1695,80 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
aux_state_handles = ctypes.POINTER(NDArrayHandle)()

try:
check_call(_LIB.MXExecutorSimpleBindEx(self.handle,
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
num_ctx_map_keys,
ctx_map_keys,
ctx_map_dev_types,
ctx_map_dev_ids,
mx_uint(provided_req_type_list_len),
provided_grad_req_names,
provided_grad_req_types,
mx_uint(len(provided_arg_shape_names)),
c_str_array(provided_arg_shape_names),
c_array_buf(mx_int,
array('I', provided_arg_shape_data)),
c_array_buf(mx_uint,
array('i', provided_arg_shape_idx)),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
num_provided_arg_stypes,
provided_arg_stype_names,
provided_arg_stype_data,
mx_uint(len(shared_arg_name_list)),
c_str_array(shared_arg_name_list),
ctypes.byref(shared_buffer_len),
shared_buffer_names,
shared_buffer_handles,
ctypes.byref(updated_shared_buffer_names),
ctypes.byref(updated_shared_buffer_handles),
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
ctypes.byref(num_aux_states),
ctypes.byref(aux_state_handles),
shared_exec_handle,
ctypes.byref(exe_handle)))
if sys.version_info[0] > 2 and _int64_enabled():
check_call(_LIB.MXExecutorSimpleBindEx64(self.handle,
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
num_ctx_map_keys,
ctx_map_keys,
ctx_map_dev_types,
ctx_map_dev_ids,
mx_uint(provided_req_type_list_len),
provided_grad_req_names,
provided_grad_req_types,
mx_uint(len(provided_arg_shape_names)),
c_str_array(provided_arg_shape_names),
c_array_buf(mx_int64,
array('q', provided_arg_shape_data)),
c_array_buf(mx_uint,
array('i', provided_arg_shape_idx)),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
num_provided_arg_stypes,
provided_arg_stype_names,
provided_arg_stype_data,
mx_uint(len(shared_arg_name_list)),
c_str_array(shared_arg_name_list),
ctypes.byref(shared_buffer_len),
shared_buffer_names,
shared_buffer_handles,
ctypes.byref(updated_shared_buffer_names),
ctypes.byref(updated_shared_buffer_handles),
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
ctypes.byref(num_aux_states),
ctypes.byref(aux_state_handles),
shared_exec_handle,
ctypes.byref(exe_handle)))
else:
check_call(_LIB.MXExecutorSimpleBindEx(self.handle,
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
num_ctx_map_keys,
ctx_map_keys,
ctx_map_dev_types,
ctx_map_dev_ids,
mx_uint(provided_req_type_list_len),
provided_grad_req_names,
provided_grad_req_types,
mx_uint(len(provided_arg_shape_names)),
c_str_array(provided_arg_shape_names),
c_array_buf(mx_int,
array('I', provided_arg_shape_data)),
c_array_buf(mx_uint,
array('i', provided_arg_shape_idx)),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
num_provided_arg_stypes,
provided_arg_stype_names,
provided_arg_stype_data,
mx_uint(len(shared_arg_name_list)),
c_str_array(shared_arg_name_list),
ctypes.byref(shared_buffer_len),
shared_buffer_names,
shared_buffer_handles,
ctypes.byref(updated_shared_buffer_names),
ctypes.byref(updated_shared_buffer_handles),
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
ctypes.byref(num_aux_states),
ctypes.byref(aux_state_handles),
shared_exec_handle,
ctypes.byref(exe_handle)))
except MXNetError as e:
error_msg = "simple_bind error. Arguments:\n"
for k, v in kwargs.items():
Expand Down
Loading

0 comments on commit c0e616f

Please sign in to comment.