diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index bd30e44f910c..f4f4e2b3438e 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -55,7 +55,9 @@ extern "C" { #endif /*! \brief manually define unsigned int */ -typedef unsigned int mx_uint; +typedef uint32_t mx_uint; +/*! \brief manually define 64-bit int */ +typedef int64_t mx_int64; /*! \brief manually define float */ typedef float mx_float; /*! \brief data type to store dim size */ @@ -556,6 +558,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape, int dtype, NDArrayHandle *out); +MXNET_DLL int MXNDArrayCreateExInt64(const mx_int64 *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle *out); /*! * \brief create an empty sparse NDArray with specified shape and data type @@ -587,6 +596,19 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type, const mx_uint *aux_shape, NDArrayHandle *out); +MXNET_DLL int MXNDArrayCreateSparseExInt64(int storage_type, + const mx_int64 *shape, + mx_int64 ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + mx_uint *aux_ndims, + const mx_uint *aux_shape, + NDArrayHandle *out); + /*! * \brief create a NDArray handle that is loaded from raw bytes. * \param buf the head of the raw bytes @@ -634,6 +656,12 @@ MXNET_DLL int MXNDArrayLoad(const char* fname, mx_uint *out_name_size, const char*** out_names); +MXNET_DLL int MXNDArrayLoadInt64(const char* fname, + mx_int64 *out_size, + NDArrayHandle** out_arr, + mx_int64 *out_name_size, + const char*** out_names); + /*! * \brief Load list / dictionary of narrays from file content loaded into memory. * This will load a list of ndarrays in a similar @@ -649,11 +677,18 @@ MXNET_DLL int MXNDArrayLoad(const char* fname, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer, - size_t size, - mx_uint *out_size, - NDArrayHandle** out_arr, - mx_uint *out_name_size, - const char*** out_names); + size_t size, + mx_uint *out_size, + NDArrayHandle** out_arr, + mx_uint *out_name_size, + const char*** out_names); + +MXNET_DLL int MXNDArrayLoadFromBufferInt64(const void *ndarray_buffer, + size_t size, + mx_int64 *out_size, + NDArrayHandle** out_arr, + mx_int64 *out_name_size, + const char*** out_names); /*! * \brief Perform a synchronize copy from a continugous CPU memory region. @@ -793,6 +828,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle, MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata); + +MXNET_DLL int MXNDArrayGetShapeInt64(NDArrayHandle handle, + mx_int64 *out_dim, + const mx_int64 **out_pdata); + /*! * \brief get the shape of the array * \param handle the handle to the narray @@ -803,6 +843,11 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle, int *out_dim, const int **out_pdata); + +MXNET_DLL int MXNDArrayGetShapeExInt64(NDArrayHandle handle, + int *out_dim, + const int64_t **out_pdata); + /*! * \brief get the content of the data in NDArray * \param handle the handle to the ndarray @@ -886,6 +931,10 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle, mx_uint i, int *out_type); +MXNET_DLL int MXNDArrayGetAuxTypeInt64(NDArrayHandle handle, + mx_int64 i, + int *out_type); + /*! * \brief Get a deep copy of the ith aux data blob * in the form of an NDArray of default storage type. @@ -895,6 +944,10 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle, mx_uint i, NDArrayHandle *out); +MXNET_DLL int MXNDArrayGetAuxNDArrayInt64(NDArrayHandle handle, + mx_int64 i, + NDArrayHandle *out); + /*! * \brief Get a deep copy of the data blob * in the form of an NDArray of default storage type. @@ -950,6 +1003,10 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out); */ MXNET_DLL int MXListFunctions(mx_uint *out_size, FunctionHandle **out_array); + +MXNET_DLL int MXListFunctionsInt64(mx_int64 *out_size, + FunctionHandle **out_array); + /*! * \brief get the function handle by name * \param name the name of the function @@ -1217,6 +1274,10 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle, */ MXNET_DLL int MXListAllOpNames(mx_uint *out_size, const char ***out_array); + +MXNET_DLL int MXListAllOpNamesInt64(mx_int64 *out_size, + const char ***out_array); + /*! * \brief list all the available AtomicSymbolEntry * \param out_size the size of returned array @@ -1226,6 +1287,9 @@ MXNET_DLL int MXListAllOpNames(mx_uint *out_size, MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, AtomicSymbolCreator **out_array); +MXNET_DLL int MXSymbolListAtomicSymbolCreatorsInt64(mx_int64 *out_size, + AtomicSymbolCreator **out_array); + /*! * \brief Get the name of an atomic symbol. * \param creator the AtomicSymbolCreator. @@ -1438,6 +1502,11 @@ MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol, MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); + +MXNET_DLL int MXSymbolListArgumentsInt64(SymbolHandle symbol, + mx_int64 *out_size, + const char ***out_str_array); + /*! * \brief List returns in the symbol. * \param symbol the symbol @@ -1449,6 +1518,10 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); +MXNET_DLL int MXSymbolListOutputsInt64(SymbolHandle symbol, + mx_int64 *out_size, + const char ***out_str_array); + /*! * \brief Get number of outputs of the symbol. * \param symbol The symbol @@ -1456,7 +1529,7 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol, - mx_uint *output_count); + mx_uint *output_count); /*! * \brief Get a symbol that contains all the internals. @@ -1495,6 +1568,11 @@ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol, MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); + +MXNET_DLL int MXSymbolListAuxiliaryStatesInt64(SymbolHandle symbol, + mx_int64 *out_size, + const char ***out_str_array); + /*! * \brief Compose the symbol on other symbols. * @@ -1566,6 +1644,22 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, const mx_uint ***aux_shape_data, int *complete); +MXNET_DLL int MXSymbolInferShapeInt64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_int64 *arg_ind_ptr, + const mx_int64 *arg_shape_data, + mx_int64 *in_shape_size, + const mx_int64 **in_shape_ndim, + const mx_int64 ***in_shape_data, + mx_int64 *out_shape_size, + const mx_int64 **out_shape_ndim, + const mx_int64 ***out_shape_data, + mx_int64 *aux_shape_size, + const mx_int64 **aux_shape_ndim, + const mx_int64 ***aux_shape_data, + int *complete); + /*! * \brief infer shape of unknown input shapes given the known one. * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data @@ -1603,6 +1697,23 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym, const int **aux_shape_ndim, const int ***aux_shape_data, int *complete); + +MXNET_DLL int MXSymbolInferShapeExInt64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_uint *arg_ind_ptr, + const int *arg_shape_data, + mx_uint *in_shape_size, + const int **in_shape_ndim, + const int64_t ***in_shape_data, + mx_uint *out_shape_size, + const int **out_shape_ndim, + const int64_t ***out_shape_data, + mx_uint *aux_shape_size, + const int **aux_shape_ndim, + const int64_t ***aux_shape_data, + int *complete); + /*! * \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead. * partially infer shape of unknown input shapes given the known one. @@ -1644,6 +1755,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym, const mx_uint ***aux_shape_data, int *complete); +MXNET_DLL int MXSymbolInferShapePartialInt64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_int64 *arg_ind_ptr, + const mx_int64 *arg_shape_data, + mx_int64 *in_shape_size, + const mx_int64 **in_shape_ndim, + const mx_int64 ***in_shape_data, + mx_int64 *out_shape_size, + const mx_int64 **out_shape_ndim, + const mx_int64 ***out_shape_data, + mx_int64 *aux_shape_size, + const mx_int64 **aux_shape_ndim, + const mx_int64 ***aux_shape_data, + int *complete); /*! * \brief partially infer shape of unknown input shapes given the known one. @@ -1685,6 +1811,22 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym, const int ***aux_shape_data, int *complete); +MXNET_DLL int MXSymbolInferShapePartialExInt64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_int64 *arg_ind_ptr, + const int *arg_shape_data, + mx_int64 *in_shape_size, + const int **in_shape_ndim, + const int ***in_shape_data, + mx_int64 *out_shape_size, + const int **out_shape_ndim, + const int ***out_shape_data, + mx_int64 *aux_shape_size, + const int **aux_shape_ndim, + const int ***aux_shape_data, + int *complete); + /*! * \brief infer type of unknown input types given the known one. * The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data diff --git a/include/mxnet/c_predict_api.h b/include/mxnet/c_predict_api.h index 18bec625f05f..c79baa4a86ff 100644 --- a/include/mxnet/c_predict_api.h +++ b/include/mxnet/c_predict_api.h @@ -42,7 +42,9 @@ extern "C" { #endif /*! \brief manually define unsigned int */ -typedef unsigned int mx_uint; +typedef uint32_t mx_uint; +/*! \brief manually define 64-bit int */ +typedef int64_t mx_int64; /*! \brief manually define float */ typedef float mx_float; /*! \brief handle to Predictor */ diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h index bc630f153744..77468d0ebdef 100644 --- a/include/mxnet/tuple.h +++ b/include/mxnet/tuple.h @@ -366,7 +366,6 @@ class Tuple { } }; - /*! brief check if a shape's ndim is known. */ inline bool ndim_is_known(const int ndim) { CHECK_GE(ndim, -1) << "shape ndim must be >= -1, while received " << ndim; diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 73fae4876873..f99ba03eb7b5 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -215,6 +215,7 @@ def _load_lib(): # type definitions mx_int = ctypes.c_int mx_uint = ctypes.c_uint +mx_int64 = ctypes.c_int64 mx_float = ctypes.c_float mx_float_p = ctypes.POINTER(mx_float) mx_real_t = _np.float32 diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 3fb1af6a7336..3731e81e7ea0 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -35,8 +35,9 @@ import numpy as np from ..base import _LIB, numeric_types, integer_types from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t -from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int +from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int, mx_int64 from ..base import ctypes2buffer +from ..runtime import Features from ..context import Context, current_context from . import _internal from . import op @@ -131,14 +132,24 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t): A new empty `NDArray` handle. """ hdl = NDArrayHandle() - check_call(_LIB.MXNDArrayCreateEx( - c_array_buf(mx_uint, native_array('I', shape)), - mx_uint(len(shape)), - ctypes.c_int(ctx.device_typeid), - ctypes.c_int(ctx.device_id), - ctypes.c_int(int(delay_alloc)), - ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])), - ctypes.byref(hdl))) + if Features().is_enabled('INT64_TENSOR_SIZE'): + check_call(_LIB.MXNDArrayCreateExInt64( + c_array_buf(mx_int64, native_array('q', shape)), + mx_int64(len(shape)), + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + ctypes.c_int(int(delay_alloc)), + ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])), + ctypes.byref(hdl))) + else: + check_call(_LIB.MXNDArrayCreateEx( + c_array_buf(mx_uint, native_array('I', shape)), + mx_uint(len(shape)), + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + ctypes.c_int(int(delay_alloc)), + ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])), + ctypes.byref(hdl))) return hdl @@ -1847,9 +1858,14 @@ def shape(self): (2L, 3L, 4L) """ ndim = mx_int() - pdata = ctypes.POINTER(mx_int)() - check_call(_LIB.MXNDArrayGetShapeEx( - self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) + if Features().is_enabled('INT64_TENSOR_SIZE'): + pdata = ctypes.POINTER(mx_int64)() + check_call(_LIB.MXNDArrayGetShapeExInt64( + self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) + else: + pdata = ctypes.POINTER(mx_int)() + check_call(_LIB.MXNDArrayGetShapeEx( + self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) if ndim.value == -1: return None else: diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index d3cd519b9a8c..696c5b06f7bd 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1180,18 +1180,27 @@ def _infer_shape_impl(self, partial, *args, **kwargs): keys = c_str_array(str_keys) arg_shape_size = mx_uint() arg_shape_ndim = ctypes.POINTER(mx_int)() - arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() + if Features().is_enabled('INT64_TENSOR_SIZE'): + arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))() + out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))() + aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))() + if partial: + infer_func = _LIB.MXSymbolInferShapePartialExInt64 + else: + infer_func = _LIB.MXSymbolInferShapeExInt64 + else: + arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() + out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() + aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() + if partial: + infer_func = _LIB.MXSymbolInferShapePartialEx + else: + infer_func = _LIB.MXSymbolInferShapeEx out_shape_size = mx_uint() out_shape_ndim = ctypes.POINTER(mx_int)() - out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() aux_shape_size = mx_uint() aux_shape_ndim = ctypes.POINTER(mx_int)() - aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() complete = ctypes.c_int() - if partial: - infer_func = _LIB.MXSymbolInferShapePartialEx - else: - infer_func = _LIB.MXSymbolInferShapeEx check_call(infer_func( self.handle, mx_uint(len(indptr) - 1), diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 35bd3eeb477a..0735d4c365e5 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -179,13 +179,29 @@ int MXNDArrayCreate(const mx_uint *shape, API_END(); } +int MXNDArrayCreateExInt64(const mx_int64 *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle *out) { + API_BEGIN(); + *out = new NDArray( + mxnet::TShape(shape, shape + ndim), + Context::Create(static_cast(dev_type), dev_id), + delay_alloc != 0, + dtype); + API_END(); +} + int MXNDArrayCreateEx(const mx_uint *shape, - mx_uint ndim, - int dev_type, - int dev_id, - int delay_alloc, - int dtype, - NDArrayHandle *out) { + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle *out) { API_BEGIN(); *out = new NDArray( mxnet::TShape(shape, shape + ndim), @@ -196,17 +212,17 @@ int MXNDArrayCreateEx(const mx_uint *shape, } int MXNDArrayCreateSparseEx(int storage_type, - const mx_uint *shape, - mx_uint ndim, - int dev_type, - int dev_id, - int delay_alloc, - int dtype, - mx_uint num_aux, - int *aux_type, - mx_uint *aux_ndims, - const mx_uint *aux_shape, - NDArrayHandle *out) { + const mx_uint *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + mx_uint *aux_ndims, + const mx_uint *aux_shape, + NDArrayHandle *out) { API_BEGIN(); std::vector aux_types; mxnet::ShapeVector aux_shapes; @@ -541,6 +557,34 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle, API_END(); } +int MXNDArrayGetShapeExInt64(NDArrayHandle handle, + int *out_dim, + const int64_t **out_pdata) { + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + NDArray *arr = static_cast(handle); + if (!arr->is_none()) { + mxnet::TShape s = arr->shape(); + if (!Imperative::Get()->is_np_shape()) { + common::ConvertToLegacyShape(&s); + } + *out_dim = s.ndim(); + if (s.ndim() >= 0) { + std::vector &buffer = ret->arg_shape_buffer_ex_int64; + buffer.resize(s.ndim()); + mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data()); + *out_pdata = buffer.data(); + } + } else { + if (Imperative::Get()->is_np_shape()) { + *out_dim = -1; + } else { + *out_dim = 0; + } + } + API_END(); +} + int MXNDArrayGetData(NDArrayHandle handle, void **out_pdata) { API_BEGIN(); diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index 013ecab93da8..a985295d2a0c 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -81,10 +81,14 @@ struct MXAPIThreadLocalEntry { std::vector arg_shape_data, out_shape_data, aux_shape_data; /*! \brief result holder for returning shape pointer */ std::vector arg_shape_data_ex, out_shape_data_ex, aux_shape_data_ex; + std::vector arg_shape_data_ex_int64, out_shape_data_ex_int64, + aux_shape_data_ex_int64; /*! \brief uint32_t buffer for returning shape pointer */ std::vector arg_shape_buffer, out_shape_buffer, aux_shape_buffer; /*! \brief uint32_t buffer for returning shape pointer */ std::vector arg_shape_buffer_ex, out_shape_buffer_ex, aux_shape_buffer_ex; + std::vector arg_shape_buffer_ex_int64, out_shape_buffer_ex_int64, + aux_shape_buffer_ex_int64; /*! \brief bool buffer */ std::vector save_inputs, save_outputs; // DEPRECATED. Use SetupShapeArrayReturnWithBufferEx instead. @@ -130,6 +134,30 @@ struct MXAPIThreadLocalEntry { } } } + + inline static void SetupShapeArrayReturnWithBufferExInt64( + const mxnet::ShapeVector &shapes, + std::vector *ndim, + std::vector *data, + std::vector *buffer) { + ndim->resize(shapes.size()); + data->resize(shapes.size()); + size_t size = 0; + for (const auto& s : shapes) { + if (s.ndim() > 0) { + size += s.ndim(); + } + } + buffer->resize(size); + int64_t *ptr = buffer->data(); + for (size_t i = 0; i < shapes.size(); ++i) { + ndim->at(i) = shapes[i].ndim(); + data->at(i) = ptr; + if (shapes[i].ndim() > 0) { + ptr = mxnet::ShapeTypeCast(shapes[i].begin(), shapes[i].end(), ptr); + } + } + } }; // define the threadlocal store. diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 80ae5438c20d..3e3006c11b83 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -658,6 +658,85 @@ int MXSymbolInferShapeEx(SymbolHandle sym, API_END(); } +int MXSymbolInferShapeExInt64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_uint *arg_ind_ptr, + const int *arg_shape_data, + mx_uint *in_shape_size, + const int **in_shape_ndim, + const int64_t ***in_shape_data, + mx_uint *out_shape_size, + const int **out_shape_ndim, + const int64_t ***out_shape_data, + mx_uint *aux_shape_size, + const int **aux_shape_ndim, + const int64_t ***aux_shape_data, + int *complete) { + nnvm::Symbol *s = static_cast(sym); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + nnvm::Graph g = Symbol2Graph(*s); + mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape()); + if (keys == nullptr && num_args != 0) { + std::vector read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); + CHECK_LE(num_args, read_only_args.size()); + for (mx_uint i = 0; i < num_args; ++i) { + arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast( + arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]); + } + } else { + std::unordered_map kwargs; + for (mx_uint i = 0; i < num_args; ++i) { + kwargs[keys[i]] = mxnet::ShapeTypeCast( + arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]); + } + mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape"); + } + + try { + g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); + } catch (const mxnet::op::InferShapeError &err) { + throw dmlc::Error(err.msg); + } + + // if use legacy shape definition, need to convert numpy shape to legacy shape + mxnet::ShapeVector shapes = g.GetAttr("shape"); + if (!Imperative::Get()->is_np_shape()) { + common::ConvertToLegacyShape(&shapes); + } + + // copy back + CopyAttr(g.indexed_graph(), shapes, + &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); + + // copy data back + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferExInt64(ret->arg_shapes, + &(ret->arg_shape_ndim_ex), + &(ret->arg_shape_data_ex_int64), + &(ret->arg_shape_buffer_ex_int64)); + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferExInt64(ret->out_shapes, + &(ret->out_shape_ndim_ex), + &(ret->out_shape_data_ex_int64), + &(ret->out_shape_buffer_ex_int64)); + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferExInt64(ret->aux_shapes, + &(ret->aux_shape_ndim_ex), + &(ret->aux_shape_data_ex_int64), + &(ret->aux_shape_buffer_ex_int64)); + *in_shape_size = static_cast(ret->arg_shapes.size()); + *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim_ex); + *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data_ex_int64); + *out_shape_size = static_cast(ret->out_shapes.size()); + *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim_ex); + *out_shape_data = dmlc::BeginPtr(ret->out_shape_data_ex_int64); + *aux_shape_size = static_cast(ret->aux_shapes.size()); + *aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim_ex); + *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data_ex_int64); + // mark complete + *complete = (g.GetAttr("shape_num_unknown_nodes") == 0); + API_END(); +} + int MXSymbolInferShapePartial(SymbolHandle sym, mx_uint num_args, const char** keys, @@ -708,6 +787,31 @@ int MXSymbolInferShapePartialEx(SymbolHandle sym, &succ); } +int MXSymbolInferShapePartialExInt64(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_uint *arg_ind_ptr, + const int *arg_shape_data, + mx_uint *in_shape_size, + const int **in_shape_ndim, + const int64_t ***in_shape_data, + mx_uint *out_shape_size, + const int **out_shape_ndim, + const int64_t ***out_shape_data, + mx_uint *aux_shape_size, + const int **aux_shape_ndim, + const int64_t ***aux_shape_data, + int *complete) { + int succ; + *complete = 1; + return MXSymbolInferShapeExInt64(sym, num_args, keys, + arg_ind_ptr, arg_shape_data, + in_shape_size, in_shape_ndim, in_shape_data, + out_shape_size, out_shape_ndim, out_shape_data, + aux_shape_size, aux_shape_ndim, aux_shape_data, + &succ); +} + int MXSymbolInferType(SymbolHandle sym, mx_uint num_args, const char** keys, diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 5cd7bf6652d3..1e6d7fb579f3 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -647,9 +647,9 @@ void SliceEx(const nnvm::NodeAttrs& attrs, template inline void GetIndexRange(const mxnet::TShape& dshape, - const mxnet::Tuple>& param_begin, - const mxnet::Tuple>& param_end, - const mxnet::Tuple>& param_step, + const mxnet::Tuple>& param_begin, + const mxnet::Tuple>& param_end, + const mxnet::Tuple>& param_step, common::StaticArray* begin, common::StaticArray* end, common::StaticArray* step) { @@ -1010,8 +1010,8 @@ void SliceAssignOpForward(const nnvm::NodeAttrs& attrs, struct SliceAssignScalarParam : public dmlc::Parameter { double scalar; - mxnet::Tuple> begin, end; - mxnet::Tuple> step; + mxnet::Tuple> begin, end; + mxnet::Tuple> step; DMLC_DECLARE_PARAMETER(SliceAssignScalarParam) { DMLC_DECLARE_FIELD(scalar) .set_default(0) @@ -1021,7 +1021,7 @@ struct SliceAssignScalarParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(end) .describe("ending indices for the slice operation, supports negative indices."); DMLC_DECLARE_FIELD(step) - .set_default(mxnet::Tuple>()) + .set_default(mxnet::Tuple>()) .describe("step for the slice operation, supports negative values."); } }; @@ -1323,12 +1323,12 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, inline void SliceLikeInferRanges(const mxnet::TShape& dshape, const mxnet::TShape& fshape, const mxnet::Tuple& axes, - mxnet::Tuple>* param_begin, - mxnet::Tuple>* param_end, - mxnet::Tuple>* param_step) { - std::vector> pb(dshape.ndim()); - std::vector> pe(dshape.ndim()); - std::vector> ps(dshape.ndim()); + mxnet::Tuple>* param_begin, + mxnet::Tuple>* param_end, + mxnet::Tuple>* param_step) { + std::vector> pb(dshape.ndim()); + std::vector> pe(dshape.ndim()); + std::vector> ps(dshape.ndim()); if (axes.ndim() == 0) { for (int i = 0; i < dshape.ndim(); ++i) { pb[i] = 0; @@ -1352,9 +1352,9 @@ inline void SliceLikeInferRanges(const mxnet::TShape& dshape, ps[axis] = 1; } } - *param_begin = mxnet::Tuple>(pb.begin(), pb.end()); - *param_end = mxnet::Tuple>(pe.begin(), pe.end()); - *param_step = mxnet::Tuple>(ps.begin(), ps.end()); + *param_begin = mxnet::Tuple>(pb.begin(), pb.end()); + *param_end = mxnet::Tuple>(pe.begin(), pe.end()); + *param_step = mxnet::Tuple>(ps.begin(), ps.end()); } template @@ -1373,9 +1373,9 @@ void SliceLikeForward(const nnvm::NodeAttrs& attrs, const TBlob& out = outputs[0]; const mxnet::TShape& ishape = data.shape_; const mxnet::TShape& from_shape = inputs[1].shape_; - mxnet::Tuple> param_begin; - mxnet::Tuple> param_end; - mxnet::Tuple> param_step; + mxnet::Tuple> param_begin; + mxnet::Tuple> param_end; + mxnet::Tuple> param_step; SliceLikeInferRanges(ishape, from_shape, param.axes, ¶m_begin, ¶m_end, ¶m_step); MXNET_NDIM_SWITCH(data.ndim(), ndim, { @@ -1421,9 +1421,9 @@ void SliceLikeBackward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& ishape = ograd.shape_; const mxnet::TShape& from_shape = outputs[1].shape_; - mxnet::Tuple> param_begin; - mxnet::Tuple> param_end; - mxnet::Tuple> param_step; + mxnet::Tuple> param_begin; + mxnet::Tuple> param_end; + mxnet::Tuple> param_step; SliceLikeInferRanges(ishape, from_shape, param.axes, ¶m_begin, ¶m_end, ¶m_step); MXNET_NDIM_SWITCH(ograd.ndim(), ndim, { diff --git a/src/operator/tensor/slice-inl.h b/src/operator/tensor/slice-inl.h index 78a2bd8c7b45..7450e466440a 100644 --- a/src/operator/tensor/slice-inl.h +++ b/src/operator/tensor/slice-inl.h @@ -34,15 +34,15 @@ namespace mxnet { namespace op { struct SliceParam : public dmlc::Parameter { - mxnet::Tuple> begin, end; - mxnet::Tuple> step; + mxnet::Tuple> begin, end; + mxnet::Tuple> step; DMLC_DECLARE_PARAMETER(SliceParam) { DMLC_DECLARE_FIELD(begin) .describe("starting indices for the slice operation, supports negative indices."); DMLC_DECLARE_FIELD(end) .describe("ending indices for the slice operation, supports negative indices."); DMLC_DECLARE_FIELD(step) - .set_default(mxnet::Tuple>()) + .set_default(mxnet::Tuple>()) .describe("step for the slice operation, supports negative values."); } bool operator==(const SliceParam& other) const { diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 0df481a01987..58ac71f3de1f 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -157,8 +157,8 @@ def test_take(): def test_slice(): - a = nd.ones(shape=(LARGE_X, SMALL_Y)) - res = nd.slice(a, begin=(LARGE_X-1000, 1), end=(LARGE_X, SMALL_Y)) + a = nd.ones(shape=(2, LARGE_SIZE)) + res = nd.slice(a, begin=(1, LARGE_SIZE-1000000000), end=(2, LARGE_SIZE)) assert np.sum(res[-1].asnumpy() == 1) == res.shape[1]