From 1d4f26390b3c57024ccbed09abf130845ff3395a Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Mon, 12 Aug 2019 23:40:41 +0000 Subject: [PATCH] fixing local caching in ndarray shape --- python/mxnet/ndarray/ndarray.py | 6 +++--- python/mxnet/symbol/symbol.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 2b5a2d1dd45e..5f03c65a2e79 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -110,7 +110,7 @@ # Caching whether MXNet was built with INT64 support or not _INT64_TENSOR_SIZE_ENABLED = None -def int64_enabled(): +def _int64_enabled(): global _INT64_TENSOR_SIZE_ENABLED if _INT64_TENSOR_SIZE_ENABLED is None: _INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE') @@ -142,7 +142,7 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t): A new empty `NDArray` handle. """ hdl = NDArrayHandle() - if sys.version_info[0] > 2 and int64_enabled(): + if sys.version_info[0] > 2 and _int64_enabled(): check_call(_LIB.MXNDArrayCreateEx64( c_array_buf(mx_int64, native_array('q', shape)), ctypes.c_int(len(shape)), @@ -2238,7 +2238,7 @@ def shape(self): (2L, 3L, 4L) """ ndim = mx_int() - if _INT64_TENSOR_SIZE_ENABLED: + if _int64_enabled(): pdata = ctypes.POINTER(mx_int64)() check_call(_LIB.MXNDArrayGetShapeEx64( self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index fc798263d52f..3ac44176a87b 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -39,7 +39,7 @@ from ..base import check_call, MXNetError, NotImplementedForSymbol from ..context import Context, current_context from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP -from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, int64_enabled +from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, _int64_enabled from ..ndarray import _ndarray_cls from ..executor import Executor from . import _internal @@ -1212,7 +1212,7 @@ def _infer_shape_impl(self, partial, *args, **kwargs): aux_shape_size = mx_uint() aux_shape_ndim = ctypes.POINTER(mx_int)() complete = ctypes.c_int() - if sys.version_info[0] > 2 and int64_enabled(): + if sys.version_info[0] > 2 and _int64_enabled(): arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))() out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))() aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))()