Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
caching results of runtime features and minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Aug 12, 2019
1 parent a51114d commit 24ea3c1
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 279 deletions.
234 changes: 117 additions & 117 deletions include/mxnet/c_api.h

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ class Tuple {
}
};


/*! brief check if a shape's ndim is known. */
inline bool ndim_is_known(const int ndim) {
CHECK_GE(ndim, -1) << "shape ndim must be >= -1, while received " << ndim;
Expand Down
16 changes: 12 additions & 4 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@
_NDARRAY_BASIC_INDEXING = 0
_NDARRAY_ADVANCED_INDEXING = 1

# Caching whether MXNet was built with INT64 support or not
_INT64_TENSOR_SIZE_ENABLED = None

def int64_enabled():
global _INT64_TENSOR_SIZE_ENABLED
if _INT64_TENSOR_SIZE_ENABLED is None:
_INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE')
return _INT64_TENSOR_SIZE_ENABLED

def _new_empty_handle():
"""Returns a new empty handle.
Expand Down Expand Up @@ -134,8 +142,8 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
A new empty `NDArray` handle.
"""
hdl = NDArrayHandle()
if Features().is_enabled('INT64_TENSOR_SIZE') and sys.version_info[0] > 2:
check_call(_LIB.MXNDArrayCreateExInt64(
if sys.version_info[0] > 2 and int64_enabled():
check_call(_LIB.MXNDArrayCreateEx64(
c_array_buf(mx_int64, native_array('q', shape)),
ctypes.c_int(len(shape)),
ctypes.c_int(ctx.device_typeid),
Expand Down Expand Up @@ -2230,9 +2238,9 @@ def shape(self):
(2L, 3L, 4L)
"""
ndim = mx_int()
if sys.version_info[0] > 2 and Features().is_enabled('INT64_TENSOR_SIZE'):
if _INT64_TENSOR_SIZE_ENABLED:
pdata = ctypes.POINTER(mx_int64)()
check_call(_LIB.MXNDArrayGetShapeExInt64(
check_call(_LIB.MXNDArrayGetShapeEx64(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
else:
pdata = ctypes.POINTER(mx_int)()
Expand Down
9 changes: 4 additions & 5 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@
from ..base import check_call, MXNetError, NotImplementedForSymbol
from ..context import Context, current_context
from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP
from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, int64_enabled
from ..ndarray import _ndarray_cls
from ..runtime import Features
from ..executor import Executor
from . import _internal
from . import op
Expand Down Expand Up @@ -1213,14 +1212,14 @@ def _infer_shape_impl(self, partial, *args, **kwargs):
aux_shape_size = mx_uint()
aux_shape_ndim = ctypes.POINTER(mx_int)()
complete = ctypes.c_int()
if Features().is_enabled('INT64_TENSOR_SIZE') and sys.version_info[0] > 2:
if sys.version_info[0] > 2 and int64_enabled():
arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))()
out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))()
aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))()
if partial:
infer_func = _LIB.MXSymbolInferShapePartialExInt64
infer_func = _LIB.MXSymbolInferShapePartialEx64
else:
infer_func = _LIB.MXSymbolInferShapeExInt64
infer_func = _LIB.MXSymbolInferShapeEx64
check_call(infer_func(
self.handle,
mx_uint(len(indptr) - 1),
Expand Down
39 changes: 20 additions & 19 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,13 @@ int MXNDArrayCreateNone(NDArrayHandle *out) {
}

template<typename DataType, typename dimtype>
void CreateArray(const DataType* shape, dimtype ndim, int dev_type, int dev_id, int delay_alloc,
int dtype, NDArrayHandle* out) {
void CreateNDArray(const DataType* shape,
dimtype ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
NDArrayHandle* out) {
*out = new NDArray(mxnet::TShape(shape, shape + ndim),
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
delay_alloc != 0, dtype);
Expand All @@ -210,15 +215,15 @@ int MXNDArrayCreate(const mx_uint *shape,
API_END();
}

int MXNDArrayCreateExInt64(const mx_int64 *shape,
int ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
NDArrayHandle *out) {
int MXNDArrayCreateEx64(const mx_int64 *shape,
int ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
NDArrayHandle *out) {
API_BEGIN();
CreateArray<mx_int64, int>(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out);
CreateNDArray<mx_int64, int>(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out);
API_END();
}

Expand All @@ -230,11 +235,7 @@ int MXNDArrayCreateEx(const mx_uint *shape,
int dtype,
NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray(
mxnet::TShape(shape, shape + ndim),
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
delay_alloc != 0,
dtype);
CreateNDArray<mx_uint, mx_uint>(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out);
API_END();
}

Expand Down Expand Up @@ -558,7 +559,7 @@ int MXNDArrayGetShape(NDArrayHandle handle,

template<typename dtype>
inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim,
MXAPIThreadLocalEntry<dtype>* ret) {
MXAPIThreadLocalEntry<dtype>* ret) {
NDArray* arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
mxnet::TShape s = arr->shape();
Expand Down Expand Up @@ -590,9 +591,9 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle,
API_END();
}

int MXNDArrayGetShapeExInt64(NDArrayHandle handle,
int *out_dim,
const mx_int64 **out_pdata) {
int MXNDArrayGetShapeEx64(NDArrayHandle handle,
int *out_dim,
const mx_int64 **out_pdata) {
MXAPIThreadLocalEntry<int64_t> *ret = MXAPIThreadLocalStore<int64_t>::Get();
API_BEGIN();
GetShape<mx_int64>(handle, out_pdata, out_dim, ret);
Expand Down
Loading

0 comments on commit 24ea3c1

Please sign in to comment.