diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 20b2aa2d5c9b..5ab10b6b2204 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -55,7 +55,9 @@ extern "C" { #endif /*! \brief manually define unsigned int */ -typedef unsigned int mx_uint; +typedef uint32_t mx_uint; +/*! \brief manually define 64-bit int */ +typedef int64_t mx_int64; /*! \brief manually define float */ typedef float mx_float; /*! \brief data type to store dim size */ @@ -572,6 +574,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape, int dtype, NDArrayHandle *out); +MXNET_DLL int MXNDArrayCreateEx64(const mx_int64 *shape, + int ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle *out); /*! * \brief create an empty sparse NDArray with specified shape and data type @@ -603,6 +612,19 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type, const mx_uint *aux_shape, NDArrayHandle *out); +MXNET_DLL int MXNDArrayCreateSparseEx64(int storage_type, + const mx_int64 *shape, + int ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + int *aux_ndims, + const mx_int64 *aux_shape, + NDArrayHandle *out); + /*! * \brief create a NDArray handle that is loaded from raw bytes. * \param buf the head of the raw bytes @@ -650,6 +672,12 @@ MXNET_DLL int MXNDArrayLoad(const char* fname, mx_uint *out_name_size, const char*** out_names); +MXNET_DLL int MXNDArrayLoad64(const char* fname, + mx_int64 *out_size, + NDArrayHandle** out_arr, + mx_int64 *out_name_size, + const char*** out_names); + /*! * \brief Load list / dictionary of narrays from file content loaded into memory. * This will load a list of ndarrays in a similar @@ -665,11 +693,18 @@ MXNET_DLL int MXNDArrayLoad(const char* fname, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer, - size_t size, - mx_uint *out_size, - NDArrayHandle** out_arr, - mx_uint *out_name_size, - const char*** out_names); + size_t size, + mx_uint *out_size, + NDArrayHandle** out_arr, + mx_uint *out_name_size, + const char*** out_names); + +MXNET_DLL int MXNDArrayLoadFromBuffer64(const void *ndarray_buffer, + size_t size, + mx_int64 *out_size, + NDArrayHandle** out_arr, + mx_int64 *out_name_size, + const char*** out_names); /*! * \brief Perform a synchronize copy from a continugous CPU memory region. @@ -809,6 +844,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle, MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata); + +MXNET_DLL int MXNDArrayGetShape64(NDArrayHandle handle, + int *out_dim, + const int64_t **out_pdata); + /*! * \brief get the shape of the array * \param handle the handle to the narray @@ -819,6 +859,11 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle, int *out_dim, const int **out_pdata); + +MXNET_DLL int MXNDArrayGetShapeEx64(NDArrayHandle handle, + int *out_dim, + const mx_int64 **out_pdata); + /*! * \brief get the content of the data in NDArray * \param handle the handle to the ndarray @@ -902,6 +947,10 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle, mx_uint i, int *out_type); +MXNET_DLL int MXNDArrayGetAuxType64(NDArrayHandle handle, + mx_int64 i, + int *out_type); + /*! * \brief Get a deep copy of the ith aux data blob * in the form of an NDArray of default storage type. @@ -911,6 +960,10 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle, mx_uint i, NDArrayHandle *out); +MXNET_DLL int MXNDArrayGetAuxNDArray64(NDArrayHandle handle, + mx_int64 i, + NDArrayHandle *out); + /*! * \brief Get a deep copy of the data blob * in the form of an NDArray of default storage type. @@ -966,6 +1019,10 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out); */ MXNET_DLL int MXListFunctions(mx_uint *out_size, FunctionHandle **out_array); + +MXNET_DLL int MXListFunctions64(mx_int64 *out_size, + FunctionHandle **out_array); + /*! * \brief get the function handle by name * \param name the name of the function @@ -1233,6 +1290,10 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle, */ MXNET_DLL int MXListAllOpNames(mx_uint *out_size, const char ***out_array); + +MXNET_DLL int MXListAllOpNames64(mx_int64 *out_size, + const char ***out_array); + /*! * \brief list all the available AtomicSymbolEntry * \param out_size the size of returned array @@ -1242,6 +1303,9 @@ MXNET_DLL int MXListAllOpNames(mx_uint *out_size, MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, AtomicSymbolCreator **out_array); +MXNET_DLL int MXSymbolListAtomicSymbolCreators64(mx_int64 *out_size, + AtomicSymbolCreator **out_array); + /*! * \brief Get the name of an atomic symbol. * \param creator the AtomicSymbolCreator. @@ -1454,6 +1518,11 @@ MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol, MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); + +MXNET_DLL int MXSymbolListArguments64(SymbolHandle symbol, + size_t *out_size, + const char ***out_str_array); + /*! * \brief List returns in the symbol. * \param symbol the symbol @@ -1465,6 +1534,10 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); +MXNET_DLL int MXSymbolListOutputs64(SymbolHandle symbol, + size_t *out_size, + const char ***out_str_array); + /*! * \brief Get number of outputs of the symbol. * \param symbol The symbol @@ -1472,7 +1545,7 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol, - mx_uint *output_count); + mx_uint *output_count); /*! * \brief Get a symbol that contains all the internals. @@ -1511,6 +1584,11 @@ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol, MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); + +MXNET_DLL int MXSymbolListAuxiliaryStates64(SymbolHandle symbol, + size_t *out_size, + const char ***out_str_array); + /*! * \brief Compose the symbol on other symbols. * @@ -1582,6 +1660,22 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, const mx_uint ***aux_shape_data, int *complete); +MXNET_DLL int MXSymbolInferShape64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_int64 *arg_ind_ptr, + const mx_int64 *arg_shape_data, + size_t *in_shape_size, + const int **in_shape_ndim, + const mx_int64 ***in_shape_data, + size_t *out_shape_size, + const int **out_shape_ndim, + const mx_int64 ***out_shape_data, + size_t *aux_shape_size, + const int **aux_shape_ndim, + const mx_int64 ***aux_shape_data, + int *complete); + /*! * \brief infer shape of unknown input shapes given the known one. * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data @@ -1619,6 +1713,23 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym, const int **aux_shape_ndim, const int ***aux_shape_data, int *complete); + +MXNET_DLL int MXSymbolInferShapeEx64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_int64 *arg_ind_ptr, + const mx_int64 *arg_shape_data, + size_t *in_shape_size, + const int **in_shape_ndim, + const mx_int64 ***in_shape_data, + size_t *out_shape_size, + const int **out_shape_ndim, + const mx_int64 ***out_shape_data, + size_t *aux_shape_size, + const int **aux_shape_ndim, + const mx_int64 ***aux_shape_data, + int *complete); + /*! * \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead. * partially infer shape of unknown input shapes given the known one. @@ -1660,6 +1771,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym, const mx_uint ***aux_shape_data, int *complete); +MXNET_DLL int MXSymbolInferShapePartial64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_int64 *arg_ind_ptr, + const mx_int64 *arg_shape_data, + size_t *in_shape_size, + const int **in_shape_ndim, + const mx_int64 ***in_shape_data, + size_t *out_shape_size, + const int **out_shape_ndim, + const mx_int64 ***out_shape_data, + size_t *aux_shape_size, + const int **aux_shape_ndim, + const mx_int64 ***aux_shape_data, + int *complete); /*! * \brief partially infer shape of unknown input shapes given the known one. @@ -1701,6 +1827,22 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym, const int ***aux_shape_data, int *complete); +MXNET_DLL int MXSymbolInferShapePartialEx64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_int64 *arg_ind_ptr, + const mx_int64 *arg_shape_data, + size_t *in_shape_size, + const int **in_shape_ndim, + const mx_int64 ***in_shape_data, + size_t *out_shape_size, + const int **out_shape_ndim, + const mx_int64 ***out_shape_data, + size_t *aux_shape_size, + const int **aux_shape_ndim, + const mx_int64 ***aux_shape_data, + int *complete); + /*! * \brief infer type of unknown input types given the known one. * The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data diff --git a/include/mxnet/c_predict_api.h b/include/mxnet/c_predict_api.h index 18bec625f05f..c79baa4a86ff 100644 --- a/include/mxnet/c_predict_api.h +++ b/include/mxnet/c_predict_api.h @@ -42,7 +42,9 @@ extern "C" { #endif /*! \brief manually define unsigned int */ -typedef unsigned int mx_uint; +typedef uint32_t mx_uint; +/*! \brief manually define 64-bit int */ +typedef int64_t mx_int64; /*! \brief manually define float */ typedef float mx_float; /*! \brief handle to Predictor */ diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 848c36fb4d89..dd5fcf0f6db9 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -215,6 +215,7 @@ def _load_lib(): # type definitions mx_int = ctypes.c_int mx_uint = ctypes.c_uint +mx_int64 = ctypes.c_int64 mx_float = ctypes.c_float mx_float_p = ctypes.POINTER(mx_float) mx_real_t = _np.float32 diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 171ba0a5008c..5f03c65a2e79 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -33,11 +33,13 @@ import warnings import operator from functools import reduce # pylint: disable=redefined-builtin +import sys import numpy as np from ..base import _LIB, numeric_types, integer_types from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t -from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int +from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int, mx_int64 from ..base import ctypes2buffer +from ..runtime import Features from ..context import Context, current_context from . import _internal from . import op @@ -105,6 +107,14 @@ _NDARRAY_BASIC_INDEXING = 0 _NDARRAY_ADVANCED_INDEXING = 1 +# Caching whether MXNet was built with INT64 support or not +_INT64_TENSOR_SIZE_ENABLED = None + +def _int64_enabled(): + global _INT64_TENSOR_SIZE_ENABLED + if _INT64_TENSOR_SIZE_ENABLED is None: + _INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE') + return _INT64_TENSOR_SIZE_ENABLED def _new_empty_handle(): """Returns a new empty handle. @@ -132,14 +142,24 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t): A new empty `NDArray` handle. """ hdl = NDArrayHandle() - check_call(_LIB.MXNDArrayCreateEx( - c_array_buf(mx_uint, native_array('I', shape)), - mx_uint(len(shape)), - ctypes.c_int(ctx.device_typeid), - ctypes.c_int(ctx.device_id), - ctypes.c_int(int(delay_alloc)), - ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])), - ctypes.byref(hdl))) + if sys.version_info[0] > 2 and _int64_enabled(): + check_call(_LIB.MXNDArrayCreateEx64( + c_array_buf(mx_int64, native_array('q', shape)), + ctypes.c_int(len(shape)), + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + ctypes.c_int(int(delay_alloc)), + ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])), + ctypes.byref(hdl))) + else: + check_call(_LIB.MXNDArrayCreateEx( + c_array_buf(mx_uint, native_array('I', shape)), + mx_uint(len(shape)), + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + ctypes.c_int(int(delay_alloc)), + ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])), + ctypes.byref(hdl))) return hdl @@ -2218,9 +2238,14 @@ def shape(self): (2L, 3L, 4L) """ ndim = mx_int() - pdata = ctypes.POINTER(mx_int)() - check_call(_LIB.MXNDArrayGetShapeEx( - self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) + if _int64_enabled(): + pdata = ctypes.POINTER(mx_int64)() + check_call(_LIB.MXNDArrayGetShapeEx64( + self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) + else: + pdata = ctypes.POINTER(mx_int)() + check_call(_LIB.MXNDArrayGetShapeEx( + self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) if ndim.value == -1: return None else: diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 542b3796a9fe..3ac44176a87b 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -29,17 +29,17 @@ import ctypes import warnings from numbers import Number - +import sys import numpy as _numpy # pylint: disable=relative-import from ..attribute import AttrScope from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str, c_str_array, c_handle_array -from ..base import mx_uint, py_str, string_types, integer_types, mx_int +from ..base import mx_uint, py_str, string_types, integer_types, mx_int, mx_int64 from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle from ..base import check_call, MXNetError, NotImplementedForSymbol from ..context import Context, current_context from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP -from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID +from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, _int64_enabled from ..ndarray import _ndarray_cls from ..executor import Executor from . import _internal @@ -1207,34 +1207,59 @@ def _infer_shape_impl(self, partial, *args, **kwargs): keys = c_str_array(str_keys) arg_shape_size = mx_uint() arg_shape_ndim = ctypes.POINTER(mx_int)() - arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() out_shape_size = mx_uint() out_shape_ndim = ctypes.POINTER(mx_int)() - out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() aux_shape_size = mx_uint() aux_shape_ndim = ctypes.POINTER(mx_int)() - aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() complete = ctypes.c_int() - if partial: - infer_func = _LIB.MXSymbolInferShapePartialEx + if sys.version_info[0] > 2 and _int64_enabled(): + arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))() + out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))() + aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))() + if partial: + infer_func = _LIB.MXSymbolInferShapePartialEx64 + else: + infer_func = _LIB.MXSymbolInferShapeEx64 + check_call(infer_func( + self.handle, + mx_uint(len(indptr) - 1), + keys, + c_array_buf(mx_int64, array('q', indptr)), + c_array_buf(mx_int64, array('q', sdata)), + ctypes.byref(arg_shape_size), + ctypes.byref(arg_shape_ndim), + ctypes.byref(arg_shape_data), + ctypes.byref(out_shape_size), + ctypes.byref(out_shape_ndim), + ctypes.byref(out_shape_data), + ctypes.byref(aux_shape_size), + ctypes.byref(aux_shape_ndim), + ctypes.byref(aux_shape_data), + ctypes.byref(complete))) else: - infer_func = _LIB.MXSymbolInferShapeEx - check_call(infer_func( - self.handle, - mx_uint(len(indptr) - 1), - keys, - c_array_buf(mx_uint, array('I', indptr)), - c_array_buf(mx_int, array('i', sdata)), - ctypes.byref(arg_shape_size), - ctypes.byref(arg_shape_ndim), - ctypes.byref(arg_shape_data), - ctypes.byref(out_shape_size), - ctypes.byref(out_shape_ndim), - ctypes.byref(out_shape_data), - ctypes.byref(aux_shape_size), - ctypes.byref(aux_shape_ndim), - ctypes.byref(aux_shape_data), - ctypes.byref(complete))) + arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() + out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() + aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() + if partial: + infer_func = _LIB.MXSymbolInferShapePartialEx + else: + infer_func = _LIB.MXSymbolInferShapeEx + check_call(infer_func( + self.handle, + mx_uint(len(indptr) - 1), + keys, + c_array_buf(mx_uint, array('I', indptr)), + c_array_buf(mx_int, array('i', sdata)), + ctypes.byref(arg_shape_size), + ctypes.byref(arg_shape_ndim), + ctypes.byref(arg_shape_data), + ctypes.byref(out_shape_size), + ctypes.byref(out_shape_ndim), + ctypes.byref(out_shape_data), + ctypes.byref(aux_shape_size), + ctypes.byref(aux_shape_ndim), + ctypes.byref(aux_shape_data), + ctypes.byref(complete))) if complete.value != 0: arg_shapes = [tuple(arg_shape_data[i][:arg_shape_ndim[i]]) if arg_shape_ndim[i] >= 0 else None diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 17648458ba22..c2b80b3f601c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -67,7 +67,7 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, const char ***arg_type_infos, const char ***arg_descriptions, const char **return_type) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); *name = e->name.c_str(); @@ -189,6 +189,19 @@ int MXNDArrayCreateNone(NDArrayHandle *out) { API_END(); } +template +void CreateNDArray(const DataType* shape, + dimtype ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle* out) { + *out = new NDArray(mxnet::TShape(shape, shape + ndim), + Context::Create(static_cast(dev_type), dev_id), + delay_alloc != 0, dtype); +} + int MXNDArrayCreate(const mx_uint *shape, mx_uint ndim, int dev_type, @@ -196,41 +209,48 @@ int MXNDArrayCreate(const mx_uint *shape, int delay_alloc, NDArrayHandle *out) { API_BEGIN(); - *out = new NDArray( - mxnet::TShape(shape, shape + ndim), - Context::Create(static_cast(dev_type), dev_id), - delay_alloc != 0); + *out = new NDArray(mxnet::TShape(shape, shape + ndim), + Context::Create(static_cast(dev_type), dev_id), + delay_alloc != 0); + API_END(); +} + +int MXNDArrayCreateEx64(const mx_int64 *shape, + int ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle *out) { + API_BEGIN(); + CreateNDArray(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out); API_END(); } int MXNDArrayCreateEx(const mx_uint *shape, - mx_uint ndim, - int dev_type, - int dev_id, - int delay_alloc, - int dtype, - NDArrayHandle *out) { + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle *out) { API_BEGIN(); - *out = new NDArray( - mxnet::TShape(shape, shape + ndim), - Context::Create(static_cast(dev_type), dev_id), - delay_alloc != 0, - dtype); + CreateNDArray(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out); API_END(); } int MXNDArrayCreateSparseEx(int storage_type, - const mx_uint *shape, - mx_uint ndim, - int dev_type, - int dev_id, - int delay_alloc, - int dtype, - mx_uint num_aux, - int *aux_type, - mx_uint *aux_ndims, - const mx_uint *aux_shape, - NDArrayHandle *out) { + const mx_uint *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + mx_uint *aux_ndims, + const mx_uint *aux_shape, + NDArrayHandle *out) { API_BEGIN(); std::vector aux_types; mxnet::ShapeVector aux_shapes; @@ -269,7 +289,7 @@ int MXNDArrayLoadFromRawBytes(const void *buf, int MXNDArraySaveRawBytes(NDArrayHandle handle, size_t *out_size, const char **out_buf) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); ret->ret_str.resize(0); dmlc::MemoryStringStream strm(&ret->ret_str); @@ -365,7 +385,7 @@ int MXNDArrayLoad(const char* fname, NDArrayHandle** out_arr, mx_uint *out_name_size, const char*** out_names) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); ret->ret_vec_str.clear(); API_BEGIN(); std::vector data; @@ -397,7 +417,7 @@ int MXNDArrayLoadFromBuffer(const void *ndarray_buffer, NDArrayHandle** out_arr, mx_uint *out_name_size, const char*** out_names) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); ret->ret_vec_str.clear(); API_BEGIN(); CHECK_NOTNULL(ndarray_buffer); @@ -521,7 +541,7 @@ int MXNDArrayGetStorageType(NDArrayHandle handle, int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); NDArray *arr = static_cast(handle); if (!arr->is_none()) { @@ -537,12 +557,10 @@ int MXNDArrayGetShape(NDArrayHandle handle, API_END(); } -int MXNDArrayGetShapeEx(NDArrayHandle handle, - int *out_dim, - const int **out_pdata) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); - API_BEGIN(); - NDArray *arr = static_cast(handle); +template +inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim, + MXAPIThreadLocalEntry* ret) { + NDArray* arr = static_cast(handle); if (!arr->is_none()) { mxnet::TShape s = arr->shape(); if (!Imperative::Get()->is_np_shape()) { @@ -550,7 +568,7 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle, } *out_dim = s.ndim(); if (s.ndim() >= 0) { - std::vector &buffer = ret->arg_shape_buffer_ex; + std::vector &buffer = ret->arg_shape_buffer_ex; buffer.resize(s.ndim()); mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data()); *out_pdata = buffer.data(); @@ -562,6 +580,23 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle, *out_dim = 0; } } +} + +int MXNDArrayGetShapeEx(NDArrayHandle handle, + int *out_dim, + const int **out_pdata) { + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + API_BEGIN(); + GetShape(handle, out_pdata, out_dim, ret); + API_END(); +} + +int MXNDArrayGetShapeEx64(NDArrayHandle handle, + int *out_dim, + const mx_int64 **out_pdata) { + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + GetShape(handle, out_pdata, out_dim, ret); API_END(); } diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index 233acc85f36b..93fcff09a8e6 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -57,6 +57,7 @@ using namespace mxnet; /*! \brief entry to to easily hold returning information */ +template struct MXAPIThreadLocalEntry { /*! \brief result holder for returning string */ std::string ret_str; @@ -81,11 +82,11 @@ struct MXAPIThreadLocalEntry { /*! \brief result holder for returning shape pointer */ std::vector arg_shape_data, out_shape_data, aux_shape_data; /*! \brief result holder for returning shape pointer */ - std::vector arg_shape_data_ex, out_shape_data_ex, aux_shape_data_ex; + std::vector arg_shape_data_ex, out_shape_data_ex, aux_shape_data_ex; /*! \brief uint32_t buffer for returning shape pointer */ std::vector arg_shape_buffer, out_shape_buffer, aux_shape_buffer; /*! \brief uint32_t buffer for returning shape pointer */ - std::vector arg_shape_buffer_ex, out_shape_buffer_ex, aux_shape_buffer_ex; + std::vector arg_shape_buffer_ex, out_shape_buffer_ex, aux_shape_buffer_ex; /*! \brief bool buffer */ std::vector save_inputs, save_outputs; // DEPRECATED. Use SetupShapeArrayReturnWithBufferEx instead. @@ -111,8 +112,8 @@ struct MXAPIThreadLocalEntry { inline static void SetupShapeArrayReturnWithBufferEx( const mxnet::ShapeVector &shapes, std::vector *ndim, - std::vector *data, - std::vector *buffer) { + std::vector *data, + std::vector *buffer) { ndim->resize(shapes.size()); data->resize(shapes.size()); size_t size = 0; @@ -122,7 +123,7 @@ struct MXAPIThreadLocalEntry { } } buffer->resize(size); - int *ptr = buffer->data(); + dtype* ptr = buffer->data(); for (size_t i = 0; i < shapes.size(); ++i) { ndim->at(i) = shapes[i].ndim(); data->at(i) = ptr; @@ -134,7 +135,8 @@ struct MXAPIThreadLocalEntry { }; // define the threadlocal store. -typedef dmlc::ThreadLocalStore MXAPIThreadLocalStore; +template +using MXAPIThreadLocalStore = dmlc::ThreadLocalStore>; namespace mxnet { // copy attributes from inferred vector back to the vector of each type. diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index ebe3f17d7d90..31b74b503f80 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -32,7 +32,7 @@ int MXExecutorPrint(ExecutorHandle handle, const char **out_str) { Executor *exec = static_cast(handle); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); std::ostringstream os; exec->Print(os); @@ -78,7 +78,7 @@ int MXExecutorBackwardEx(ExecutorHandle handle, int MXExecutorOutputs(ExecutorHandle handle, mx_uint *out_size, NDArrayHandle **out) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); Executor *exec = static_cast(handle); std::vector heads = exec->outputs(); @@ -252,7 +252,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, NDArrayHandle** aux_states, ExecutorHandle shared_exec_handle, ExecutorHandle* out) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); nnvm::Symbol *sym = static_cast(symbol_handle); @@ -586,7 +586,7 @@ int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, NDArrayHandle** aux_states, ExecutorHandle shared_exec_handle, ExecutorHandle* out) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); nnvm::Symbol *sym = static_cast(symbol_handle); @@ -870,7 +870,7 @@ int MXExecutorReshape(int partial_shaping, ExecutorHandle *out) { Executor* new_exec = nullptr; - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); *out = nullptr; // ensure we can know whether to free executor on early abort // create shape map for in_args and aux_states @@ -961,7 +961,7 @@ int MXExecutorReshapeEx(int partial_shaping, ExecutorHandle *out) { Executor* new_exec = nullptr; - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); *out = nullptr; // ensure we can know whether to free executor on early abort // create shape map for in_args and aux_states diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index c9c6000e2f6f..4546659ca64e 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -87,7 +87,7 @@ void MXImperativeInvokeImpl(AtomicSymbolCreator creator, const char **param_keys, const char **param_vals) { const nnvm::Op* op = static_cast(creator); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); nnvm::NodeAttrs attrs = imperative::ParseAttrs(op, num_inputs, num_params, param_keys, param_vals); @@ -138,7 +138,7 @@ int MXImperativeInvokeEx(AtomicSymbolCreator creator, const char **param_keys, const char **param_vals, const int **out_stypes) { // outputs storage types - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); MXImperativeInvokeImpl(creator, num_inputs, inputs, num_outputs, outputs, num_params, param_keys, param_vals); @@ -194,7 +194,7 @@ int MXInvokeCachedOp(CachedOpHandle handle, NDArrayHandle *inputs, int *num_outputs, NDArrayHandle **outputs) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); CachedOpPtr op = *static_cast(handle); @@ -238,7 +238,7 @@ int MXInvokeCachedOpEx(CachedOpHandle handle, int *num_outputs, NDArrayHandle **outputs, const int **out_stypes) { // outputs storage types - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); int err = MXInvokeCachedOp(handle, num_inputs, inputs, num_outputs, outputs); if (err != 0) return err; API_BEGIN(); @@ -331,7 +331,7 @@ int MXAutogradBackwardEx(mx_uint num_output, int is_train, NDArrayHandle **grad_handles, int **grad_stypes) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); std::vector outputs, ograds, variables; diff --git a/src/c_api/c_api_profile.cc b/src/c_api/c_api_profile.cc index cec70288b126..5eb219a4e525 100644 --- a/src/c_api/c_api_profile.cc +++ b/src/c_api/c_api_profile.cc @@ -312,7 +312,7 @@ int MXAggregateProfileStatsPrint(const char **out_str, int reset) { int MXAggregateProfileStatsPrintEx(const char **out_str, int reset, int format, int sort_by, int ascending) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); CHECK_NOTNULL(out_str); profiler::Profiler *profiler = profiler::Profiler::Get(); diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 020c0d17f0d1..e8d59d90e99b 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -98,7 +98,7 @@ int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, const char **return_type) { static auto& map_key_var_args = nnvm::Op::GetAttr("key_var_num_args"); const Op* op = static_cast(creator); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); ret->ret_str.resize(0); if (map_key_var_args.count(op) != 0) { @@ -203,7 +203,7 @@ int MXSymbolGetAttr(SymbolHandle symbol, const char** out, int* success) { nnvm::Symbol *s = static_cast(symbol); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); if (s->GetAttr(key, &(ret->ret_str))) { *out = (ret->ret_str).c_str(); @@ -251,7 +251,7 @@ int MXSymbolListAttr(SymbolHandle symbol, mx_uint *out_size, const char*** out) { nnvm::Symbol *s = static_cast(symbol); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); std::vector > attr = s->ListAttrsRecursive(); @@ -281,7 +281,7 @@ int MXSymbolListAttrShallow(SymbolHandle symbol, mx_uint *out_size, const char*** out) { nnvm::Symbol *s = static_cast(symbol); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); std::unordered_map attr = s->ListAttrs(static_cast(1)); // NOLINT(*) @@ -360,7 +360,7 @@ int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **input_arr, int *inp std::vector input_syms = mxnet::GetInputSymbols(*s); *input_size = input_syms.size(); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); ret->ret_handles.clear(); ret->ret_handles.reserve(*input_size); for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]); @@ -405,7 +405,7 @@ int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols, } *input_size = input_syms.size(); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); ret->ret_handles.clear(); ret->ret_handles.reserve(*input_size); for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]); @@ -464,7 +464,7 @@ int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname) { int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json) { nnvm::Symbol *s = static_cast(symbol); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); ret->ret_str = nnvm::pass::SaveJSON(Symbol2Graph(*s)); *out_json = ret->ret_str.c_str(); @@ -528,7 +528,7 @@ int MXSymbolInferShape(SymbolHandle sym, const mx_uint ***aux_shape_data, int *complete) { nnvm::Symbol *s = static_cast(sym); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); nnvm::Graph g = Symbol2Graph(*s); mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape()); @@ -565,11 +565,11 @@ int MXSymbolInferShape(SymbolHandle sym, &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); // copy data back - MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBuffer(ret->arg_shapes, + MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->arg_shapes, &(ret->arg_shape_ndim), &(ret->arg_shape_data), &(ret->arg_shape_buffer)); - MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBuffer(ret->out_shapes, + MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->out_shapes, &(ret->out_shape_ndim), &(ret->out_shape_data), &(ret->out_shape_buffer)); - MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBuffer(ret->aux_shapes, + MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->aux_shapes, &(ret->aux_shape_ndim), &(ret->aux_shape_data), &(ret->aux_shape_buffer)); *in_shape_size = static_cast(ret->arg_shapes.size()); *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim); @@ -585,76 +585,149 @@ int MXSymbolInferShape(SymbolHandle sym, API_END(); } -int MXSymbolInferShapeEx(SymbolHandle sym, - mx_uint num_args, - const char** keys, - const mx_uint *arg_ind_ptr, - const int *arg_shape_data, - mx_uint *in_shape_size, - const int **in_shape_ndim, - const int ***in_shape_data, - mx_uint *out_shape_size, - const int **out_shape_ndim, - const int ***out_shape_data, - mx_uint *aux_shape_size, - const int **aux_shape_ndim, - const int ***aux_shape_data, - int *complete) { - nnvm::Symbol *s = static_cast(sym); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); - API_BEGIN(); +template +inline void SymbolInferShape(const char** keys, + mx_uint num_args, + const dtype* arg_shape_data, + const itype* arg_ind_ptr, + const int** in_shape_ndim, + const dtype*** in_shape_data, + const int** out_shape_ndim, + const dtype*** out_shape_data, + const int** aux_shape_ndim, + const dtype*** aux_shape_data, + nnvm::Symbol* s, + MXAPIThreadLocalEntry* ret, + stype* in_shape_size, + stype* out_shape_size, + stype* aux_shape_size, + int* complete) { nnvm::Graph g = Symbol2Graph(*s); mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape()); if (keys == nullptr && num_args != 0) { - std::vector read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); + std::vector < uint32_t > read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); CHECK_LE(num_args, read_only_args.size()); for (mx_uint i = 0; i < num_args; ++i) { - arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast( - arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]); + arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i], + arg_shape_data + arg_ind_ptr[i + 1]); } } else { std::unordered_map kwargs; for (mx_uint i = 0; i < num_args; ++i) { - kwargs[keys[i]] = mxnet::ShapeTypeCast( - arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]); + kwargs[keys[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i], + arg_shape_data + arg_ind_ptr[i + 1]); } mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape"); } - try { g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); - } catch (const mxnet::op::InferShapeError &err) { + } catch (const mxnet::op::InferShapeError& err) { throw dmlc::Error(err.msg); } - // if use legacy shape definition, need to convert numpy shape to legacy shape mxnet::ShapeVector shapes = g.GetAttr("shape"); if (!Imperative::Get()->is_np_shape()) { common::ConvertToLegacyShape(&shapes); } - // copy back - CopyAttr(g.indexed_graph(), shapes, - &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); - + CopyAttr(g.indexed_graph(), shapes, &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); // copy data back - MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->arg_shapes, - &(ret->arg_shape_ndim_ex), &(ret->arg_shape_data_ex), &(ret->arg_shape_buffer_ex)); - MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->out_shapes, - &(ret->out_shape_ndim_ex), &(ret->out_shape_data_ex), &(ret->out_shape_buffer_ex)); - MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->aux_shapes, - &(ret->aux_shape_ndim_ex), &(ret->aux_shape_data_ex), &(ret->aux_shape_buffer_ex)); - *in_shape_size = static_cast(ret->arg_shapes.size()); + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->arg_shapes, + &(ret->arg_shape_ndim_ex), + &(ret->arg_shape_data_ex), + &(ret->arg_shape_buffer_ex)); + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->out_shapes, + &(ret->out_shape_ndim_ex), + &(ret->out_shape_data_ex), + &(ret->out_shape_buffer_ex)); + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->aux_shapes, + &(ret->aux_shape_ndim_ex), + &(ret->aux_shape_data_ex), + &(ret->aux_shape_buffer_ex)); + *in_shape_size = static_cast(ret->arg_shapes.size()); *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim_ex); *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data_ex); - *out_shape_size = static_cast(ret->out_shapes.size()); + *out_shape_size = static_cast(ret->out_shapes.size()); *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim_ex); *out_shape_data = dmlc::BeginPtr(ret->out_shape_data_ex); - *aux_shape_size = static_cast(ret->aux_shapes.size()); + *aux_shape_size = static_cast(ret->aux_shapes.size()); *aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim_ex); *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data_ex); // mark complete *complete = (g.GetAttr("shape_num_unknown_nodes") == 0); +} + +int MXSymbolInferShapeEx(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_uint *arg_ind_ptr, + const int *arg_shape_data, + mx_uint *in_shape_size, + const int **in_shape_ndim, + const int ***in_shape_data, + mx_uint *out_shape_size, + const int **out_shape_ndim, + const int ***out_shape_data, + mx_uint *aux_shape_size, + const int **aux_shape_ndim, + const int ***aux_shape_data, + int *complete) { + nnvm::Symbol *s = static_cast(sym); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + API_BEGIN(); + SymbolInferShape(keys, + num_args, + arg_shape_data, + arg_ind_ptr, + in_shape_ndim, + in_shape_data, + out_shape_ndim, + out_shape_data, + aux_shape_ndim, + aux_shape_data, + s, + ret, + in_shape_size, + out_shape_size, + aux_shape_size, + complete); + API_END(); +} + +int MXSymbolInferShapeEx64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const int64_t *arg_ind_ptr, + const int64_t *arg_shape_data, + size_t *in_shape_size, + const int **in_shape_ndim, + const int64_t ***in_shape_data, + size_t *out_shape_size, + const int **out_shape_ndim, + const int64_t ***out_shape_data, + size_t *aux_shape_size, + const int **aux_shape_ndim, + const int64_t ***aux_shape_data, + int *complete) { + nnvm::Symbol *s = static_cast(sym); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + SymbolInferShape(keys, + num_args, + arg_shape_data, + arg_ind_ptr, + in_shape_ndim, + in_shape_data, + out_shape_ndim, + out_shape_data, + aux_shape_ndim, + aux_shape_data, + s, + ret, + in_shape_size, + out_shape_size, + aux_shape_size, + complete); API_END(); } @@ -673,7 +746,7 @@ int MXSymbolInferShapePartial(SymbolHandle sym, const mx_uint **aux_shape_ndim, const mx_uint ***aux_shape_data, int *complete) { - int succ; + int succ = 0; *complete = 1; return MXSymbolInferShape(sym, num_args, keys, arg_ind_ptr, arg_shape_data, @@ -698,7 +771,7 @@ int MXSymbolInferShapePartialEx(SymbolHandle sym, const int **aux_shape_ndim, const int ***aux_shape_data, int *complete) { - int succ; + int succ = 0; *complete = 1; return MXSymbolInferShapeEx(sym, num_args, keys, arg_ind_ptr, arg_shape_data, @@ -708,6 +781,31 @@ int MXSymbolInferShapePartialEx(SymbolHandle sym, &succ); } +int MXSymbolInferShapePartialEx64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const int64_t *arg_ind_ptr, + const int64_t *arg_shape_data, + size_t *in_shape_size, + const int **in_shape_ndim, + const int64_t ***in_shape_data, + size_t *out_shape_size, + const int **out_shape_ndim, + const int64_t ***out_shape_data, + size_t *aux_shape_size, + const int **aux_shape_ndim, + const int64_t ***aux_shape_data, + int *complete) { + int succ = 0; + *complete = 1; + return MXSymbolInferShapeEx64(sym, num_args, keys, + arg_ind_ptr, arg_shape_data, + in_shape_size, in_shape_ndim, in_shape_data, + out_shape_size, out_shape_ndim, out_shape_data, + aux_shape_size, aux_shape_ndim, aux_shape_data, + &succ); +} + int MXSymbolInferType(SymbolHandle sym, mx_uint num_args, const char** keys, @@ -720,7 +818,7 @@ int MXSymbolInferType(SymbolHandle sym, const int **aux_type_data, int *complete) { nnvm::Symbol *s = static_cast(sym); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); nnvm::Graph g = Symbol2Graph(*s); nnvm::DTypeVector arg_types(g.indexed_graph().input_nodes().size(), -1); @@ -764,7 +862,7 @@ int MXSymbolInferTypePartial(SymbolHandle sym, mx_uint *aux_type_size, const int **aux_type_data, int *complete) { - int succ; + int succ = 0; *complete = 1; return MXSymbolInferType(sym, num_args, keys, arg_type_data, diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index c00021c44d1d..e6a177e27847 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -48,7 +48,7 @@ OpStatePtr Imperative::InvokeOp( using namespace imperative; static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); const nnvm::Op *op = attrs.op; @@ -197,7 +197,7 @@ void Imperative::RecordOp( const OpStatePtr& state, std::vector* p_save_inputs, std::vector* p_save_outputs) { - MXAPIThreadLocalEntry *local_buff = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *local_buff = MXAPIThreadLocalStore<>::Get(); for (auto output : outputs) { CHECK(AGInfo::IsNone(*output)) diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 477139fd84b8..b0476fc8491a 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -104,7 +104,7 @@ inline void SetShapeType(const Context& ctx, static auto& infershape = nnvm::Op::GetAttr("FInferShape"); static auto& infertype = nnvm::Op::GetAttr("FInferType"); static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); // infer shape mxnet::ShapeVector& in_shapes = ret->arg_shapes; in_shapes.clear(); diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 96c86c42a6c4..611dd7287206 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -668,9 +668,9 @@ void SliceEx(const nnvm::NodeAttrs& attrs, template inline void GetIndexRange(const mxnet::TShape& dshape, - const mxnet::Tuple>& param_begin, - const mxnet::Tuple>& param_end, - const mxnet::Tuple>& param_step, + const mxnet::Tuple>& param_begin, + const mxnet::Tuple>& param_end, + const mxnet::Tuple>& param_step, common::StaticArray* begin, common::StaticArray* end, common::StaticArray* step) { @@ -1033,8 +1033,8 @@ void SliceAssignOpForward(const nnvm::NodeAttrs& attrs, struct SliceAssignScalarParam : public dmlc::Parameter { double scalar; - mxnet::Tuple> begin, end; - mxnet::Tuple> step; + mxnet::Tuple> begin, end; + mxnet::Tuple> step; DMLC_DECLARE_PARAMETER(SliceAssignScalarParam) { DMLC_DECLARE_FIELD(scalar) .set_default(0) @@ -1044,7 +1044,7 @@ struct SliceAssignScalarParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(end) .describe("ending indices for the slice operation, supports negative indices."); DMLC_DECLARE_FIELD(step) - .set_default(mxnet::Tuple>()) + .set_default(mxnet::Tuple>()) .describe("step for the slice operation, supports negative values."); } }; @@ -1346,12 +1346,12 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, inline void SliceLikeInferRanges(const mxnet::TShape& dshape, const mxnet::TShape& fshape, const mxnet::Tuple& axes, - mxnet::Tuple>* param_begin, - mxnet::Tuple>* param_end, - mxnet::Tuple>* param_step) { - std::vector> pb(dshape.ndim()); - std::vector> pe(dshape.ndim()); - std::vector> ps(dshape.ndim()); + mxnet::Tuple>* param_begin, + mxnet::Tuple>* param_end, + mxnet::Tuple>* param_step) { + std::vector> pb(dshape.ndim()); + std::vector> pe(dshape.ndim()); + std::vector> ps(dshape.ndim()); if (axes.ndim() == 0) { for (int i = 0; i < dshape.ndim(); ++i) { pb[i] = 0; @@ -1375,9 +1375,9 @@ inline void SliceLikeInferRanges(const mxnet::TShape& dshape, ps[axis] = 1; } } - *param_begin = mxnet::Tuple>(pb.begin(), pb.end()); - *param_end = mxnet::Tuple>(pe.begin(), pe.end()); - *param_step = mxnet::Tuple>(ps.begin(), ps.end()); + *param_begin = mxnet::Tuple>(pb.begin(), pb.end()); + *param_end = mxnet::Tuple>(pe.begin(), pe.end()); + *param_step = mxnet::Tuple>(ps.begin(), ps.end()); } template @@ -1396,9 +1396,9 @@ void SliceLikeForward(const nnvm::NodeAttrs& attrs, const TBlob& out = outputs[0]; const mxnet::TShape& ishape = data.shape_; const mxnet::TShape& from_shape = inputs[1].shape_; - mxnet::Tuple> param_begin; - mxnet::Tuple> param_end; - mxnet::Tuple> param_step; + mxnet::Tuple> param_begin; + mxnet::Tuple> param_end; + mxnet::Tuple> param_step; SliceLikeInferRanges(ishape, from_shape, param.axes, ¶m_begin, ¶m_end, ¶m_step); MXNET_NDIM_SWITCH(data.ndim(), ndim, { @@ -1444,9 +1444,9 @@ void SliceLikeBackward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& ishape = ograd.shape_; const mxnet::TShape& from_shape = outputs[1].shape_; - mxnet::Tuple> param_begin; - mxnet::Tuple> param_end; - mxnet::Tuple> param_step; + mxnet::Tuple> param_begin; + mxnet::Tuple> param_end; + mxnet::Tuple> param_step; SliceLikeInferRanges(ishape, from_shape, param.axes, ¶m_begin, ¶m_end, ¶m_step); MXNET_NDIM_SWITCH(ograd.ndim(), ndim, { diff --git a/src/operator/tensor/slice-inl.h b/src/operator/tensor/slice-inl.h index 78a2bd8c7b45..7450e466440a 100644 --- a/src/operator/tensor/slice-inl.h +++ b/src/operator/tensor/slice-inl.h @@ -34,15 +34,15 @@ namespace mxnet { namespace op { struct SliceParam : public dmlc::Parameter { - mxnet::Tuple> begin, end; - mxnet::Tuple> step; + mxnet::Tuple> begin, end; + mxnet::Tuple> step; DMLC_DECLARE_PARAMETER(SliceParam) { DMLC_DECLARE_FIELD(begin) .describe("starting indices for the slice operation, supports negative indices."); DMLC_DECLARE_FIELD(end) .describe("ending indices for the slice operation, supports negative indices."); DMLC_DECLARE_FIELD(step) - .set_default(mxnet::Tuple>()) + .set_default(mxnet::Tuple>()) .describe("step for the slice operation, supports negative values."); } bool operator==(const SliceParam& other) const { diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 0df481a01987..9dc29eb1584e 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -24,7 +24,6 @@ # dimension constants MEDIUM_X = 10000 LARGE_X = 100000000 -LARGE_Y = 50000000 SMALL_Y = 50 LARGE_SIZE = LARGE_X * SMALL_Y diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py new file mode 100644 index 000000000000..8c030f5bc20e --- /dev/null +++ b/tests/nightly/test_large_vector.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import mxnet as mx +from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d +from mxnet import gluon, nd +from tests.python.unittest.common import with_seed + +# dimension constants +LARGE_X = 5000000000 +MEDIUM_X = 1000000000 + + +def test_slice(): + a = nd.ones(LARGE_X) + res = nd.slice(a, begin=(LARGE_X - MEDIUM_X), end=LARGE_X) + assert res.shape[0] == MEDIUM_X + + +if __name__ == '__main__': + import nose + nose.runmodule()