diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index cf44da699156..4b8663bba6ea 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -100,7 +100,7 @@ inline bool ElemwiseStorageAttr(const nnvm::NodeAttrs& attrs, * \tparam rsp whether row sparse stype is supported * \tparam rsp whether csr stype is supported */ -template +template inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -115,7 +115,7 @@ inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, template + index_t n_in = -1, index_t n_out = -1> inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs, @@ -154,7 +154,7 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, return true; } -template +template inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { @@ -168,7 +168,7 @@ inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs, TShape()); } -template +template inline bool ElemwiseType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 5b106afd8d5b..6cab1990858b 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -289,8 +289,8 @@ inline int get_num_threads(const int N) { /* \brief Compute flattened index given coordinates and shape. */ template -MSHADOW_XINLINE int ravel(const Shape& coord, const Shape& shape) { - int ret = 0; +MSHADOW_XINLINE index_t ravel(const Shape& coord, const Shape& shape) { + index_t ret = 0; #pragma unroll for (int i = 0; i < ndim; ++i) { ret = ret * shape[i] + (shape[i] > coord[i]) * coord[i]; @@ -301,11 +301,11 @@ MSHADOW_XINLINE int ravel(const Shape& coord, const Shape& shape) { /* Compute coordinates from flattened index given shape */ template -MSHADOW_XINLINE Shape unravel(const int idx, const Shape& shape) { +MSHADOW_XINLINE Shape unravel(const index_t idx, const Shape& shape) { Shape ret; #pragma unroll - for (int i = ndim-1, j = idx; i >=0; --i) { - int tmp = j / shape[i]; + for (index_t i = ndim-1, j = idx; i >=0; --i) { + auto tmp = j / shape[i]; ret[i] = j - tmp*shape[i]; j = tmp; } @@ -315,8 +315,8 @@ MSHADOW_XINLINE Shape unravel(const int idx, const Shape& shape) { /* Compute dot product of two vector */ template -MSHADOW_XINLINE int dot(const Shape& coord, const Shape& stride) { - int ret = 0; +MSHADOW_XINLINE index_t dot(const Shape& coord, const Shape& stride) { + index_t ret = 0; #pragma unroll for (int i = 0; i < ndim; ++i) { ret += coord[i] * stride[i]; @@ -327,12 +327,12 @@ MSHADOW_XINLINE int dot(const Shape& coord, const Shape& stride) { /* Combining unravel and dot */ template -MSHADOW_XINLINE int unravel_dot(const int idx, const Shape& shape, +MSHADOW_XINLINE index_t unravel_dot(const index_t idx, const Shape& shape, const Shape& stride) { - int ret = 0; + index_t ret = 0; #pragma unroll - for (int i = ndim-1, j = idx; i >=0; --i) { - int tmp = j / shape[i]; + for (index_t i = ndim-1, j = idx; i >=0; --i) { + auto tmp = j / shape[i]; ret += (j - tmp*shape[i])*stride[i]; j = tmp; } @@ -433,51 +433,51 @@ struct op_with_req { /*! \brief input is one tensor */ template - MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in) { + MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *in) { KERNEL_ASSIGN(out[i], req, OP::Map(in[i])); } /*! \brief inputs are two tensors */ template - MSHADOW_XINLINE static void Map(int i, DType *out, const DType *lhs, const DType *rhs) { + MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *lhs, const DType *rhs) { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); } /*! \brief input is tensor and a scalar value */ template - MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in, const DType value) { + MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *in, const DType value) { KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value)); } /*! \brief input is tensor and two scalar value */ template - MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in, + MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *in, const DType value_1, const DType value_2) { KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value_1, value_2)); } /*! \brief No inputs (ie fill to constant value) */ template - MSHADOW_XINLINE static void Map(int i, DType *out) { + MSHADOW_XINLINE static void Map(index_t i, DType *out) { KERNEL_ASSIGN(out[i], req, OP::Map()); } /*! \brief input is single scalar value */ template - MSHADOW_XINLINE static void Map(int i, DType *out, const DType value) { + MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType value) { KERNEL_ASSIGN(out[i], req, OP::Map(value)); } /*! \brief inputs are two tensors and a scalar value */ template - MSHADOW_XINLINE static void Map(int i, DType *out, + MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *input_1, const DType *input_2, const DType value) { KERNEL_ASSIGN(out[i], req, OP::Map(input_1[i], input_2[i], value)); } /*! \brief inputs are three tensors (ie backward grad with binary grad function) */ template - MSHADOW_XINLINE static void Map(int i, DType *out, + MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *input_1, const DType *input_2, const DType *input_3) { @@ -503,21 +503,21 @@ struct Kernel { * \param args Varargs to eventually pass to the OP::Map() function */ template - inline static bool Launch(mshadow::Stream *, const int N, Args... args) { + inline static bool Launch(mshadow::Stream *, const size_t N, Args... args) { #ifdef _OPENMP const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); if (omp_threads < 2) { - for (int i = 0; i < N; ++i) { + for (size_t i = 0; i < N; ++i) { OP::Map(i, args...); } } else { #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < N; ++i) { + for (index_t i = 0; i < static_cast(N); ++i) { OP::Map(i, args...); } } #else - for (int i = 0; i < N; ++i) { + for (size_t i = 0; i < N; ++i) { OP::Map(i, args...); } #endif @@ -567,22 +567,22 @@ struct Kernel { * \param args Varargs to eventually pass to the OP::Map() function */ template - static void LaunchTuned(mshadow::Stream *, const int N, Args... args) { + static void LaunchTuned(mshadow::Stream *, const size_t N, Args... args) { #ifdef _OPENMP const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); if (omp_threads < 2 || !tuned_op::UseOMP( - static_cast(N), static_cast(omp_threads))) { - for (int i = 0; i < N; ++i) { + N, static_cast(omp_threads))) { + for (size_t i = 0; i < N; ++i) { OP::Map(i, args...); } } else { #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < N; ++i) { + for (index_t i = 0; i < static_cast(N); ++i) { OP::Map(i, args...); } } #else - for (int i = 0; i < N; ++i) { + for (size_t i = 0; i < N; ++i) { OP::Map(i, args...); } #endif @@ -596,15 +596,15 @@ struct Kernel { * \param args Varargs to eventually pass to the UseOMP() and OP::Map() functions */ template - inline static void LaunchEx(mshadow::Stream *s, const int N, Args... args) { + inline static void LaunchEx(mshadow::Stream *s, const size_t N, Args... args) { #ifdef _OPENMP const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); if (omp_threads < 2) { OP::Map(0, N, args...); } else { - const int length = (N + omp_threads - 1) / omp_threads; + const auto length = (N + omp_threads - 1) / omp_threads; #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < N; i += length) { + for (index_t i = 0; i < static_cast(N); i += length) { OP::Map(i, i + length > N ? N - i : length, args...); } } @@ -626,7 +626,7 @@ struct Kernel { template static MSHADOW_CINLINE typename std::enable_if::value, bool>::type - Launch(mshadow::Stream *s, const int N, DType *dest, Args... args) { + Launch(mshadow::Stream *s, const size_t N, DType *dest, Args... args) { LaunchTuned(s, N, dest, args...); return true; } @@ -644,7 +644,7 @@ struct Kernel { template static MSHADOW_CINLINE typename std::enable_if::value, bool>::type - Launch(mshadow::Stream *s, const int N, DType *dest, Args... args) { + Launch(mshadow::Stream *s, const size_t N, DType *dest, Args... args) { LaunchTuned(s, N, dest, args...); return true; } @@ -700,7 +700,7 @@ template struct set_to_int : public tunable { // mxnet_op version (when used directly with Kernel<>::Launch()) */ template - MSHADOW_XINLINE static void Map(int i, DType *out) { + MSHADOW_XINLINE static void Map(index_t i, DType *out) { out[i] = DType(val); } // mshadow_op version (when used with op_with_req<>) diff --git a/src/operator/random/sampler.h b/src/operator/random/sampler.h index ca764e706c64..00963a6785ee 100644 --- a/src/operator/random/sampler.h +++ b/src/operator/random/sampler.h @@ -43,32 +43,33 @@ namespace op { template inline static void LaunchRNG(mshadow::Stream *s, common::random::RandGenerator *gen, - const int N, Args... args) { + const index_t N, Args... args) { // minimal check to avoid division by zero, below. // if `N` is zero the map operation is a no-op in any case. if (N <= 0) { return; } - const int nloop = (N + RandGenerator::kMinNumRandomPerThread - 1) / + const index_t nloop = (N + RandGenerator::kMinNumRandomPerThread - 1) / RandGenerator::kMinNumRandomPerThread; - const int nthread = std::min(nloop, RandGenerator::kNumRandomStates); - const int step = (N + nthread - 1) / nthread; + const index_t nthread = std::min(nloop, + static_cast(RandGenerator::kNumRandomStates)); + const index_t step = (N + nthread - 1) / nthread; Kernel::Launch(s, nthread, *gen, N, step, args...); } #define RNG_KERNEL_LOOP(xpu, GType, thread_id, gen, N, step, ...) \ - const int start = thread_id * step; \ - const int end = start + step; \ + const index_t start = thread_id * step; \ + const index_t end = start + step; \ typename RandGenerator::Impl genImpl(&gen, thread_id); \ - for (int i = start; i < end && i < N; ++i) { \ + for (index_t i = start; i < end && i < N; ++i) { \ {__VA_ARGS__} \ } template struct SampleUniformKernel { template - MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + MSHADOW_XINLINE static void Map(index_t id, RandGenerator gen, + const index_t N, const index_t step, index_t nParm, index_t nSample, const IType *lower, const IType *upper, OType *out) { RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, { @@ -127,8 +128,8 @@ struct RandIntSampler { template struct SampleNormalKernel { template - MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + MSHADOW_XINLINE static void Map(index_t id, RandGenerator gen, + const index_t N, const index_t step, index_t nParm, index_t nSample, const IType *mean, const IType *std, OType *out) { RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, { @@ -154,8 +155,8 @@ struct NormalSampler { template struct SampleExponentialKernel { template - MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + MSHADOW_XINLINE static void Map(index_t id, RandGenerator gen, + const index_t N, const index_t step, index_t nParm, index_t nSample, const IType *lambda, OType *out) { RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, { @@ -202,8 +203,8 @@ MSHADOW_XINLINE OType SampleGamma(IType a, IType b, typename RandGenerator struct SampleGammaKernel { template - MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + MSHADOW_XINLINE static void Map(index_t id, RandGenerator gen, + const index_t N, const index_t step, index_t nParm, index_t nSample, const IType *alpha, const IType *beta, OType *out) { RNG_KERNEL_LOOP(xpu, FType, id, gen, N, step, { @@ -264,8 +265,8 @@ MSHADOW_XINLINE int SamplePoisson(float lambda, typename RandGenerator struct SamplePoissonKernel { template - MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + MSHADOW_XINLINE static void Map(index_t id, RandGenerator gen, + const index_t N, const index_t step, index_t nParm, index_t nSample, const IType *lambda, OType *out) { RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, { @@ -291,8 +292,8 @@ struct PoissonSampler { template struct SampleNegativeBinomialKernel { template - MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + MSHADOW_XINLINE static void Map(index_t id, RandGenerator gen, + const index_t N, const index_t step, index_t nParm, index_t nSample, const IType *k, const IType *p, OType *out) { RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, { @@ -323,8 +324,8 @@ struct NegativeBinomialSampler { template struct SampleGeneralizedNegativeBinomialKernel { template - MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + MSHADOW_XINLINE static void Map(index_t id, RandGenerator gen, + const index_t N, const index_t step, index_t nParm, index_t nSample, const IType *mu, const IType *alpha, OType *out) { RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, { diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h index 167fa34b083f..141d2fb83d0d 100644 --- a/src/operator/tensor/broadcast_reduce-inl.h +++ b/src/operator/tensor/broadcast_reduce-inl.h @@ -53,14 +53,14 @@ MSHADOW_XINLINE Shape calc_stride(const Shape& shape) { } template -MSHADOW_XINLINE void unravel_dot(const int idx, const Shape& shape, - const Shape& stridej, const Shape& stridek, int* j, int* k) { +MSHADOW_XINLINE void unravel_dot(const index_t idx, const Shape& shape, + const Shape& stridej, const Shape& stridek, index_t* j, index_t* k) { *j = 0; *k = 0; #pragma unroll - for (int i = ndim-1, idx_t = idx; i >=0; --i) { - const int tmp = idx_t / shape[i]; - const int coord = idx_t - tmp*shape[i]; + for (index_t i = ndim-1, idx_t = idx; i >=0; --i) { + const auto tmp = idx_t / shape[i]; + const auto coord = idx_t - tmp*shape[i]; *j += coord*stridej[i]; *k += coord*stridek[i]; idx_t = tmp; @@ -68,11 +68,11 @@ MSHADOW_XINLINE void unravel_dot(const int idx, const Shape& shape, } template -MSHADOW_XINLINE Shape unravel(const int idx, const Shape& shape) { +MSHADOW_XINLINE Shape unravel(const index_t idx, const Shape& shape) { Shape ret; #pragma unroll - for (int i = ndim-1, j = idx; i >=0; --i) { - int tmp = j / shape[i]; + for (index_t i = ndim-1, j = idx; i >=0; --i) { + auto tmp = j / shape[i]; ret[i] = j - tmp*shape[i]; j = tmp; } @@ -80,10 +80,10 @@ MSHADOW_XINLINE Shape unravel(const int idx, const Shape& shape) { } template -MSHADOW_XINLINE int ravel(const Shape& coord, const Shape& shape) { - int ret = 0; +MSHADOW_XINLINE index_t ravel(const Shape& coord, const Shape& shape) { + index_t ret = 0; #pragma unroll - for (int i = 0; i < ndim; ++i) { + for (index_t i = 0; i < ndim; ++i) { ret = ret * shape[i] + (shape[i] > 1) * coord[i]; } return ret; @@ -111,12 +111,12 @@ MSHADOW_XINLINE int diff(const Shape& small, const Shape& big, Shape } template -MSHADOW_XINLINE int unravel_dot(const int idx, const Shape& shape, +MSHADOW_XINLINE index_t unravel_dot(const index_t idx, const Shape& shape, const Shape& stride) { - int ret = 0; + index_t ret = 0; #pragma unroll - for (int i = ndim-1, j = idx; i >=0; --i) { - int tmp = j / shape[i]; + for (index_t i = ndim-1, j = idx; i >=0; --i) { + auto tmp = j / shape[i]; ret += (j - tmp*shape[i])*stride[i]; j = tmp; } @@ -124,8 +124,8 @@ MSHADOW_XINLINE int unravel_dot(const int idx, const Shape& shape, } template -MSHADOW_XINLINE int dot(const Shape& coord, const Shape& stride) { - int ret = 0; +MSHADOW_XINLINE index_t dot(const Shape& coord, const Shape& stride) { + index_t ret = 0; #pragma unroll for (int i = 0; i < ndim; ++i) ret += coord[i] * stride[i]; @@ -142,27 +142,27 @@ MSHADOW_XINLINE void assign(DType* dst, const bool addto, const DType src) { } template -MSHADOW_XINLINE void binary_broadcast_assign(const int idx, const bool addto, +MSHADOW_XINLINE void binary_broadcast_assign(const index_t idx, const bool addto, const DType* __restrict lhs, const DType* __restrict rhs, DType* out, const Shape& lshape, const Shape& rshape, const Shape& oshape) { const Shape coord = unravel(idx, oshape); - const int j = ravel(coord, lshape); - const int k = ravel(coord, rshape); + const index_t j = ravel(coord, lshape); + const index_t k = ravel(coord, rshape); assign(&out[idx], addto, OP::Map(lhs[j], rhs[k])); } template -MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool addto, +MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const bool addto, const DType* __restrict big, DType *small, const Shape& bshape, const Shape& sshape, const Shape& rshape, const Shape& rstride) { Shape coord = unravel(idx, sshape); - int j = ravel(coord, bshape); + index_t j = ravel(coord, bshape); DType val, residual; Reducer::SetInitValue(val, residual); - for (int k = 0; k < M; ++k) { + for (size_t k = 0; k < M; ++k) { coord = unravel(k, rshape); Reducer::Reduce(val, OP::Map(big[j + dot(coord, rstride)]), residual); } @@ -176,10 +176,10 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad #else template -void binary_broadcast_compute(const int N, const bool addto, const DType *lhs, +void binary_broadcast_compute(const size_t N, const bool addto, const DType *lhs, const DType *rhs, DType *out, const Shape lshape, const Shape rshape, const Shape oshape) { - for (int idx = 0; idx < N; ++idx) { + for (size_t idx = 0; idx < N; ++idx) { binary_broadcast_assign(idx, addto, lhs, rhs, out, lshape, rshape, oshape); } } @@ -188,26 +188,26 @@ template void BinaryBroadcastComputeImpl(Stream *s, const OpReqType req, const TBlob& lhs, const TBlob& rhs, const TBlob& out) { if (req == kNullOp) return; - int N = out.shape_.Size(); + size_t N = out.shape_.Size(); binary_broadcast_compute(N, req == kAddTo, lhs.dptr(), rhs.dptr(), out.dptr(), lhs.shape_.get(), rhs.shape_.get(), out.shape_.get()); } template -void seq_reduce_compute(const int N, const int M, const bool addto, +void seq_reduce_compute(const size_t N, const size_t M, const bool addto, const DType *big, DType *small, const Shape bshape, const Shape sshape, const Shape rshape, const Shape rstride) { #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (int idx = 0; idx < N; ++idx) { + for (index_t idx = 0; idx < static_cast(N); ++idx) { seq_reduce_assign(idx, M, addto, big, small, bshape, sshape, rshape, rstride); } } template -void seq_reduce_compute_extra_mem(const int N, const int M, const bool addto, +void seq_reduce_compute_extra_mem(const size_t N, const size_t M, const bool addto, const DType* big, DType* small, const Shape bshape, const Shape sshape, @@ -215,12 +215,12 @@ void seq_reduce_compute_extra_mem(const int N, const int M, const bool addto, const Shape rstride, const index_t* ws_dptr) { #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (int idx = 0; idx < N; ++idx) { + for (index_t idx = 0; idx < static_cast(N); ++idx) { Shape coord = unravel(idx, sshape); - int j = ravel(coord, bshape); + index_t j = ravel(coord, bshape); DType val, residual; Reducer::SetInitValue(val, residual); - for (int k = 0; k < M; ++k) { + for (size_t k = 0; k < M; ++k) { Reducer::Reduce(val, OP::Map(big[j + ws_dptr[k]]), residual); } assign(&small[idx], addto, val); @@ -233,7 +233,7 @@ void Reduce(Stream* s, const TBlob& small, const OpReqType req, if (req == kNullOp) return; Shape rshape, rstride; diff(small.shape_.get(), big.shape_.get(), &rshape, &rstride); - int N = small.shape_.Size(), M = rshape.Size(); + size_t N = small.shape_.Size(), M = rshape.Size(); seq_reduce_compute( N, M, req == kAddTo, big.dptr(), small.dptr(), big.shape_.get(), small.shape_.get(), rshape, rstride); @@ -247,9 +247,9 @@ void ReduceWithExtraMem(Stream* s, const TBlob& small, const OpReqType req, Shape rshape, rstride; diff(small.shape_.get(), big.shape_.get(), &rshape, &rstride); index_t* ws_dptr = reinterpret_cast(workspace.dptr_); - int N = small.shape_.Size(), M = rshape.Size(); + size_t N = small.shape_.Size(), M = rshape.Size(); #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (int k = 0; k < M; k++) { + for (index_t k = 0; k < static_cast(M); k++) { Shape coord = unravel(k, rshape); ws_dptr[k] = dot(coord, rstride); } @@ -272,7 +272,7 @@ size_t ReduceWorkspaceSize(Stream *s, const TShape& small, const OpReqType } template -MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool addto, +MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const bool addto, const DType* __restrict big, const DType* __restrict lhs, const DType* __restrict rhs, DType *small, const Shape& big_shape, const Shape& lhs_shape0, @@ -282,20 +282,20 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad const Shape& rstride, const Shape& lhs_stride, const Shape& rhs_stride) { Shape coord = unravel(idx, small_shape); - const int idx_big0 = ravel(coord, big_shape); - const int idx_lhs0 = ravel(coord, lhs_shape0); - const int idx_rhs0 = ravel(coord, rhs_shape0); + const index_t idx_big0 = ravel(coord, big_shape); + const index_t idx_lhs0 = ravel(coord, lhs_shape0); + const index_t idx_rhs0 = ravel(coord, rhs_shape0); DType val, residual; Reducer::SetInitValue(val, residual); - for (int k = 0; k < M; ++k) { + for (size_t k = 0; k < M; ++k) { Shape coord_big = unravel(k, rshape); - int idx_big = idx_big0 + dot(coord_big, rstride); + index_t idx_big = idx_big0 + dot(coord_big, rstride); Shape coord_lhs = unravel(k, lhs_shape); - int idx_lhs = idx_lhs0 + dot(coord_lhs, lhs_stride); + index_t idx_lhs = idx_lhs0 + dot(coord_lhs, lhs_stride); Shape coord_rhs = unravel(k, rhs_shape); - int idx_rhs = idx_rhs0 + dot(coord_rhs, rhs_stride); + index_t idx_rhs = idx_rhs0 + dot(coord_rhs, rhs_stride); Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual); } @@ -304,7 +304,7 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad } template -void seq_reduce_compute(const int N, const int M, const bool addto, +void seq_reduce_compute(const size_t N, const size_t M, const bool addto, const DType *big, const DType *lhs, const DType *rhs, DType *small, const Shape big_shape, const Shape small_shape, const Shape rshape, const Shape rstride, @@ -312,7 +312,7 @@ void seq_reduce_compute(const int N, const int M, const bool addto, const Shape rhs_shape, const Shape rhs_stride, const Shape& lhs_shape0, const Shape& rhs_shape0) { #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (int idx = 0; idx < N; ++idx) { + for (index_t idx = 0; idx < static_cast(N); ++idx) { seq_reduce_assign(idx, M, addto, big, lhs, rhs, small, big_shape, lhs_shape0, rhs_shape0, small_shape, rshape, lhs_shape, rhs_shape, rstride, lhs_stride, rhs_stride); @@ -326,8 +326,8 @@ void Reduce(Stream *s, const TBlob& small, const OpReqType req, if (req == kNullOp) return; Shape rshape, rstride; diff(small.shape_.get(), big.shape_.get(), &rshape, &rstride); - int N = small.shape_.Size(); - int M = rshape.Size(); + size_t N = small.shape_.Size(); + size_t M = rshape.Size(); Shape lhs_shape, lhs_stride; diff(small.shape_.get(), lhs.shape_.get(), &lhs_shape, &lhs_stride); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 391c35117128..304422038b89 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -190,7 +190,7 @@ namespace mxnet_op { template struct binary_broadcast_kernel { /*! \brief Map function for binary_broadcast_kernel */ - MSHADOW_XINLINE static void Map(int base, int length, OpReqType req, + MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, const Shape &lstride, const Shape &rstride, const Shape &oshape, DType *lhs, DType *rhs, DType *out) { @@ -199,7 +199,7 @@ struct binary_broadcast_kernel { auto ridx = static_cast(dot(coord, rstride)); KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx])); // starts from 1 to avoid extra inc at end of loop - for (int i = 1; i < length; ++i) { + for (index_t i = 1; i < length; ++i) { inc(&coord, oshape, &lidx, lstride, &ridx, rstride); // When tuning, don't actually run the op, since it's not going to be tuned against // the actual op we'll eventually be using @@ -208,7 +208,7 @@ struct binary_broadcast_kernel { } /*! \brief Map function for binary_broadcast_kernel */ - MSHADOW_XINLINE static void Map(int base, int length, OpReqType req, + MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, const Shape &lstride, const Shape &rstride, const Shape &oshape, DType lhs, DType *rhs, DType *out) { @@ -217,7 +217,7 @@ struct binary_broadcast_kernel { auto ridx = static_cast(dot(coord, rstride)); KERNEL_ASSIGN(out[base], req, OP::Map(lhs, rhs[ridx])); // starts from 1 to avoid extra inc at end of loop - for (int i = 1; i < length; ++i) { + for (index_t i = 1; i < length; ++i) { inc(&coord, oshape, &lidx, lstride, &ridx, rstride); // When tuning, don't actually run the op, since it's not going to be tuned against // the actual op we'll eventually be using @@ -238,7 +238,7 @@ struct csr_dns_csr_broadcast_kernel { * \param out ptr to the data buffer of the result csr matrix */ template - MSHADOW_XINLINE static void Map(int row, const DType *csr_data, const CType *csr_indices, + MSHADOW_XINLINE static void Map(index_t row, const DType *csr_data, const CType *csr_indices, const RType *csr_indptr, const DType *dns, DType *out) { const nnvm::dim_t curr_row_i = csr_indptr[row]; const nnvm::dim_t next_row_i = csr_indptr[row + 1]; @@ -257,7 +257,7 @@ struct csr_dns_csr_broadcast_kernel { * \param nnz number of non-zero elements in input csr matrix */ template - MSHADOW_XINLINE static void Map(int i, const DType *csr_data, const DType* scalar_ptr, + MSHADOW_XINLINE static void Map(index_t i, const DType *csr_data, const DType* scalar_ptr, DType *out, const nnvm::dim_t nnz) { const DType scale = scalar_ptr[0]; if (i < nnz) { @@ -269,7 +269,7 @@ struct csr_dns_csr_broadcast_kernel { template struct csr_dns_map_kernel { template - MSHADOW_XINLINE static void Map(int row, const DType *csr_data, const CType *csr_indices, + MSHADOW_XINLINE static void Map(index_t row, const DType *csr_data, const CType *csr_indices, const RType *csr_indptr, DType *out, const nnvm::dim_t num_rows, const nnvm::dim_t num_cols) { if (row < num_rows) { diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 77236e068f86..c39418dbe41d 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -36,7 +36,7 @@ struct TakeCPU { // K is the number of rows of in_data // i is the index of out_data template - MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, + MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data, const IType* idx, const size_t M, const int64_t K) { int64_t j = static_cast(idx[i]); if (clip) { @@ -420,19 +420,19 @@ inline void SparseEmbeddingOpBackwardRspImpl(const bool deterministic, template inline typename std::enable_if<(!std::is_same::value), void>::type -GatherNDBackwardImpl(int N, int M, int K, +GatherNDBackwardImpl(index_t N, index_t M, index_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, const IType* indices, mshadow::Stream *s) { #pragma omp parallel for - for (int i = 0; i < N; i++) { - int offset = 0; - for (int j = 0; j < M; ++j) { - offset += strides[j] * static_cast(indices[j*N + i]); + for (index_t i = 0; i < N; i++) { + index_t offset = 0; + for (index_t j = 0; j < M; ++j) { + offset += strides[j] * static_cast(indices[j*N + i]); } - for (int j = 0; j < K; ++j) { + for (index_t j = 0; j < K; ++j) { #pragma omp atomic out[offset + j] += data[i * K + j]; } @@ -441,18 +441,18 @@ GatherNDBackwardImpl(int N, int M, int K, template inline typename std::enable_if::value, void>::type -GatherNDBackwardImpl(int N, int M, int K, +GatherNDBackwardImpl(index_t N, index_t M, index_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, const IType* indices, mshadow::Stream *s) { - for (int i = 0; i < N; i++) { - int offset = 0; - for (int j = 0; j < M; ++j) { - offset += strides[j] * static_cast(indices[j*N + i]); + for (index_t i = 0; i < N; i++) { + index_t offset = 0; + for (index_t j = 0; j < M; ++j) { + offset += strides[j] * static_cast(indices[j*N + i]); } - for (int j = 0; j < K; ++j) { + for (index_t j = 0; j < K; ++j) { out[offset + j] += data[i * K + j]; } } diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 0d72b1815fde..bad3e5a1a6c5 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -439,22 +439,22 @@ inline void SparseEmbeddingOpBackwardRspImpl(const bool deterministic, struct backward_gather_nd_gpu { template - MSHADOW_XINLINE static void Map(int i, int N, int M, int K, + MSHADOW_XINLINE static void Map(index_t i, index_t N, index_t M, index_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, const IType* indices) { - int offset = 0; - for (int j = 0; j < M; ++j) { + index_t offset = 0; + for (index_t j = 0; j < M; ++j) { offset += strides[j] * static_cast(indices[j*N + i]); } - for (int j = 0; j < K; ++j) { + for (index_t j = 0; j < K; ++j) { atomicAdd(out + (offset + j), data[i * K + j]); } } }; template -inline void GatherNDBackwardImpl(int N, int M, int K, +inline void GatherNDBackwardImpl(index_t N, index_t M, index_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 92b6e21018e5..fba331e25705 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -314,7 +314,8 @@ struct Take { * \param axis axis id */ template - MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const IType* idx, + MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data, + const IType* idx, const mshadow::Shape<10> in_stride, const mshadow::Shape<10> out_stride, const int in_ndims, const int out_ndims, const int idx_ndims, @@ -361,7 +362,7 @@ struct TakeRspKernel { * \param nnr number of non-zero rows */ template - MSHADOW_XINLINE static void Map(int i, + MSHADOW_XINLINE static void Map(index_t i, const IType* data, DType* out, const RType* weight_idx, @@ -1395,15 +1396,15 @@ inline bool ScatterNDType(const nnvm::NodeAttrs& attrs, struct scatter_nd { template - MSHADOW_XINLINE static void Map(int i, OpReqType req, int N, int M, int K, + MSHADOW_XINLINE static void Map(index_t i, OpReqType req, index_t N, index_t M, index_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, const IType* indices) { - int offset = 0; - for (int j = 0; j < M; ++j) { - offset += strides[j] * static_cast(indices[j*N + i]); + index_t offset = 0; + for (index_t j = 0; j < M; ++j) { + offset += strides[j] * static_cast(indices[j*N + i]); } - for (int j = 0; j < K; ++j) { + for (index_t j = 0; j < K; ++j) { KERNEL_ASSIGN(out[offset+j], req, data[i*K + j]); } } @@ -1416,17 +1417,18 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { using namespace mshadow; + using nnvm::dim_t; CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); if (req[0] == kNullOp) return; mshadow::Stream *s = ctx.get_stream(); const TShape& oshape = outputs[0].shape_; const TShape& ishape = inputs[1].shape_; - int M = ishape[0]; - int N = ishape.Size() / M; - int K = oshape.ProdShape(M, oshape.ndim()); + dim_t M = ishape[0]; + dim_t N = ishape.Size() / M; + dim_t K = oshape.ProdShape(M, oshape.ndim()); mshadow::Shape<10> strides; - for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride; + for (dim_t i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride; if (kWriteTo == req[0]) { Fill(s, outputs[0], req[0], 0); } @@ -1441,7 +1443,7 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs, template inline typename std::enable_if<(!std::is_same::value), void>::type -GatherNDBackwardImpl(int N, int M, int K, +GatherNDBackwardImpl(index_t N, index_t M, index_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, @@ -1450,7 +1452,7 @@ GatherNDBackwardImpl(int N, int M, int K, template inline typename std::enable_if::value, void>::type -GatherNDBackwardImpl(int N, int M, int K, +GatherNDBackwardImpl(index_t N, index_t M, index_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, @@ -1458,7 +1460,7 @@ GatherNDBackwardImpl(int N, int M, int K, mshadow::Stream *s); template -inline void GatherNDBackwardImpl(int N, int M, int K, +inline void GatherNDBackwardImpl(index_t N, index_t M, index_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, @@ -1472,17 +1474,18 @@ void GatherNDBackward(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { using namespace mshadow; + using nnvm::dim_t; CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); if (req[0] == kNullOp) return; mshadow::Stream *s = ctx.get_stream(); const TShape& oshape = outputs[0].shape_; const TShape& ishape = inputs[1].shape_; - int M = ishape[0]; - int N = ishape.Size() / M; - int K = oshape.ProdShape(M, oshape.ndim()); + dim_t M = ishape[0]; + dim_t N = ishape.Size() / M; + dim_t K = oshape.ProdShape(M, oshape.ndim()); mshadow::Shape<10> strides; - for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride; + for (dim_t i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride; if (kWriteTo == req[0]) { Fill(s, outputs[0], req[0], 0); } diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 4e52b087f10a..e9e67cb1a4c5 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -453,7 +453,7 @@ void EyeFill(const nnvm::NodeAttrs& attrs, struct range_fwd { template - MSHADOW_XINLINE static void Map(int i, int repeat, DType start, DType step, + MSHADOW_XINLINE static void Map(index_t i, int repeat, DType start, DType step, int req, DType* out) { KERNEL_ASSIGN(out[i], req, start + (i/repeat) * step); } @@ -471,8 +471,8 @@ void RangeCompute(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // Force unsigned params to take two's complement form on ARM to ensure consistency with x86 // results. Casting negative floats to unsigned types is undefined in the CPP standard. - auto step = std::is_signed() ? param.step : static_cast(param.step); - auto start = std::is_signed() ? param.start : static_cast(param.start); + auto step = std::is_signed() ? param.step : static_cast(param.step); + auto start = std::is_signed() ? param.start : static_cast(param.start); Kernel::Launch(s, outputs[0].Size(), static_cast(param.repeat), diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 9c81d87464de..3b229cf38eba 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -626,9 +626,9 @@ inline void GetIndexRange(const TShape& dshape, const nnvm::Tuple>& param_begin, const nnvm::Tuple>& param_end, const nnvm::Tuple>& param_step, - common::StaticArray* begin, - common::StaticArray* end, - common::StaticArray* step) { + common::StaticArray* begin, + common::StaticArray* end, + common::StaticArray* step) { CHECK_NE(dshape.ndim(), 0U); CHECK_LE(param_begin.ndim(), dshape.ndim()) << "Slicing axis exceeds data dimensions"; @@ -646,8 +646,8 @@ inline void GetIndexRange(const TShape& dshape, } for (index_t i = 0; i < param_begin.ndim(); ++i) { - int b = 0, e = dshape[i], s = 1; - const int len = dshape[i]; + index_t b = 0, e = dshape[i], s = 1; + const index_t len = dshape[i]; if (param_step.ndim() != 0U) { const auto& opt_step_val = param_step[i]; if (opt_step_val.has_value()) { @@ -724,7 +724,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, TShape oshape = dshape; MXNET_NDIM_SWITCH(dshape.ndim(), ndim, { - common::StaticArray begin, end, step; + common::StaticArray begin, end, step; GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step); for (index_t i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; @@ -743,19 +743,19 @@ template struct slice_forward { // i is the i-th row after flattening out into 2D tensor template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* data, const mshadow::Shape dshape, const mshadow::Shape oshape, - const common::StaticArray begin, - const common::StaticArray step) { - const int data_last_dim_size = dshape[ndim-1]; - const int out_last_dim_size = oshape[ndim-1]; - const int step_last_dim = step[ndim-1]; - const int begin_last_dim = begin[ndim-1]; - const int j = i % out_last_dim_size; - int irow = 0; // row id of flattend 2D data - int stride = 1; - int idx = i / out_last_dim_size; + const common::StaticArray begin, + const common::StaticArray step) { + const index_t data_last_dim_size = dshape[ndim-1]; + const index_t out_last_dim_size = oshape[ndim-1]; + const index_t step_last_dim = step[ndim-1]; + const index_t begin_last_dim = begin[ndim-1]; + const index_t j = i % out_last_dim_size; + index_t irow = 0; // row id of flattend 2D data + index_t stride = 1; + index_t idx = i / out_last_dim_size; #pragma unroll for (int k = ndim - 2; k >= 0; --k) { irow += stride * ((idx % oshape[k]) * step[k] + begin[k]); @@ -771,20 +771,20 @@ template struct slice_forward { // i is the i-th row after flattening out into 2D tensor template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* data, const mshadow::Shape dshape, const mshadow::Shape oshape, - const common::StaticArray begin, - const common::StaticArray step) { - const int data_last_dim_size = dshape[ndim-1]; - const int out_last_dim_size = oshape[ndim-1]; - const int step_last_dim = step[ndim-1]; - const int begin_last_dim = begin[ndim-1]; - int out_offset = i * out_last_dim_size; - for (int j = 0; j < out_last_dim_size; ++j) { - int irow = 0; // row id of flattend 2D data - int stride = 1; - int idx = i; + const common::StaticArray begin, + const common::StaticArray step) { + const index_t data_last_dim_size = dshape[ndim-1]; + const index_t out_last_dim_size = oshape[ndim-1]; + const index_t step_last_dim = step[ndim-1]; + const index_t begin_last_dim = begin[ndim-1]; + index_t out_offset = i * out_last_dim_size; + for (index_t j = 0; j < out_last_dim_size; ++j) { + index_t irow = 0; // row id of flattend 2D data + index_t stride = 1; + index_t idx = i; #pragma unroll for (int k = ndim - 2; k >= 0; --k) { irow += stride * ((idx % oshape[k]) * step[k] + begin[k]); @@ -813,11 +813,11 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs, const TBlob& out = outputs[0]; const SliceParam& param = nnvm::get(attrs.parsed); MXNET_NDIM_SWITCH(data.ndim(), ndim, { - common::StaticArray begin, end, step; + common::StaticArray begin, end, step; GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step); MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - int num_threads = out.shape_.FlatTo2D()[0]; + size_t num_threads = out.shape_.FlatTo2D()[0]; if (std::is_same::value) { num_threads *= out.shape_.get()[ndim - 1]; } @@ -836,20 +836,20 @@ template struct slice_assign { // i is the i-th row after flattening out into 2D tensor template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* val, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* val, const mshadow::Shape oshape, const mshadow::Shape vshape, - const common::StaticArray begin, - const common::StaticArray step) { - const int data_last_dim_size = oshape[ndim-1]; - const int out_last_dim_size = vshape[ndim-1]; - const int step_last_dim = step[ndim-1]; - const int begin_last_dim = begin[ndim-1]; - int offset = i * out_last_dim_size; - for (int j = 0; j < out_last_dim_size; ++j) { - int irow = 0; // row id of flattend 2D out - int stride = 1; - int idx = i; + const common::StaticArray begin, + const common::StaticArray step) { + const index_t data_last_dim_size = oshape[ndim-1]; + const index_t out_last_dim_size = vshape[ndim-1]; + const index_t step_last_dim = step[ndim-1]; + const index_t begin_last_dim = begin[ndim-1]; + index_t offset = i * out_last_dim_size; + for (index_t j = 0; j < out_last_dim_size; ++j) { + index_t irow = 0; // row id of flattend 2D out + index_t stride = 1; + index_t idx = i; #pragma unroll for (int k = ndim - 2; k >= 0; --k) { irow += stride * ((idx % vshape[k]) * step[k] + begin[k]); @@ -866,19 +866,19 @@ template struct slice_assign { // i is the i-th row after flattening out into 2D tensor template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* val, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* val, const mshadow::Shape oshape, const mshadow::Shape vshape, - const common::StaticArray begin, - const common::StaticArray step) { - const int data_last_dim_size = oshape[ndim-1]; - const int out_last_dim_size = vshape[ndim-1]; - const int step_last_dim = step[ndim-1]; - const int begin_last_dim = begin[ndim-1]; - const int j = i % out_last_dim_size; - int irow = 0; // row id of flattend 2D out - int stride = 1; - int idx = i / out_last_dim_size; + const common::StaticArray begin, + const common::StaticArray step) { + const index_t data_last_dim_size = oshape[ndim-1]; + const index_t out_last_dim_size = vshape[ndim-1]; + const index_t step_last_dim = step[ndim-1]; + const index_t begin_last_dim = begin[ndim-1]; + const index_t j = i % out_last_dim_size; + index_t irow = 0; // row id of flattend 2D out + index_t stride = 1; + index_t idx = i / out_last_dim_size; #pragma unroll for (int k = ndim - 2; k >= 0; --k) { irow += stride * ((idx % vshape[k]) * step[k] + begin[k]); @@ -911,7 +911,7 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs, LOG(FATAL) << "_slice_backward does not support kWriteInplace"; } MXNET_NDIM_SWITCH(ograd.ndim(), ndim, { - common::StaticArray begin, end, step; + common::StaticArray begin, end, step; GetIndexRange(igrad.shape_, param.begin, param.end, param.step, &begin, &end, &step); MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { @@ -937,7 +937,7 @@ inline bool SliceAssignOpShape(const nnvm::NodeAttrs& attrs, TShape vshape = dshape; // vshape is the value shape on the right hand side const SliceParam& param = nnvm::get(attrs.parsed); MXNET_NDIM_SWITCH(dshape.ndim(), ndim, { - common::StaticArray begin, end, step; + common::StaticArray begin, end, step; GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step); for (index_t i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; @@ -975,7 +975,7 @@ void SliceAssignOpForward(const nnvm::NodeAttrs& attrs, const SliceParam& param = nnvm::get(attrs.parsed); MXNET_NDIM_SWITCH(data.ndim(), ndim, { - common::StaticArray begin, end, step; + common::StaticArray begin, end, step; GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step); MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { @@ -1024,20 +1024,20 @@ template struct slice_assign_scalar { // i is the i-th row after flattening out into 2D tensor template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType val, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType val, const OpReqType req, const mshadow::Shape oshape, const mshadow::Shape vshape, - const common::StaticArray begin, - const common::StaticArray step) { - const int data_last_dim_size = oshape[ndim-1]; - const int out_last_dim_size = vshape[ndim-1]; - const int step_last_dim = step[ndim-1]; - const int begin_last_dim = begin[ndim-1]; - for (int j = 0; j < out_last_dim_size; ++j) { - int irow = 0; // row id of flattend 2D out - int stride = 1; - int idx = i; + const common::StaticArray begin, + const common::StaticArray step) { + const index_t data_last_dim_size = oshape[ndim-1]; + const index_t out_last_dim_size = vshape[ndim-1]; + const index_t step_last_dim = step[ndim-1]; + const index_t begin_last_dim = begin[ndim-1]; + for (index_t j = 0; j < out_last_dim_size; ++j) { + index_t irow = 0; // row id of flattend 2D out + index_t stride = 1; + index_t idx = i; #pragma unroll for (int k = ndim - 2; k >= 0; --k) { irow += stride * ((idx % vshape[k]) * step[k] + begin[k]); @@ -1076,7 +1076,7 @@ void SliceAssignScalarOpForward(const nnvm::NodeAttrs& attrs, TShape vshape = data.shape_; const SliceAssignScalarParam& param = nnvm::get(attrs.parsed); MXNET_NDIM_SWITCH(data.ndim(), ndim, { - common::StaticArray begin, end, step; + common::StaticArray begin, end, step; GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step); for (index_t i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; @@ -1107,7 +1107,7 @@ struct SliceAxisParam : public dmlc::Parameter { }; inline void GetSliceAxisParams(const SliceAxisParam& param, const TShape& ishape, - int* axis, int* begin, int* end) { + int* axis, index_t* begin, index_t* end) { *axis = param.axis; if (*axis < 0) { *axis += static_cast(ishape.ndim()); @@ -1115,7 +1115,7 @@ inline void GetSliceAxisParams(const SliceAxisParam& param, const TShape& ishape CHECK(*axis < static_cast(ishape.ndim()) && *axis >= 0) << "Transformed axis must be smaller than the source ndim and larger than zero! Recieved axis=" << param.axis << ", src_ndim=" << ishape.ndim() << ", transformed axis=" << *axis; - int axis_size = static_cast(ishape[*axis]); + index_t axis_size = static_cast(ishape[*axis]); *begin = param.begin; *end = -1; if (*begin < 0) { @@ -1149,7 +1149,8 @@ inline bool SliceAxisShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); TShape& ishape = (*in_attrs)[0]; - int axis, begin, end; + int axis; + index_t begin, end; GetSliceAxisParams(param, ishape, &axis, &begin, &end); TShape shape(ishape.ndim()); for (index_t i = 0; i < ishape.ndim(); ++i) { @@ -1173,7 +1174,8 @@ void SliceAxis(const nnvm::NodeAttrs& attrs, using namespace mshadow::expr; const SliceAxisParam& param = nnvm::get(attrs.parsed); mshadow::Stream *s = ctx.get_stream(); - int axis, begin, end; + int axis; + index_t begin, end; GetSliceAxisParams(param, inputs[0].shape_, &axis, &begin, &end); int ndim = static_cast(outputs[0].ndim()); @@ -1207,7 +1209,8 @@ void SliceAxisGrad_(const nnvm::NodeAttrs& attrs, using namespace mshadow::op; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); - int axis, begin, end; + int axis; + index_t begin, end; GetSliceAxisParams(param, outputs[0].shape_, &axis, &begin, &end); int ndim = static_cast(outputs[0].shape_.ndim()); @@ -1354,7 +1357,7 @@ void SliceLikeForward(const nnvm::NodeAttrs& attrs, SliceLikeInferRanges(ishape, from_shape, param.axes, ¶m_begin, ¶m_end, ¶m_step); MXNET_NDIM_SWITCH(data.ndim(), ndim, { - common::StaticArray begin, end, step; + common::StaticArray begin, end, step; GetIndexRange(data.shape_, param_begin, param_end, param_step, &begin, &end, &step); MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { @@ -1400,7 +1403,7 @@ void SliceLikeBackward(const nnvm::NodeAttrs& attrs, SliceLikeInferRanges(ishape, from_shape, param.axes, ¶m_begin, ¶m_end, ¶m_step); MXNET_NDIM_SWITCH(ograd.ndim(), ndim, { - common::StaticArray begin, end, step; + common::StaticArray begin, end, step; GetIndexRange(ograd.shape_, param_begin, param_end, param_step, &begin, &end, &step); MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { @@ -1429,7 +1432,7 @@ struct ClipParam : public dmlc::Parameter { struct clip { template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* datas, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* datas, DType a_min, DType a_max) { DType data = datas[i]; if (data > a_max) { @@ -1445,7 +1448,7 @@ struct clip { struct clip_grad { template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* grad, const DType* datas, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* grad, const DType* datas, DType a_min, DType a_max) { DType data = datas[i]; if (data > a_max) { @@ -1934,7 +1937,7 @@ struct reverse { } #ifdef __CUDACC__ template - __device__ static void Map(int index, index_t nreversedim, const DType *src, DType *dst, + __device__ static void Map(index_t index, index_t nreversedim, const DType *src, DType *dst, const index_t * stride_, const index_t * trailing_) { __shared__ index_t stride_share[REVERSE_MAX_DIM]; @@ -1949,7 +1952,7 @@ struct reverse { } #else template - MSHADOW_XINLINE static void Map(int index, index_t nreversedim, const DType *src, DType *dst, + MSHADOW_XINLINE static void Map(index_t index, index_t nreversedim, const DType *src, DType *dst, const index_t * stride_, const index_t * trailing_) { index_t new_idx = ReverseIndex(index, nreversedim, stride_, trailing_); @@ -2141,10 +2144,10 @@ struct SqueezeParam : public dmlc::Parameter { // move all the zeros to the last of the shape array // and keep the relative order of the non-zero values. // Returns the new shape size after moving all zeros to the end. -inline uint32_t SqueezeShapeHelper(TShape* shape) { +inline size_t SqueezeShapeHelper(TShape* shape) { CHECK(shape != nullptr); - uint32_t count = 0; - for (uint32_t i = 0; i < shape->ndim(); ++i) { + size_t count = 0; + for (size_t i = 0; i < shape->ndim(); ++i) { if ((*shape)[i] == 0) { ++count; } else { @@ -2167,7 +2170,7 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs, if (param.axis.has_value()) { // preprocess axis TShape axes = param.axis.value(); - for (uint32_t i = 0; i < axes.ndim(); ++i) { + for (size_t i = 0; i < axes.ndim(); ++i) { if (axes[i] < 0) { axes[i] += dndim; CHECK_GE(axes[i], 0) @@ -2182,11 +2185,11 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs, oshape[axes[i]] = 0; } } else { - for (uint32_t i = 0; i < oshape.ndim(); ++i) { + for (size_t i = 0; i < oshape.ndim(); ++i) { if (oshape[i] == 1) oshape[i] = 0; } } - uint32_t oshape_size = SqueezeShapeHelper(&oshape); + size_t oshape_size = SqueezeShapeHelper(&oshape); if (oshape_size == 0) { // corner case when dshape is (1, 1, 1, 1) oshape[0] = 1; oshape_size = 1; @@ -2229,7 +2232,7 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs& attrs, expected_out[0] = in_shape[0]; expected_out[1] = in_shape[1] / (block * block); - uint32_t i = 2; + size_t i = 2; while (i < expected_out.ndim()) { expected_out[i] = in_shape[i] * block; ++i; @@ -2259,9 +2262,9 @@ inline bool DepthToSpaceOpType(const nnvm::NodeAttrs& attrs, * \param inp_index index within input tensor from where value is retrieved * \param offset_arr array containing the linear offset of input tensor */ -MSHADOW_XINLINE void update_index(int index_position, int dim_size, int *idx, - int *inp_index, const int* offset_arr) { - int next_idx_val = *idx / dim_size; +MSHADOW_XINLINE void update_index(index_t index_position, index_t dim_size, index_t *idx, + index_t *inp_index, const index_t* offset_arr) { + index_t next_idx_val = *idx / dim_size; *inp_index += (*idx - next_idx_val * dim_size) * offset_arr[index_position]; *idx = next_idx_val; } @@ -2280,9 +2283,9 @@ MSHADOW_XINLINE void update_index(int index_position, int dim_size, int *idx, template struct depth_to_space_forward { template - MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, - const int block, const int* size, const int* offset_arr) { - int inp_index = 0, idx = i, dim_size; + MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data, + const int block, const index_t* size, const index_t* offset_arr) { + index_t inp_index = 0, idx = i, dim_size; dim_size = block; update_index(2, dim_size, &idx, &inp_index, offset_arr); dim_size = size[3]; @@ -2315,9 +2318,9 @@ struct depth_to_space_forward { template struct compute_offset_for_depth_to_space { template - MSHADOW_XINLINE static void Map(int i, DType* offset_arr, DType* size, const int block, - const int32_t size0, const int32_t size1, const int32_t size2, - const int32_t size3) { + MSHADOW_XINLINE static void Map(index_t i, DType* offset_arr, DType* size, const int block, + const index_t size0, const index_t size1, const index_t size2, + const index_t size3) { size[0] = size0; size[1] = size1; size[2] = size2; @@ -2349,10 +2352,10 @@ void DepthToSpaceOpForward(const nnvm::NodeAttrs& attrs, int block = param.block_size; mshadow::Tensor workspace = - ctx.requested[0].get_space_typed(mshadow::Shape1(sizeof(int32_t) * 10), s); + ctx.requested[0].get_space_typed(mshadow::Shape1(sizeof(index_t) * 10), s); char* workspace_curr_ptr = workspace.dptr_; - int32_t* offset_arr = reinterpret_cast(workspace_curr_ptr); - int32_t* size = reinterpret_cast(workspace_curr_ptr + sizeof(int32_t) * 6); + index_t* offset_arr = reinterpret_cast(workspace_curr_ptr); + index_t* size = reinterpret_cast(workspace_curr_ptr + sizeof(index_t) * 6); MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { @@ -2431,9 +2434,9 @@ inline bool SpaceToDepthOpType(const nnvm::NodeAttrs& attrs, template struct space_to_depth_forward { template - MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const int block, - const int* size, const int* offset_arr) { - int inp_index = 0, idx = i, dim_size; + MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data, const int block, + const index_t* size, const index_t* offset_arr) { + index_t inp_index = 0, idx = i, dim_size; dim_size = size[3] / block; update_index(4, dim_size, &idx, &inp_index, offset_arr); dim_size = size[2] / block; @@ -2466,9 +2469,9 @@ struct space_to_depth_forward { template struct compute_offset_for_space_to_depth { template - MSHADOW_XINLINE static void Map(int i, DType* offset_arr, DType* size, const int block, - const int32_t size0, const int32_t size1, - const int32_t size2, const int32_t size3) { + MSHADOW_XINLINE static void Map(index_t i, DType* offset_arr, DType* size, const int block, + const index_t size0, const index_t size1, + const index_t size2, const index_t size3) { size[0] = size0; size[1] = size1; size[2] = size2; @@ -2500,10 +2503,10 @@ void SpaceToDepthOpForward(const nnvm::NodeAttrs& attrs, int block = param.block_size; mshadow::Tensor workspace = - ctx.requested[0].get_space_typed(mshadow::Shape1(sizeof(int32_t) * 10), s); + ctx.requested[0].get_space_typed(mshadow::Shape1(sizeof(index_t) * 10), s); char* workspace_curr_ptr = workspace.dptr_; - int32_t* offset_arr = reinterpret_cast(workspace_curr_ptr); - int32_t* size = reinterpret_cast(workspace_curr_ptr + sizeof(int32_t) * 6); + index_t* offset_arr = reinterpret_cast(workspace_curr_ptr); + index_t* size = reinterpret_cast(workspace_curr_ptr + sizeof(index_t) * 6); MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 121acc174b51..a301362f2db7 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -15,20 +15,126 @@ # specific language governing permissions and limitations # under the License. -import unittest import mxnet as mx +import numpy as np from mxnet import gluon, nd +# dimension constants +MEDIUM_X = 10000 +LARGE_X = 100000000 +LARGE_Y = 50000000 +SMALL_Y = 50 +LARGE_SIZE = LARGE_X * SMALL_Y + +def test_gluon_embedding(): + m = gluon.nn.Embedding(SMALL_Y, MEDIUM_X) + m.initialize() + a = nd.zeros((MEDIUM_X, SMALL_Y)) + b = m(a) + assert b.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X) + assert b.asnumpy().size == LARGE_SIZE + +def test_ndarray_zeros(): + a = nd.zeros(shape=(LARGE_X, SMALL_Y)) + assert a[-1][0] == 0 + assert a.shape == (LARGE_X, SMALL_Y) + assert a.size == LARGE_SIZE + +def test_ndarray_ones(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + assert a[-1][0] == 1 + assert nd.sum(a).asnumpy() == LARGE_SIZE + +def test_ndarray_random_uniform(): + a = nd.random.uniform(shape=(LARGE_X, SMALL_Y)) + assert a[-1][0] != 0 + +def test_ndarray_empty(): + a = nd.empty((LARGE_X, SMALL_Y)) + assert a.shape == (LARGE_X, SMALL_Y) + +def test_elementwise(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + b = nd.ones(shape=(LARGE_X, SMALL_Y)) + res = a + b + assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] + res = a + 1 + assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] + res = nd.sqrt(a + 3) + assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] + +def test_reduce(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + assert nd.sum(a).asnumpy() == a.shape[0] * a.shape[1] + +def test_dot(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + b = nd.ones(shape=(SMALL_Y, SMALL_Y)) + res = nd.dot(a, b) + assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1] + +def test_FullyConnected(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + b = nd.ones(shape=(SMALL_Y, SMALL_Y)) + res = nd.FullyConnected(a, b, num_hidden=b.shape[1], no_bias=True) + assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1] + +def test_broadcast(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1) + res = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y)) + assert np.sum(res[-1].asnumpy() == LARGE_X) == res.shape[1] + res = mx.nd.broadcast_like(b, a) + assert np.sum(res[-1].asnumpy() == LARGE_X) == a.shape[1] + +def test_clip(): + a = nd.arange(0, LARGE_X).reshape(LARGE_X, 1) + b = nd.broadcast_to(a, shape=(a.shape[0], SMALL_Y)) + res = nd.clip(b, a_min=100, a_max=1000) + assert np.sum(res[-1].asnumpy() == 1000) == b.shape[1] + +def test_take(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + idx = nd.arange(LARGE_X-1000, LARGE_X) + res = nd.take(a, idx) + assert np.sum(res[-1].asnumpy() == 1) == res.shape[1] + +def test_slice(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + res = nd.slice(a, begin=(LARGE_X-1000, 1), end=(LARGE_X, SMALL_Y)) + assert np.sum(res[-1].asnumpy() == 1) == res.shape[1] + +def test_slice_assign(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + a[LARGE_X-1:LARGE_X] = 1000 + assert np.sum(a[-1].asnumpy() == 1000) == a.shape[1] + +def test_expand_dims(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + res = nd.expand_dims(a, axis=1) + assert res.shape == (a.shape[0], 1, a.shape[1]) + +def test_squeeze(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + data = nd.expand_dims(a, axis=1) + res = nd.squeeze(data) + assert res.shape == a.shape + +def test_broadcast_div(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + b = nd.ones(shape=(LARGE_X, 1)) * 2 + res = a / b + assert np.sum(res[-1].asnumpy() == 0.5) == a.shape[1] + +def test_Dense(ctx=mx.cpu(0)): + data = mx.nd.ones(shape=(50*1000*1000, 100)) + linear = gluon.nn.Dense(100) + linear.initialize(ctx=ctx) + res = linear(data) + res.wait_to_read() + assert res.shape == (50000000, 100) -class TestLargeArray(unittest.TestCase): - def test_ndarray2numpy(self): - m = gluon.nn.Embedding(14000, 128) - m.initialize() - ind = nd.zeros((700000, 128)) - x = m(ind) - x.shape - test = x.asnumpy() - assert (x.shape == test.shape) if __name__ == '__main__': - unittest.main() + import nose + nose.runmodule()