diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 3d44d41012b3..5ab10b6b2204 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -574,13 +574,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape, int dtype, NDArrayHandle *out); -MXNET_DLL int MXNDArrayCreateExInt64(const mx_int64 *shape, - int ndim, - int dev_type, - int dev_id, - int delay_alloc, - int dtype, - NDArrayHandle *out); +MXNET_DLL int MXNDArrayCreateEx64(const mx_int64 *shape, + int ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle *out); /*! * \brief create an empty sparse NDArray with specified shape and data type @@ -612,18 +612,18 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type, const mx_uint *aux_shape, NDArrayHandle *out); -MXNET_DLL int MXNDArrayCreateSparseExInt64(int storage_type, - const mx_int64 *shape, - int ndim, - int dev_type, - int dev_id, - int delay_alloc, - int dtype, - mx_uint num_aux, - int *aux_type, - int *aux_ndims, - const mx_int64 *aux_shape, - NDArrayHandle *out); +MXNET_DLL int MXNDArrayCreateSparseEx64(int storage_type, + const mx_int64 *shape, + int ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + int *aux_ndims, + const mx_int64 *aux_shape, + NDArrayHandle *out); /*! * \brief create a NDArray handle that is loaded from raw bytes. @@ -672,11 +672,11 @@ MXNET_DLL int MXNDArrayLoad(const char* fname, mx_uint *out_name_size, const char*** out_names); -MXNET_DLL int MXNDArrayLoadInt64(const char* fname, - mx_int64 *out_size, - NDArrayHandle** out_arr, - mx_int64 *out_name_size, - const char*** out_names); +MXNET_DLL int MXNDArrayLoad64(const char* fname, + mx_int64 *out_size, + NDArrayHandle** out_arr, + mx_int64 *out_name_size, + const char*** out_names); /*! * \brief Load list / dictionary of narrays from file content loaded into memory. @@ -699,12 +699,12 @@ MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer, mx_uint *out_name_size, const char*** out_names); -MXNET_DLL int MXNDArrayLoadFromBufferInt64(const void *ndarray_buffer, - size_t size, - mx_int64 *out_size, - NDArrayHandle** out_arr, - mx_int64 *out_name_size, - const char*** out_names); +MXNET_DLL int MXNDArrayLoadFromBuffer64(const void *ndarray_buffer, + size_t size, + mx_int64 *out_size, + NDArrayHandle** out_arr, + mx_int64 *out_name_size, + const char*** out_names); /*! * \brief Perform a synchronize copy from a continugous CPU memory region. @@ -845,9 +845,9 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata); -MXNET_DLL int MXNDArrayGetShapeInt64(NDArrayHandle handle, - int *out_dim, - const int64_t **out_pdata); +MXNET_DLL int MXNDArrayGetShape64(NDArrayHandle handle, + int *out_dim, + const int64_t **out_pdata); /*! * \brief get the shape of the array @@ -860,9 +860,9 @@ MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle, int *out_dim, const int **out_pdata); -MXNET_DLL int MXNDArrayGetShapeExInt64(NDArrayHandle handle, - int *out_dim, - const mx_int64 **out_pdata); +MXNET_DLL int MXNDArrayGetShapeEx64(NDArrayHandle handle, + int *out_dim, + const mx_int64 **out_pdata); /*! * \brief get the content of the data in NDArray @@ -947,9 +947,9 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle, mx_uint i, int *out_type); -MXNET_DLL int MXNDArrayGetAuxTypeInt64(NDArrayHandle handle, - mx_int64 i, - int *out_type); +MXNET_DLL int MXNDArrayGetAuxType64(NDArrayHandle handle, + mx_int64 i, + int *out_type); /*! * \brief Get a deep copy of the ith aux data blob @@ -960,9 +960,9 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle, mx_uint i, NDArrayHandle *out); -MXNET_DLL int MXNDArrayGetAuxNDArrayInt64(NDArrayHandle handle, - mx_int64 i, - NDArrayHandle *out); +MXNET_DLL int MXNDArrayGetAuxNDArray64(NDArrayHandle handle, + mx_int64 i, + NDArrayHandle *out); /*! * \brief Get a deep copy of the data blob @@ -1020,8 +1020,8 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out); MXNET_DLL int MXListFunctions(mx_uint *out_size, FunctionHandle **out_array); -MXNET_DLL int MXListFunctionsInt64(mx_int64 *out_size, - FunctionHandle **out_array); +MXNET_DLL int MXListFunctions64(mx_int64 *out_size, + FunctionHandle **out_array); /*! * \brief get the function handle by name @@ -1291,8 +1291,8 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle, MXNET_DLL int MXListAllOpNames(mx_uint *out_size, const char ***out_array); -MXNET_DLL int MXListAllOpNamesInt64(mx_int64 *out_size, - const char ***out_array); +MXNET_DLL int MXListAllOpNames64(mx_int64 *out_size, + const char ***out_array); /*! * \brief list all the available AtomicSymbolEntry @@ -1303,8 +1303,8 @@ MXNET_DLL int MXListAllOpNamesInt64(mx_int64 *out_size, MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, AtomicSymbolCreator **out_array); -MXNET_DLL int MXSymbolListAtomicSymbolCreatorsInt64(mx_int64 *out_size, - AtomicSymbolCreator **out_array); +MXNET_DLL int MXSymbolListAtomicSymbolCreators64(mx_int64 *out_size, + AtomicSymbolCreator **out_array); /*! * \brief Get the name of an atomic symbol. @@ -1519,9 +1519,9 @@ MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); -MXNET_DLL int MXSymbolListArgumentsInt64(SymbolHandle symbol, - size_t *out_size, - const char ***out_str_array); +MXNET_DLL int MXSymbolListArguments64(SymbolHandle symbol, + size_t *out_size, + const char ***out_str_array); /*! * \brief List returns in the symbol. @@ -1534,9 +1534,9 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); -MXNET_DLL int MXSymbolListOutputsInt64(SymbolHandle symbol, - size_t *out_size, - const char ***out_str_array); +MXNET_DLL int MXSymbolListOutputs64(SymbolHandle symbol, + size_t *out_size, + const char ***out_str_array); /*! * \brief Get number of outputs of the symbol. @@ -1585,9 +1585,9 @@ MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); -MXNET_DLL int MXSymbolListAuxiliaryStatesInt64(SymbolHandle symbol, - size_t *out_size, - const char ***out_str_array); +MXNET_DLL int MXSymbolListAuxiliaryStates64(SymbolHandle symbol, + size_t *out_size, + const char ***out_str_array); /*! * \brief Compose the symbol on other symbols. @@ -1660,21 +1660,21 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, const mx_uint ***aux_shape_data, int *complete); -MXNET_DLL int MXSymbolInferShapeInt64(SymbolHandle sym, - mx_uint num_args, - const char** keys, - const mx_int64 *arg_ind_ptr, - const mx_int64 *arg_shape_data, - size_t *in_shape_size, - const int **in_shape_ndim, - const mx_int64 ***in_shape_data, - size_t *out_shape_size, - const int **out_shape_ndim, - const mx_int64 ***out_shape_data, - size_t *aux_shape_size, - const int **aux_shape_ndim, - const mx_int64 ***aux_shape_data, - int *complete); +MXNET_DLL int MXSymbolInferShape64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_int64 *arg_ind_ptr, + const mx_int64 *arg_shape_data, + size_t *in_shape_size, + const int **in_shape_ndim, + const mx_int64 ***in_shape_data, + size_t *out_shape_size, + const int **out_shape_ndim, + const mx_int64 ***out_shape_data, + size_t *aux_shape_size, + const int **aux_shape_ndim, + const mx_int64 ***aux_shape_data, + int *complete); /*! * \brief infer shape of unknown input shapes given the known one. @@ -1714,21 +1714,21 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym, const int ***aux_shape_data, int *complete); -MXNET_DLL int MXSymbolInferShapeExInt64(SymbolHandle sym, - mx_uint num_args, - const char** keys, - const mx_int64 *arg_ind_ptr, - const mx_int64 *arg_shape_data, - size_t *in_shape_size, - const int **in_shape_ndim, - const mx_int64 ***in_shape_data, - size_t *out_shape_size, - const int **out_shape_ndim, - const mx_int64 ***out_shape_data, - size_t *aux_shape_size, - const int **aux_shape_ndim, - const mx_int64 ***aux_shape_data, - int *complete); +MXNET_DLL int MXSymbolInferShapeEx64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_int64 *arg_ind_ptr, + const mx_int64 *arg_shape_data, + size_t *in_shape_size, + const int **in_shape_ndim, + const mx_int64 ***in_shape_data, + size_t *out_shape_size, + const int **out_shape_ndim, + const mx_int64 ***out_shape_data, + size_t *aux_shape_size, + const int **aux_shape_ndim, + const mx_int64 ***aux_shape_data, + int *complete); /*! * \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead. @@ -1771,21 +1771,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym, const mx_uint ***aux_shape_data, int *complete); -MXNET_DLL int MXSymbolInferShapePartialInt64(SymbolHandle sym, - mx_uint num_args, - const char** keys, - const mx_int64 *arg_ind_ptr, - const mx_int64 *arg_shape_data, - size_t *in_shape_size, - const int **in_shape_ndim, - const mx_int64 ***in_shape_data, - size_t *out_shape_size, - const int **out_shape_ndim, - const mx_int64 ***out_shape_data, - size_t *aux_shape_size, - const int **aux_shape_ndim, - const mx_int64 ***aux_shape_data, - int *complete); +MXNET_DLL int MXSymbolInferShapePartial64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_int64 *arg_ind_ptr, + const mx_int64 *arg_shape_data, + size_t *in_shape_size, + const int **in_shape_ndim, + const mx_int64 ***in_shape_data, + size_t *out_shape_size, + const int **out_shape_ndim, + const mx_int64 ***out_shape_data, + size_t *aux_shape_size, + const int **aux_shape_ndim, + const mx_int64 ***aux_shape_data, + int *complete); /*! * \brief partially infer shape of unknown input shapes given the known one. @@ -1827,21 +1827,21 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym, const int ***aux_shape_data, int *complete); -MXNET_DLL int MXSymbolInferShapePartialExInt64(SymbolHandle sym, - mx_uint num_args, - const char** keys, - const mx_int64 *arg_ind_ptr, - const mx_int64 *arg_shape_data, - size_t *in_shape_size, - const int **in_shape_ndim, - const mx_int64 ***in_shape_data, - size_t *out_shape_size, - const int **out_shape_ndim, - const mx_int64 ***out_shape_data, - size_t *aux_shape_size, - const int **aux_shape_ndim, - const mx_int64 ***aux_shape_data, - int *complete); +MXNET_DLL int MXSymbolInferShapePartialEx64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_int64 *arg_ind_ptr, + const mx_int64 *arg_shape_data, + size_t *in_shape_size, + const int **in_shape_ndim, + const mx_int64 ***in_shape_data, + size_t *out_shape_size, + const int **out_shape_ndim, + const mx_int64 ***out_shape_data, + size_t *aux_shape_size, + const int **aux_shape_ndim, + const mx_int64 ***aux_shape_data, + int *complete); /*! * \brief infer type of unknown input types given the known one. diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h index b56d2e4e9765..f018c8faabea 100644 --- a/include/mxnet/tuple.h +++ b/include/mxnet/tuple.h @@ -374,6 +374,7 @@ class Tuple { } }; + /*! brief check if a shape's ndim is known. */ inline bool ndim_is_known(const int ndim) { CHECK_GE(ndim, -1) << "shape ndim must be >= -1, while received " << ndim; diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 82b8b07eb6dc..2b5a2d1dd45e 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -107,6 +107,14 @@ _NDARRAY_BASIC_INDEXING = 0 _NDARRAY_ADVANCED_INDEXING = 1 +# Caching whether MXNet was built with INT64 support or not +_INT64_TENSOR_SIZE_ENABLED = None + +def int64_enabled(): + global _INT64_TENSOR_SIZE_ENABLED + if _INT64_TENSOR_SIZE_ENABLED is None: + _INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE') + return _INT64_TENSOR_SIZE_ENABLED def _new_empty_handle(): """Returns a new empty handle. @@ -134,8 +142,8 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t): A new empty `NDArray` handle. """ hdl = NDArrayHandle() - if Features().is_enabled('INT64_TENSOR_SIZE') and sys.version_info[0] > 2: - check_call(_LIB.MXNDArrayCreateExInt64( + if sys.version_info[0] > 2 and int64_enabled(): + check_call(_LIB.MXNDArrayCreateEx64( c_array_buf(mx_int64, native_array('q', shape)), ctypes.c_int(len(shape)), ctypes.c_int(ctx.device_typeid), @@ -2230,9 +2238,9 @@ def shape(self): (2L, 3L, 4L) """ ndim = mx_int() - if sys.version_info[0] > 2 and Features().is_enabled('INT64_TENSOR_SIZE'): + if _INT64_TENSOR_SIZE_ENABLED: pdata = ctypes.POINTER(mx_int64)() - check_call(_LIB.MXNDArrayGetShapeExInt64( + check_call(_LIB.MXNDArrayGetShapeEx64( self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) else: pdata = ctypes.POINTER(mx_int)() diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 919d14ff8a7d..fc798263d52f 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -39,9 +39,8 @@ from ..base import check_call, MXNetError, NotImplementedForSymbol from ..context import Context, current_context from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP -from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID +from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, int64_enabled from ..ndarray import _ndarray_cls -from ..runtime import Features from ..executor import Executor from . import _internal from . import op @@ -1213,14 +1212,14 @@ def _infer_shape_impl(self, partial, *args, **kwargs): aux_shape_size = mx_uint() aux_shape_ndim = ctypes.POINTER(mx_int)() complete = ctypes.c_int() - if Features().is_enabled('INT64_TENSOR_SIZE') and sys.version_info[0] > 2: + if sys.version_info[0] > 2 and int64_enabled(): arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))() out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))() aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))() if partial: - infer_func = _LIB.MXSymbolInferShapePartialExInt64 + infer_func = _LIB.MXSymbolInferShapePartialEx64 else: - infer_func = _LIB.MXSymbolInferShapeExInt64 + infer_func = _LIB.MXSymbolInferShapeEx64 check_call(infer_func( self.handle, mx_uint(len(indptr) - 1), diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index b9c5f1960ae9..c2b80b3f601c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -190,8 +190,13 @@ int MXNDArrayCreateNone(NDArrayHandle *out) { } template -void CreateArray(const DataType* shape, dimtype ndim, int dev_type, int dev_id, int delay_alloc, - int dtype, NDArrayHandle* out) { +void CreateNDArray(const DataType* shape, + dimtype ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle* out) { *out = new NDArray(mxnet::TShape(shape, shape + ndim), Context::Create(static_cast(dev_type), dev_id), delay_alloc != 0, dtype); @@ -210,15 +215,15 @@ int MXNDArrayCreate(const mx_uint *shape, API_END(); } -int MXNDArrayCreateExInt64(const mx_int64 *shape, - int ndim, - int dev_type, - int dev_id, - int delay_alloc, - int dtype, - NDArrayHandle *out) { +int MXNDArrayCreateEx64(const mx_int64 *shape, + int ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle *out) { API_BEGIN(); - CreateArray(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out); + CreateNDArray(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out); API_END(); } @@ -230,11 +235,7 @@ int MXNDArrayCreateEx(const mx_uint *shape, int dtype, NDArrayHandle *out) { API_BEGIN(); - *out = new NDArray( - mxnet::TShape(shape, shape + ndim), - Context::Create(static_cast(dev_type), dev_id), - delay_alloc != 0, - dtype); + CreateNDArray(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out); API_END(); } @@ -558,7 +559,7 @@ int MXNDArrayGetShape(NDArrayHandle handle, template inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim, - MXAPIThreadLocalEntry* ret) { + MXAPIThreadLocalEntry* ret) { NDArray* arr = static_cast(handle); if (!arr->is_none()) { mxnet::TShape s = arr->shape(); @@ -590,9 +591,9 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle, API_END(); } -int MXNDArrayGetShapeExInt64(NDArrayHandle handle, - int *out_dim, - const mx_int64 **out_pdata) { +int MXNDArrayGetShapeEx64(NDArrayHandle handle, + int *out_dim, + const mx_int64 **out_pdata) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); GetShape(handle, out_pdata, out_dim, ret); diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index f19658bbb26b..e8d59d90e99b 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -25,7 +25,6 @@ #include "mxnet/base.h" #include "mxnet/c_api.h" #include "mxnet/imperative.h" -#include "mxnet/libinfo.h" #include "nnvm/c_api.h" #include "nnvm/pass.h" #include "nnvm/pass_functions.h" @@ -588,74 +587,74 @@ int MXSymbolInferShape(SymbolHandle sym, template inline void SymbolInferShape(const char** keys, - mx_uint num_args, - const dtype* arg_shape_data, - const itype* arg_ind_ptr, - const int** in_shape_ndim, - const dtype*** in_shape_data, - const int** out_shape_ndim, - const dtype*** out_shape_data, - const int** aux_shape_ndim, - const dtype*** aux_shape_data, - nnvm::Symbol* s, - MXAPIThreadLocalEntry* ret, - stype* in_shape_size, - stype* out_shape_size, - stype* aux_shape_size, - int* complete) { -nnvm::Graph g = Symbol2Graph(*s); -mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape()); -if (keys == nullptr && num_args != 0) { - std::vector < uint32_t > read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); - CHECK_LE(num_args, read_only_args.size()); - for (mx_uint i = 0; i < num_args; ++i) { - arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i], - arg_shape_data + arg_ind_ptr[i + 1]); + mx_uint num_args, + const dtype* arg_shape_data, + const itype* arg_ind_ptr, + const int** in_shape_ndim, + const dtype*** in_shape_data, + const int** out_shape_ndim, + const dtype*** out_shape_data, + const int** aux_shape_ndim, + const dtype*** aux_shape_data, + nnvm::Symbol* s, + MXAPIThreadLocalEntry* ret, + stype* in_shape_size, + stype* out_shape_size, + stype* aux_shape_size, + int* complete) { + nnvm::Graph g = Symbol2Graph(*s); + mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape()); + if (keys == nullptr && num_args != 0) { + std::vector < uint32_t > read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); + CHECK_LE(num_args, read_only_args.size()); + for (mx_uint i = 0; i < num_args; ++i) { + arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i], + arg_shape_data + arg_ind_ptr[i + 1]); + } + } else { + std::unordered_map kwargs; + for (mx_uint i = 0; i < num_args; ++i) { + kwargs[keys[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i], + arg_shape_data + arg_ind_ptr[i + 1]); + } + mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape"); } -} else { - std::unordered_map kwargs; - for (mx_uint i = 0; i < num_args; ++i) { - kwargs[keys[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i], - arg_shape_data + arg_ind_ptr[i + 1]); + try { + g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); + } catch (const mxnet::op::InferShapeError& err) { + throw dmlc::Error(err.msg); } - mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape"); -} -try { - g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); -} catch (const mxnet::op::InferShapeError& err) { - throw dmlc::Error(err.msg); -} -// if use legacy shape definition, need to convert numpy shape to legacy shape -mxnet::ShapeVector shapes = g.GetAttr("shape"); -if (!Imperative::Get()->is_np_shape()) { - common::ConvertToLegacyShape(&shapes); -} -// copy back -CopyAttr(g.indexed_graph(), shapes, &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); -// copy data back -MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->arg_shapes, - &(ret->arg_shape_ndim_ex), - &(ret->arg_shape_data_ex), - &(ret->arg_shape_buffer_ex)); -MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->out_shapes, - &(ret->out_shape_ndim_ex), - &(ret->out_shape_data_ex), - &(ret->out_shape_buffer_ex)); -MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->aux_shapes, - &(ret->aux_shape_ndim_ex), - &(ret->aux_shape_data_ex), - &(ret->aux_shape_buffer_ex)); - *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data_ex); - *out_shape_data = dmlc::BeginPtr(ret->out_shape_data_ex); - *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data_ex); + // if use legacy shape definition, need to convert numpy shape to legacy shape + mxnet::ShapeVector shapes = g.GetAttr("shape"); + if (!Imperative::Get()->is_np_shape()) { + common::ConvertToLegacyShape(&shapes); + } + // copy back + CopyAttr(g.indexed_graph(), shapes, &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); + // copy data back + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->arg_shapes, + &(ret->arg_shape_ndim_ex), + &(ret->arg_shape_data_ex), + &(ret->arg_shape_buffer_ex)); + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->out_shapes, + &(ret->out_shape_ndim_ex), + &(ret->out_shape_data_ex), + &(ret->out_shape_buffer_ex)); + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->aux_shapes, + &(ret->aux_shape_ndim_ex), + &(ret->aux_shape_data_ex), + &(ret->aux_shape_buffer_ex)); *in_shape_size = static_cast(ret->arg_shapes.size()); *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim_ex); + *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data_ex); *out_shape_size = static_cast(ret->out_shapes.size()); *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim_ex); + *out_shape_data = dmlc::BeginPtr(ret->out_shape_data_ex); *aux_shape_size = static_cast(ret->aux_shapes.size()); *aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim_ex); -// mark complete -*complete = (g.GetAttr("shape_num_unknown_nodes") == 0); + *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data_ex); + // mark complete + *complete = (g.GetAttr("shape_num_unknown_nodes") == 0); } int MXSymbolInferShapeEx(SymbolHandle sym, @@ -677,58 +676,58 @@ int MXSymbolInferShapeEx(SymbolHandle sym, MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); SymbolInferShape(keys, - num_args, - arg_shape_data, - arg_ind_ptr, - in_shape_ndim, - in_shape_data, - out_shape_ndim, - out_shape_data, - aux_shape_ndim, - aux_shape_data, - s, - ret, - in_shape_size, - out_shape_size, - aux_shape_size, - complete); + num_args, + arg_shape_data, + arg_ind_ptr, + in_shape_ndim, + in_shape_data, + out_shape_ndim, + out_shape_data, + aux_shape_ndim, + aux_shape_data, + s, + ret, + in_shape_size, + out_shape_size, + aux_shape_size, + complete); API_END(); } -int MXSymbolInferShapeExInt64(SymbolHandle sym, - mx_uint num_args, - const char** keys, - const int64_t *arg_ind_ptr, - const int64_t *arg_shape_data, - size_t *in_shape_size, - const int **in_shape_ndim, - const int64_t ***in_shape_data, - size_t *out_shape_size, - const int **out_shape_ndim, - const int64_t ***out_shape_data, - size_t *aux_shape_size, - const int **aux_shape_ndim, - const int64_t ***aux_shape_data, - int *complete) { +int MXSymbolInferShapeEx64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const int64_t *arg_ind_ptr, + const int64_t *arg_shape_data, + size_t *in_shape_size, + const int **in_shape_ndim, + const int64_t ***in_shape_data, + size_t *out_shape_size, + const int **out_shape_ndim, + const int64_t ***out_shape_data, + size_t *aux_shape_size, + const int **aux_shape_ndim, + const int64_t ***aux_shape_data, + int *complete) { nnvm::Symbol *s = static_cast(sym); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); SymbolInferShape(keys, - num_args, - arg_shape_data, - arg_ind_ptr, - in_shape_ndim, - in_shape_data, - out_shape_ndim, - out_shape_data, - aux_shape_ndim, - aux_shape_data, - s, - ret, - in_shape_size, - out_shape_size, - aux_shape_size, - complete); + num_args, + arg_shape_data, + arg_ind_ptr, + in_shape_ndim, + in_shape_data, + out_shape_ndim, + out_shape_data, + aux_shape_ndim, + aux_shape_data, + s, + ret, + in_shape_size, + out_shape_size, + aux_shape_size, + complete); API_END(); } @@ -747,7 +746,7 @@ int MXSymbolInferShapePartial(SymbolHandle sym, const mx_uint **aux_shape_ndim, const mx_uint ***aux_shape_data, int *complete) { - int succ; + int succ = 0; *complete = 1; return MXSymbolInferShape(sym, num_args, keys, arg_ind_ptr, arg_shape_data, @@ -772,7 +771,7 @@ int MXSymbolInferShapePartialEx(SymbolHandle sym, const int **aux_shape_ndim, const int ***aux_shape_data, int *complete) { - int succ; + int succ = 0; *complete = 1; return MXSymbolInferShapeEx(sym, num_args, keys, arg_ind_ptr, arg_shape_data, @@ -782,29 +781,29 @@ int MXSymbolInferShapePartialEx(SymbolHandle sym, &succ); } -int MXSymbolInferShapePartialExInt64(SymbolHandle sym, - mx_uint num_args, - const char** keys, - const int64_t *arg_ind_ptr, - const int64_t *arg_shape_data, - size_t *in_shape_size, - const int **in_shape_ndim, - const int64_t ***in_shape_data, - size_t *out_shape_size, - const int **out_shape_ndim, - const int64_t ***out_shape_data, - size_t *aux_shape_size, - const int **aux_shape_ndim, - const int64_t ***aux_shape_data, - int *complete) { - int succ; +int MXSymbolInferShapePartialEx64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const int64_t *arg_ind_ptr, + const int64_t *arg_shape_data, + size_t *in_shape_size, + const int **in_shape_ndim, + const int64_t ***in_shape_data, + size_t *out_shape_size, + const int **out_shape_ndim, + const int64_t ***out_shape_data, + size_t *aux_shape_size, + const int **aux_shape_ndim, + const int64_t ***aux_shape_data, + int *complete) { + int succ = 0; *complete = 1; - return MXSymbolInferShapeExInt64(sym, num_args, keys, - arg_ind_ptr, arg_shape_data, - in_shape_size, in_shape_ndim, in_shape_data, - out_shape_size, out_shape_ndim, out_shape_data, - aux_shape_size, aux_shape_ndim, aux_shape_data, - &succ); + return MXSymbolInferShapeEx64(sym, num_args, keys, + arg_ind_ptr, arg_shape_data, + in_shape_size, in_shape_ndim, in_shape_data, + out_shape_size, out_shape_ndim, out_shape_data, + aux_shape_size, aux_shape_ndim, aux_shape_data, + &succ); } int MXSymbolInferType(SymbolHandle sym, @@ -863,7 +862,7 @@ int MXSymbolInferTypePartial(SymbolHandle sym, mx_uint *aux_type_size, const int **aux_type_data, int *complete) { - int succ; + int succ = 0; *complete = 1; return MXSymbolInferType(sym, num_args, keys, arg_type_data, diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 8b150000af54..8c030f5bc20e 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -23,13 +23,13 @@ # dimension constants LARGE_X = 5000000000 -SMALL_Y = 1 +MEDIUM_X = 1000000000 def test_slice(): a = nd.ones(LARGE_X) - res = nd.slice(a, begin=(LARGE_X-1000000000), end=(LARGE_X)) - assert res.shape[0] == 1000000000 + res = nd.slice(a, begin=(LARGE_X - MEDIUM_X), end=LARGE_X) + assert res.shape[0] == MEDIUM_X if __name__ == '__main__':