diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index bd30e44f910c..deab404d4114 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -55,7 +55,8 @@ extern "C" { #endif /*! \brief manually define unsigned int */ -typedef unsigned int mx_uint; +typedef int64_t mx_int64; +typedef uint32_t mx_uint; /*! \brief manually define float */ typedef float mx_float; /*! \brief data type to store dim size */ @@ -556,6 +557,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape, int dtype, NDArrayHandle *out); +MXNET_DLL int MXNDArrayCreateExInt64(const mx_int64 *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle *out); /*! * \brief create an empty sparse NDArray with specified shape and data type @@ -793,6 +801,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle, MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata); + +MXNET_DLL int MXNDArrayGetShapeInt64(NDArrayHandle handle, + mx_int64 *out_dim, + const mx_int64 **out_pdata); + /*! * \brief get the shape of the array * \param handle the handle to the narray @@ -802,7 +815,8 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, */ MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle, int *out_dim, - const int **out_pdata); + const int64_t **out_pdata); + /*! * \brief get the content of the data in NDArray * \param handle the handle to the ndarray @@ -1456,7 +1470,7 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol, - mx_uint *output_count); + mx_uint *output_count); /*! * \brief Get a symbol that contains all the internals. diff --git a/include/mxnet/c_predict_api.h b/include/mxnet/c_predict_api.h index 18bec625f05f..b18a43241687 100644 --- a/include/mxnet/c_predict_api.h +++ b/include/mxnet/c_predict_api.h @@ -42,7 +42,8 @@ extern "C" { #endif /*! \brief manually define unsigned int */ -typedef unsigned int mx_uint; +typedef int64_t mx_int64; +typedef uint32_t mx_uint; /*! \brief manually define float */ typedef float mx_float; /*! \brief handle to Predictor */ diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h index bc630f153744..77468d0ebdef 100644 --- a/include/mxnet/tuple.h +++ b/include/mxnet/tuple.h @@ -366,7 +366,6 @@ class Tuple { } }; - /*! brief check if a shape's ndim is known. */ inline bool ndim_is_known(const int ndim) { CHECK_GE(ndim, -1) << "shape ndim must be >= -1, while received " << ndim; diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 73fae4876873..f99ba03eb7b5 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -215,6 +215,7 @@ def _load_lib(): # type definitions mx_int = ctypes.c_int mx_uint = ctypes.c_uint +mx_int64 = ctypes.c_int64 mx_float = ctypes.c_float mx_float_p = ctypes.POINTER(mx_float) mx_real_t = _np.float32 diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 3fb1af6a7336..f18f9b1f9d8a 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -35,7 +35,7 @@ import numpy as np from ..base import _LIB, numeric_types, integer_types from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t -from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int +from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int, mx_int64 from ..base import ctypes2buffer from ..context import Context, current_context from . import _internal @@ -130,10 +130,12 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t): handle A new empty `NDArray` handle. """ + print("shape={}".format(shape)) hdl = NDArrayHandle() - check_call(_LIB.MXNDArrayCreateEx( - c_array_buf(mx_uint, native_array('I', shape)), - mx_uint(len(shape)), +# check_call(_LIB.MXNDArrayCreateEx( + check_call(_LIB.MXNDArrayCreateExInt64( + c_array_buf(mx_int64, native_array('q', shape)), + mx_int64(len(shape)), ctypes.c_int(ctx.device_typeid), ctypes.c_int(ctx.device_id), ctypes.c_int(int(delay_alloc)), @@ -1847,7 +1849,7 @@ def shape(self): (2L, 3L, 4L) """ ndim = mx_int() - pdata = ctypes.POINTER(mx_int)() + pdata = ctypes.POINTER(mx_int64)() check_call(_LIB.MXNDArrayGetShapeEx( self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) if ndim.value == -1: diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 35bd3eeb477a..b0f41f53929d 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -179,13 +179,29 @@ int MXNDArrayCreate(const mx_uint *shape, API_END(); } +int MXNDArrayCreateExInt64(const mx_int64 *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle *out) { + API_BEGIN(); + *out = new NDArray( + mxnet::TShape(shape, shape + ndim), + Context::Create(static_cast(dev_type), dev_id), + delay_alloc != 0, + dtype); + API_END(); +} + int MXNDArrayCreateEx(const mx_uint *shape, - mx_uint ndim, - int dev_type, - int dev_id, - int delay_alloc, - int dtype, - NDArrayHandle *out) { + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle *out) { API_BEGIN(); *out = new NDArray( mxnet::TShape(shape, shape + ndim), @@ -513,9 +529,10 @@ int MXNDArrayGetShape(NDArrayHandle handle, API_END(); } -int MXNDArrayGetShapeEx(NDArrayHandle handle, +template +int MXNDArrayGetShapeExInt(NDArrayHandle handle, int *out_dim, - const int **out_pdata) { + const IntType **out_pdata) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); NDArray *arr = static_cast(handle); @@ -526,7 +543,7 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle, } *out_dim = s.ndim(); if (s.ndim() >= 0) { - std::vector &buffer = ret->arg_shape_buffer_ex; + std::vector &buffer = ret->arg_shape_buffer_ex_int64; buffer.resize(s.ndim()); mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data()); *out_pdata = buffer.data(); @@ -541,6 +558,16 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle, API_END(); } +int MXNDArrayGetShapeEx(NDArrayHandle handle, + int *out_dim, + const int64_t **out_pdata) { +#if MXNET_USE_INT64_TENSOR_SIZE == 1 + return MXNDArrayGetShapeExInt(handle, out_dim, out_pdata); +#else + return MXNDArrayGetShapeExInt(handle, out_dim, out_pdata); +#endif +} + int MXNDArrayGetData(NDArrayHandle handle, void **out_pdata) { API_BEGIN(); diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index 013ecab93da8..a62fbad82700 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -81,10 +81,12 @@ struct MXAPIThreadLocalEntry { std::vector arg_shape_data, out_shape_data, aux_shape_data; /*! \brief result holder for returning shape pointer */ std::vector arg_shape_data_ex, out_shape_data_ex, aux_shape_data_ex; + std::vector arg_shape_data_ex_int64, out_shape_data_ex_int64, aux_shape_data_ex_int64; /*! \brief uint32_t buffer for returning shape pointer */ std::vector arg_shape_buffer, out_shape_buffer, aux_shape_buffer; /*! \brief uint32_t buffer for returning shape pointer */ std::vector arg_shape_buffer_ex, out_shape_buffer_ex, aux_shape_buffer_ex; + std::vector arg_shape_buffer_ex_int64, out_shape_buffer_ex_int64, aux_shape_buffer_ex_int64; /*! \brief bool buffer */ std::vector save_inputs, save_outputs; // DEPRECATED. Use SetupShapeArrayReturnWithBufferEx instead. @@ -130,6 +132,31 @@ struct MXAPIThreadLocalEntry { } } } + + inline static void SetupShapeArrayReturnWithBufferExInt64( + const mxnet::ShapeVector &shapes, + std::vector *ndim, + std::vector *data, + std::vector *buffer) { + ndim->resize(shapes.size()); + data->resize(shapes.size()); + size_t size = 0; + for (const auto& s : shapes) { + if (s.ndim() > 0) { + size += s.ndim(); + } + } + buffer->resize(size); + int64_t *ptr = buffer->data(); + for (size_t i = 0; i < shapes.size(); ++i) { + ndim->at(i) = shapes[i].ndim(); + data->at(i) = ptr; + if (shapes[i].ndim() > 0) { + ptr = mxnet::ShapeTypeCast(shapes[i].begin(), shapes[i].end(), ptr); + } + } + } + }; // define the threadlocal store. diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 5cd7bf6652d3..1e6d7fb579f3 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -647,9 +647,9 @@ void SliceEx(const nnvm::NodeAttrs& attrs, template inline void GetIndexRange(const mxnet::TShape& dshape, - const mxnet::Tuple>& param_begin, - const mxnet::Tuple>& param_end, - const mxnet::Tuple>& param_step, + const mxnet::Tuple>& param_begin, + const mxnet::Tuple>& param_end, + const mxnet::Tuple>& param_step, common::StaticArray* begin, common::StaticArray* end, common::StaticArray* step) { @@ -1010,8 +1010,8 @@ void SliceAssignOpForward(const nnvm::NodeAttrs& attrs, struct SliceAssignScalarParam : public dmlc::Parameter { double scalar; - mxnet::Tuple> begin, end; - mxnet::Tuple> step; + mxnet::Tuple> begin, end; + mxnet::Tuple> step; DMLC_DECLARE_PARAMETER(SliceAssignScalarParam) { DMLC_DECLARE_FIELD(scalar) .set_default(0) @@ -1021,7 +1021,7 @@ struct SliceAssignScalarParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(end) .describe("ending indices for the slice operation, supports negative indices."); DMLC_DECLARE_FIELD(step) - .set_default(mxnet::Tuple>()) + .set_default(mxnet::Tuple>()) .describe("step for the slice operation, supports negative values."); } }; @@ -1323,12 +1323,12 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, inline void SliceLikeInferRanges(const mxnet::TShape& dshape, const mxnet::TShape& fshape, const mxnet::Tuple& axes, - mxnet::Tuple>* param_begin, - mxnet::Tuple>* param_end, - mxnet::Tuple>* param_step) { - std::vector> pb(dshape.ndim()); - std::vector> pe(dshape.ndim()); - std::vector> ps(dshape.ndim()); + mxnet::Tuple>* param_begin, + mxnet::Tuple>* param_end, + mxnet::Tuple>* param_step) { + std::vector> pb(dshape.ndim()); + std::vector> pe(dshape.ndim()); + std::vector> ps(dshape.ndim()); if (axes.ndim() == 0) { for (int i = 0; i < dshape.ndim(); ++i) { pb[i] = 0; @@ -1352,9 +1352,9 @@ inline void SliceLikeInferRanges(const mxnet::TShape& dshape, ps[axis] = 1; } } - *param_begin = mxnet::Tuple>(pb.begin(), pb.end()); - *param_end = mxnet::Tuple>(pe.begin(), pe.end()); - *param_step = mxnet::Tuple>(ps.begin(), ps.end()); + *param_begin = mxnet::Tuple>(pb.begin(), pb.end()); + *param_end = mxnet::Tuple>(pe.begin(), pe.end()); + *param_step = mxnet::Tuple>(ps.begin(), ps.end()); } template @@ -1373,9 +1373,9 @@ void SliceLikeForward(const nnvm::NodeAttrs& attrs, const TBlob& out = outputs[0]; const mxnet::TShape& ishape = data.shape_; const mxnet::TShape& from_shape = inputs[1].shape_; - mxnet::Tuple> param_begin; - mxnet::Tuple> param_end; - mxnet::Tuple> param_step; + mxnet::Tuple> param_begin; + mxnet::Tuple> param_end; + mxnet::Tuple> param_step; SliceLikeInferRanges(ishape, from_shape, param.axes, ¶m_begin, ¶m_end, ¶m_step); MXNET_NDIM_SWITCH(data.ndim(), ndim, { @@ -1421,9 +1421,9 @@ void SliceLikeBackward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& ishape = ograd.shape_; const mxnet::TShape& from_shape = outputs[1].shape_; - mxnet::Tuple> param_begin; - mxnet::Tuple> param_end; - mxnet::Tuple> param_step; + mxnet::Tuple> param_begin; + mxnet::Tuple> param_end; + mxnet::Tuple> param_step; SliceLikeInferRanges(ishape, from_shape, param.axes, ¶m_begin, ¶m_end, ¶m_step); MXNET_NDIM_SWITCH(ograd.ndim(), ndim, { diff --git a/src/operator/tensor/slice-inl.h b/src/operator/tensor/slice-inl.h index 78a2bd8c7b45..7450e466440a 100644 --- a/src/operator/tensor/slice-inl.h +++ b/src/operator/tensor/slice-inl.h @@ -34,15 +34,15 @@ namespace mxnet { namespace op { struct SliceParam : public dmlc::Parameter { - mxnet::Tuple> begin, end; - mxnet::Tuple> step; + mxnet::Tuple> begin, end; + mxnet::Tuple> step; DMLC_DECLARE_PARAMETER(SliceParam) { DMLC_DECLARE_FIELD(begin) .describe("starting indices for the slice operation, supports negative indices."); DMLC_DECLARE_FIELD(end) .describe("ending indices for the slice operation, supports negative indices."); DMLC_DECLARE_FIELD(step) - .set_default(mxnet::Tuple>()) + .set_default(mxnet::Tuple>()) .describe("step for the slice operation, supports negative values."); } bool operator==(const SliceParam& other) const { diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 0df481a01987..0e3a5ff7707c 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -157,8 +157,8 @@ def test_take(): def test_slice(): - a = nd.ones(shape=(LARGE_X, SMALL_Y)) - res = nd.slice(a, begin=(LARGE_X-1000, 1), end=(LARGE_X, SMALL_Y)) + a = nd.ones(shape=(LARGE_SIZE, 2)) + res = nd.slice(a, begin=(LARGE_SIZE-1000, 1), end=(LARGE_SIZE, 2)) assert np.sum(res[-1].asnumpy() == 1) == res.shape[1]