Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fixing local caching in ndarray shape
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Aug 13, 2019
1 parent 24ea3c1 commit 1d4f263
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
# Caching whether MXNet was built with INT64 support or not
_INT64_TENSOR_SIZE_ENABLED = None

def int64_enabled():
def _int64_enabled():
global _INT64_TENSOR_SIZE_ENABLED
if _INT64_TENSOR_SIZE_ENABLED is None:
_INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE')
Expand Down Expand Up @@ -142,7 +142,7 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
A new empty `NDArray` handle.
"""
hdl = NDArrayHandle()
if sys.version_info[0] > 2 and int64_enabled():
if sys.version_info[0] > 2 and _int64_enabled():
check_call(_LIB.MXNDArrayCreateEx64(
c_array_buf(mx_int64, native_array('q', shape)),
ctypes.c_int(len(shape)),
Expand Down Expand Up @@ -2238,7 +2238,7 @@ def shape(self):
(2L, 3L, 4L)
"""
ndim = mx_int()
if _INT64_TENSOR_SIZE_ENABLED:
if _int64_enabled():
pdata = ctypes.POINTER(mx_int64)()
check_call(_LIB.MXNDArrayGetShapeEx64(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from ..base import check_call, MXNetError, NotImplementedForSymbol
from ..context import Context, current_context
from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP
from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, int64_enabled
from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, _int64_enabled
from ..ndarray import _ndarray_cls
from ..executor import Executor
from . import _internal
Expand Down Expand Up @@ -1212,7 +1212,7 @@ def _infer_shape_impl(self, partial, *args, **kwargs):
aux_shape_size = mx_uint()
aux_shape_ndim = ctypes.POINTER(mx_int)()
complete = ctypes.c_int()
if sys.version_info[0] > 2 and int64_enabled():
if sys.version_info[0] > 2 and _int64_enabled():
arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))()
out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))()
aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))()
Expand Down

0 comments on commit 1d4f263

Please sign in to comment.