Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
templatizing code
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Aug 11, 2019
1 parent 304d992 commit aeb995a
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 81 deletions.
31 changes: 14 additions & 17 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();

API_BEGIN();
*name = e->name.c_str();
Expand Down Expand Up @@ -189,8 +189,8 @@ int MXNDArrayCreateNone(NDArrayHandle *out) {
API_END();
}

template<typename dtype, typename dimtype>
void CreateArray(const dtype* shape, dimtype ndim, int dev_type, int dev_id, int delay_alloc,
template<typename DataType, typename dimtype>
void CreateArray(const DataType* shape, dimtype ndim, int dev_type, int dev_id, int delay_alloc,
int dtype, NDArrayHandle* out) {
*out = new NDArray(mxnet::TShape(shape, shape + ndim),
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
Expand All @@ -204,7 +204,9 @@ int MXNDArrayCreate(const mx_uint *shape,
int delay_alloc,
NDArrayHandle *out) {
API_BEGIN();
CreateArray<mx_uint, mx_uint>(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out);
*out = new NDArray(mxnet::TShape(shape, shape + ndim),
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
delay_alloc != 0);
API_END();
}

Expand Down Expand Up @@ -286,7 +288,7 @@ int MXNDArrayLoadFromRawBytes(const void *buf,
int MXNDArraySaveRawBytes(NDArrayHandle handle,
size_t *out_size,
const char **out_buf) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
ret->ret_str.resize(0);
dmlc::MemoryStringStream strm(&ret->ret_str);
Expand Down Expand Up @@ -382,7 +384,7 @@ int MXNDArrayLoad(const char* fname,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
ret->ret_vec_str.clear();
API_BEGIN();
std::vector<NDArray> data;
Expand Down Expand Up @@ -414,7 +416,7 @@ int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
ret->ret_vec_str.clear();
API_BEGIN();
CHECK_NOTNULL(ndarray_buffer);
Expand Down Expand Up @@ -538,7 +540,7 @@ int MXNDArrayGetStorageType(NDArrayHandle handle,
int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
Expand All @@ -556,7 +558,7 @@ int MXNDArrayGetShape(NDArrayHandle handle,

template<typename dtype>
inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim,
MXAPIThreadLocalEntry* ret) {
MXAPIThreadLocalEntry<dtype>* ret) {
NDArray* arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
mxnet::TShape s = arr->shape();
Expand All @@ -565,12 +567,7 @@ inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim
}
*out_dim = s.ndim();
if (s.ndim() >= 0) {
std::vector<dtype> &buffer =
#if MXNET_USE_INT64_TENSOR_SIZE == 1
ret->arg_shape_buffer_ex_int64;
#else
ret->arg_shape_buffer_ex;
#endif
std::vector<dtype> &buffer = ret->arg_shape_buffer_ex;
buffer.resize(s.ndim());
mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data());
*out_pdata = buffer.data();
Expand All @@ -587,7 +584,7 @@ inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim
int MXNDArrayGetShapeEx(NDArrayHandle handle,
int *out_dim,
const int **out_pdata) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
GetShape<int>(handle, out_pdata, out_dim, ret);
API_END();
Expand All @@ -596,7 +593,7 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle,
int MXNDArrayGetShapeExInt64(NDArrayHandle handle,
int *out_dim,
const mx_int64 **out_pdata) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<int64_t> *ret = MXAPIThreadLocalStore<int64_t>::Get();
API_BEGIN();
GetShape<mx_int64>(handle, out_pdata, out_dim, ret);
API_END();
Expand Down
20 changes: 5 additions & 15 deletions src/c_api/c_api_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
using namespace mxnet;

/*! \brief entry to to easily hold returning information */
template<typename dtype = int>
struct MXAPIThreadLocalEntry {
/*! \brief result holder for returning string */
std::string ret_str;
Expand All @@ -81,21 +82,11 @@ struct MXAPIThreadLocalEntry {
/*! \brief result holder for returning shape pointer */
std::vector<const mx_uint*> arg_shape_data, out_shape_data, aux_shape_data;
/*! \brief result holder for returning shape pointer */
#if MXNET_USE_INT64_TENSOR_SIZE == 1
std::vector<const int64_t*>
#else
std::vector<const int*>
#endif
arg_shape_data_ex, out_shape_data_ex, aux_shape_data_ex;
std::vector<const dtype*> arg_shape_data_ex, out_shape_data_ex, aux_shape_data_ex;
/*! \brief uint32_t buffer for returning shape pointer */
std::vector<uint32_t> arg_shape_buffer, out_shape_buffer, aux_shape_buffer;
/*! \brief uint32_t buffer for returning shape pointer */
#if MXNET_USE_INT64_TENSOR_SIZE == 1
std::vector<int64_t>
#else
std::vector<int>
#endif
arg_shape_buffer_ex, out_shape_buffer_ex, aux_shape_buffer_ex;
std::vector<dtype> arg_shape_buffer_ex, out_shape_buffer_ex, aux_shape_buffer_ex;
/*! \brief bool buffer */
std::vector<bool> save_inputs, save_outputs;
// DEPRECATED. Use SetupShapeArrayReturnWithBufferEx instead.
Expand All @@ -118,7 +109,6 @@ struct MXAPIThreadLocalEntry {
}
}
// helper function to setup return value of shape array
template<typename dtype>
inline static void SetupShapeArrayReturnWithBufferEx(
const mxnet::ShapeVector &shapes,
std::vector<int> *ndim,
Expand All @@ -142,11 +132,11 @@ struct MXAPIThreadLocalEntry {
}
}
}

};

// define the threadlocal store.
typedef dmlc::ThreadLocalStore<MXAPIThreadLocalEntry> MXAPIThreadLocalStore;
template<typename dtype = int>
using MXAPIThreadLocalStore = dmlc::ThreadLocalStore<MXAPIThreadLocalEntry<dtype>>;

namespace mxnet {
// copy attributes from inferred vector back to the vector of each type.
Expand Down
12 changes: 6 additions & 6 deletions src/c_api/c_api_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

int MXExecutorPrint(ExecutorHandle handle, const char **out_str) {
Executor *exec = static_cast<Executor*>(handle);
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
std::ostringstream os;
exec->Print(os);
Expand Down Expand Up @@ -78,7 +78,7 @@ int MXExecutorBackwardEx(ExecutorHandle handle,
int MXExecutorOutputs(ExecutorHandle handle,
mx_uint *out_size,
NDArrayHandle **out) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
Executor *exec = static_cast<Executor*>(handle);
std::vector<NDArray> heads = exec->outputs();
Expand Down Expand Up @@ -252,7 +252,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(symbol_handle);

Expand Down Expand Up @@ -586,7 +586,7 @@ int MXExecutorSimpleBindEx(SymbolHandle symbol_handle,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(symbol_handle);

Expand Down Expand Up @@ -870,7 +870,7 @@ int MXExecutorReshape(int partial_shaping,
ExecutorHandle *out) {
Executor* new_exec = nullptr;

MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
*out = nullptr; // ensure we can know whether to free executor on early abort
// create shape map for in_args and aux_states
Expand Down Expand Up @@ -961,7 +961,7 @@ int MXExecutorReshapeEx(int partial_shaping,
ExecutorHandle *out) {
Executor* new_exec = nullptr;

MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
*out = nullptr; // ensure we can know whether to free executor on early abort
// create shape map for in_args and aux_states
Expand Down
10 changes: 5 additions & 5 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ void MXImperativeInvokeImpl(AtomicSymbolCreator creator,
const char **param_keys,
const char **param_vals) {
const nnvm::Op* op = static_cast<nnvm::Op*>(creator);
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();

nnvm::NodeAttrs attrs = imperative::ParseAttrs(op, num_inputs, num_params,
param_keys, param_vals);
Expand Down Expand Up @@ -138,7 +138,7 @@ int MXImperativeInvokeEx(AtomicSymbolCreator creator,
const char **param_keys,
const char **param_vals,
const int **out_stypes) { // outputs storage types
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
MXImperativeInvokeImpl(creator, num_inputs, inputs, num_outputs, outputs,
num_params, param_keys, param_vals);
Expand Down Expand Up @@ -194,7 +194,7 @@ int MXInvokeCachedOp(CachedOpHandle handle,
NDArrayHandle *inputs,
int *num_outputs,
NDArrayHandle **outputs) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();

API_BEGIN();
CachedOpPtr op = *static_cast<CachedOpPtr*>(handle);
Expand Down Expand Up @@ -238,7 +238,7 @@ int MXInvokeCachedOpEx(CachedOpHandle handle,
int *num_outputs,
NDArrayHandle **outputs,
const int **out_stypes) { // outputs storage types
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
int err = MXInvokeCachedOp(handle, num_inputs, inputs, num_outputs, outputs);
if (err != 0) return err;
API_BEGIN();
Expand Down Expand Up @@ -331,7 +331,7 @@ int MXAutogradBackwardEx(mx_uint num_output,
int is_train,
NDArrayHandle **grad_handles,
int **grad_stypes) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();

std::vector<NDArray*> outputs, ograds, variables;
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_profile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ int MXAggregateProfileStatsPrint(const char **out_str, int reset) {

int MXAggregateProfileStatsPrintEx(const char **out_str, int reset, int format, int sort_by,
int ascending) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
CHECK_NOTNULL(out_str);
profiler::Profiler *profiler = profiler::Profiler::Get();
Expand Down
Loading

0 comments on commit aeb995a

Please sign in to comment.