diff --git a/3rdparty/mshadow b/3rdparty/mshadow index 95ebe0f109ae..6e94643bdf1d 160000 --- a/3rdparty/mshadow +++ b/3rdparty/mshadow @@ -1 +1 @@ -Subproject commit 95ebe0f109ae021d0d66e2a627ccfc55c3253b55 +Subproject commit 6e94643bdf1d51a505b147f28c358fb71070b8fd diff --git a/CMakeLists.txt b/CMakeLists.txt index 9cd68e14093c..09a52be2fc87 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,7 @@ mxnet_option(USE_SIGNAL_HANDLER "Print stack traces on segfaults." OFF) mxnet_option(USE_TENSORRT "Enable infeference optimization with TensorRT." OFF) mxnet_option(USE_ASAN "Enable Clang/GCC ASAN sanitizers." OFF) mxnet_option(ENABLE_TESTCOVERAGE "Enable compilation with test coverage metric output" OFF) +mxnet_option(USE_INT64_TENSOR_SIZE "Use int64_t to represent the total number of elements in a tensor" OFF) message(STATUS "CMAKE_CROSSCOMPILING ${CMAKE_CROSSCOMPILING}") message(STATUS "CMAKE_HOST_SYSTEM_PROCESSOR ${CMAKE_HOST_SYSTEM_PROCESSOR}") @@ -295,6 +296,13 @@ else() add_definitions(-DMXNET_USE_NCCL=0) endif() +if (USE_INT64_TENSOR_SIZE) + message(STATUS "Using 64-bit integer for tensor size") + add_definitions(-DMSHADOW_INT64_TENSOR_SIZE=1) +else() + add_definitions(-DMSHADOW_INT64_TENSOR_SIZE=0) +endif() + include(cmake/ChooseBlas.cmake) if(USE_CUDA AND FIRST_CUDA) include(3rdparty/mshadow/cmake/Utils.cmake) diff --git a/Makefile b/Makefile index 53998ac31919..29cfd573665c 100644 --- a/Makefile +++ b/Makefile @@ -189,6 +189,11 @@ ifeq ($(USE_OPERATOR_TUNING), 1) CFLAGS += -DMXNET_USE_OPERATOR_TUNING=1 endif +ifeq ($(USE_INT64_TENSOR_SIZE), 1) + CFLAGS += -DMSHADOW_INT64_TENSOR_SIZE=1 +else + CFLAGS += -DMSHADOW_INT64_TENSOR_SIZE=0 +endif # verify existence of separate lapack library when using blas/openblas/atlas # switch off lapack support in case it can't be found # issue covered with this diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index ac8033c794ec..c3610d2452e0 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -755,6 +755,53 @@ build_ubuntu_gpu_cmake() { ninja -v } +build_ubuntu_cpu_large_tensor() { + set -ex + cd /work/build + build_ccache_wrappers + cmake \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \ + -DUSE_SIGNAL_HANDLER=ON \ + -DENABLE_TESTCOVERAGE=ON \ + -DUSE_CUDA=OFF \ + -DUSE_CUDNN=OFF \ + -DUSE_MKLDNN=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DUSE_INT64_TENSOR_SIZE=ON \ + -G Ninja \ + /work/mxnet + + ninja -v +} + +build_ubuntu_gpu_large_tensor() { + set -ex + cd /work/build + build_ccache_wrappers + cmake \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \ + -DUSE_SIGNAL_HANDLER=ON \ + -DENABLE_TESTCOVERAGE=ON \ + -DUSE_CUDA=ON \ + -DUSE_CUDNN=ON \ + -DUSE_MKL_IF_AVAILABLE=OFF \ + -DUSE_MKLML_MKL=OFF \ + -DUSE_MKLDNN=OFF \ + -DUSE_DIST_KVSTORE=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DCUDA_ARCH_NAME=Manual \ + -DCUDA_ARCH_BIN=$CI_CMAKE_CUDA_ARCH_BIN \ + -DUSE_INT64_TENSOR_SIZE=ON \ + -G Ninja \ + /work/mxnet + + ninja -v +} + build_ubuntu_blc() { echo "pass" } @@ -1183,6 +1230,13 @@ nightly_test_KVStore_singleNode() { python tests/nightly/test_kvstore.py } +#Test Large Tensor Size +nightly_test_large_tensor() { + set -ex + export PYTHONPATH=./python/ + nosetests-3.4 tests/nightly/test_large_array.py +} + #Tests Amalgamation Build with 5 different sets of flags nightly_test_amalgamation() { set -ex diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index ac1579a13d96..23230ac0442f 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -119,6 +119,34 @@ def compile_unix_openblas_debug_cpu() { }] } +def compile_unix_int64_cpu() { + return ['CPU: USE_INT64_TENSOR_SIZE': { + node(NODE_LINUX_CPU) { + ws('workspace/build-cpu-int64') { + timeout(time: max_time, unit: 'MINUTES') { + utils.init_git() + utils.docker_run('ubuntu_cpu', 'build_ubuntu_cpu_large_tensor', false) + utils.pack_lib('ubuntu_cpu_int64', mx_cmake_lib, true) + } + } + } + }] +} + +def compile_unix_int64_gpu() { + return ['GPU: USE_INT64_TENSOR_SIZE': { + node(NODE_LINUX_GPU) { + ws('workspace/build-gpu-int64') { + timeout(time: max_time, unit: 'MINUTES') { + utils.init_git() + utils.docker_run('ubuntu_gpu', 'build_ubuntu_gpu_large_tensor', false) + utils.pack_lib('ubuntu_gpu_int64', mx_cmake_lib, true) + } + } + } + }] +} + def compile_unix_mkl_cpu() { return ['CPU: MKL': { node(NODE_LINUX_CPU) { diff --git a/ci/jenkins/Jenkinsfile_unix_cpu b/ci/jenkins/Jenkinsfile_unix_cpu index 919381ebccd4..fa0942988d9c 100644 --- a/ci/jenkins/Jenkinsfile_unix_cpu +++ b/ci/jenkins/Jenkinsfile_unix_cpu @@ -38,7 +38,8 @@ core_logic: { custom_steps.compile_unix_openblas_debug_cpu(), custom_steps.compile_unix_mkl_cpu(), custom_steps.compile_unix_mkldnn_cpu(), - custom_steps.compile_unix_mkldnn_mkl_cpu() + custom_steps.compile_unix_mkldnn_mkl_cpu(), + custom_steps.compile_unix_int64_cpu() ]) utils.parallel_stage('Tests', [ diff --git a/ci/jenkins/Jenkinsfile_unix_gpu b/ci/jenkins/Jenkinsfile_unix_gpu index 5d0a37b82edc..43186012c749 100644 --- a/ci/jenkins/Jenkinsfile_unix_gpu +++ b/ci/jenkins/Jenkinsfile_unix_gpu @@ -40,6 +40,7 @@ core_logic: { custom_steps.compile_unix_cmake_mkldnn_gpu(), custom_steps.compile_unix_cmake_gpu(), custom_steps.compile_unix_tensorrt_gpu(), + custom_steps.compile_unix_int64_gpu() ]) utils.parallel_stage('Tests', [ diff --git a/include/mxnet/libinfo.h b/include/mxnet/libinfo.h index f35d41a9aa8a..8b58a398c673 100644 --- a/include/mxnet/libinfo.h +++ b/include/mxnet/libinfo.h @@ -123,7 +123,9 @@ #define MXNET_USE_SIGNAL_HANDLER 0 #endif - +#ifndef MXNET_USE_INT64_TENSOR_SIZE +#define MXNET_USE_INT64_TENSOR_SIZE MSHADOW_INT64_TENSOR_SIZE +#endif namespace mxnet { namespace features { @@ -177,6 +179,8 @@ enum : unsigned { PROFILER, DIST_KVSTORE, CXX14, + INT64_TENSOR_SIZE, + // Signal handler to print stack traces on exceptions SIGNAL_HANDLER, DEBUG, diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h index a7a57266dab8..a08dab1b1b74 100755 --- a/include/mxnet/tensor_blob.h +++ b/include/mxnet/tensor_blob.h @@ -218,15 +218,16 @@ class TBlob { return shape_.ndim(); } /*! - * \brief return size of i-th dimension, start counting from highest dimension + * \brief return size of i-th dimension, start counting from highest dimension. + * return type needs to be a signed integer. * \param idx the dimension count from the highest dimensin - * \return the size + * \return the size. -1 means unknown size to support zero-size tensor. */ inline index_t size(index_t idx) const { return shape_[idx]; } /*! \brief total number of elements in the tensor */ - inline index_t Size(void) const { + inline size_t Size(void) const { return shape_.Size(); } /*! \brief get pointer in dtype */ @@ -443,7 +444,7 @@ class FieldEntry throw dmlc::ParamError(os.str()); } if (enforce_nonzero_) { - for (mxnet::index_t i = 0; i < v.ndim(); ++i) { + for (int i = 0; i < v.ndim(); ++i) { if (v[i] == 0U) { std::ostringstream os; os << "value " << v << "for Parameter " << this->key_ @@ -457,7 +458,7 @@ class FieldEntry this->enforce_nonzero_ = true; return this->self(); } - inline FieldEntry &set_expect_ndim(mxnet::index_t ndim) { + inline FieldEntry &set_expect_ndim(int ndim) { expect_ndim_ = ndim; return this->self(); } @@ -466,7 +467,7 @@ class FieldEntry // whether all the entries need to be nonzero bool enforce_nonzero_; // expected number of dimension, default = 0 means no restriction. - mxnet::index_t expect_ndim_; + int expect_ndim_; }; } // namespace parameter diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h index 8431bbb23b96..bc630f153744 100644 --- a/include/mxnet/tuple.h +++ b/include/mxnet/tuple.h @@ -569,7 +569,7 @@ class TShape : public Tuple { * \param axis_end The ending axis specified. * \return the flat 3d shape */ - inline mshadow::Shape<3> FlatTo3D(size_t axis_begin, size_t axis_end) const { + inline mshadow::Shape<3> FlatTo3D(int axis_begin, int axis_end) const { CHECK(axis_end >= axis_begin); mshadow::Shape<3> s; CHECK(ndim_is_known(ndim())) << "shape must have a valid ndim"; @@ -579,10 +579,10 @@ class TShape : public Tuple { s.shape_[1] = 1; s.shape_[2] = 1; - for (size_t i = 0; i < axis_begin; ++i) { + for (int i = 0; i < axis_begin; ++i) { s.shape_[0] *= d[i]; } - for (size_t i = axis_begin; i <= axis_end; ++i) { + for (int i = axis_begin; i <= axis_end; ++i) { s.shape_[1] *= d[i]; } for (int i = axis_end + 1; i < ndim(); ++i) { @@ -595,7 +595,7 @@ class TShape : public Tuple { * \param axis The axis specified. * \return the flat 3d shape */ - inline mshadow::Shape<3> FlatTo3D(size_t axis) const { + inline mshadow::Shape<3> FlatTo3D(int axis) const { return FlatTo3D(axis, axis); } inline bool operator==(const TShape &s) const { @@ -712,8 +712,8 @@ template struct hash > { /*! \brief hash a Tuple into unsigned int */ size_t operator()(const mxnet::Tuple& val) const { - std::hash hash_uint; - size_t res = hash_uint(val.ndim()); + std::hash hash_int; + size_t res = hash_int(val.ndim()); for (int i = 0; i < val.ndim(); ++i) { res = dmlc::HashCombine(res, val[i]); } @@ -726,8 +726,8 @@ template<> struct hash { /*! \brief hash a TShape into unsigned int */ size_t operator()(const mxnet::TShape& val) const { - std::hash hash_uint; - size_t res = hash_uint(val.ndim()); + std::hash hash_int; + size_t res = hash_int(val.ndim()); for (int i = 0; i < val.ndim(); ++i) { res = dmlc::HashCombine(res, val[i]); } diff --git a/make/config.mk b/make/config.mk index d4431a97173d..20834675ecbd 100644 --- a/make/config.mk +++ b/make/config.mk @@ -215,6 +215,12 @@ EXTRA_OPERATORS = # Create C++ interface package USE_CPP_PACKAGE = 0 +# Use int64_t type to represent the total number of elements in a tensor +# This will cause performance degradation reported in issue #14496 +# Set to 1 for large tensor with tensor size greater than INT32_MAX i.e. 2147483647 +# Note: the size of each dimension is still bounded by INT32_MAX +USE_INT64_TENSOR_SIZE = 0 + #---------------------------- # plugins #---------------------------- diff --git a/make/crosscompile.jetson.mk b/make/crosscompile.jetson.mk index f0c89d6239e6..880e2cf5b466 100644 --- a/make/crosscompile.jetson.mk +++ b/make/crosscompile.jetson.mk @@ -192,6 +192,12 @@ EXTRA_OPERATORS = # Create C++ interface package USE_CPP_PACKAGE = 0 +# Use int64_t type to represent the total number of elements in the tensor +# This will cause performance degradation reported in issue #14496 +# Set to 1 for large tensor with tensor size greater than INT32_MAX i.e. 2147483647 +# Note: the size of each dimension is still bounded by INT32_MAX +USE_INT64_TENSOR_SIZE = 0 + #---------------------------- # plugins #---------------------------- diff --git a/make/osx.mk b/make/osx.mk index 7e32d81a5d71..0b5842e59524 100644 --- a/make/osx.mk +++ b/make/osx.mk @@ -135,6 +135,12 @@ EXTRA_OPERATORS = # Create C++ interface package USE_CPP_PACKAGE = 0 +# Use int64_t type to represent the total number of elements in a tensor +# This will cause performance degradation reported in issue #14496 +# Set to 1 for large tensor with tensor size greater than INT32_MAX i.e. 2147483647 +# Note: the size of each dimension is still bounded by INT32_MAX +USE_INT64_TENSOR_SIZE = 0 + #---------------------------- # plugins #---------------------------- diff --git a/src/common/serialization.h b/src/common/serialization.h index 8192ee210a1c..c22d8bc82270 100644 --- a/src/common/serialization.h +++ b/src/common/serialization.h @@ -49,7 +49,7 @@ template inline size_t SerializedSize(const T &obj); template -inline size_t SerializedSize(const nnvm::Tuple &obj); +inline size_t SerializedSize(const mxnet::Tuple &obj); template inline size_t SerializedSize(const std::map &obj); @@ -64,7 +64,7 @@ template inline void Serialize(const T &obj, char **buffer); template -inline void Serialize(const nnvm::Tuple &obj, char **buffer); +inline void Serialize(const mxnet::Tuple &obj, char **buffer); template inline void Serialize(const std::map &obj, char **buffer); @@ -79,7 +79,7 @@ template inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos); template -inline void Deserialize(nnvm::Tuple *obj, const std::string &buffer, size_t *curr_pos); +inline void Deserialize(mxnet::Tuple *obj, const std::string &buffer, size_t *curr_pos); template inline void Deserialize(std::map *obj, const std::string &buffer, size_t *curr_pos); @@ -102,7 +102,7 @@ inline size_t SerializedSize(const T &obj) { } template -inline size_t SerializedSize(const nnvm::Tuple &obj) { +inline size_t SerializedSize(const mxnet::Tuple &obj) { if (is_container::value) { size_t sum_val = 4; for (const auto& el : obj) { @@ -180,7 +180,7 @@ inline void Serialize(const T &obj, char **buffer) { } template -inline void Serialize(const nnvm::Tuple &obj, char **buffer) { +inline void Serialize(const mxnet::Tuple &obj, char **buffer) { uint32_t size = obj.ndim(); std::memcpy(*buffer, &size, 4); *buffer += 4; @@ -244,7 +244,7 @@ inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos) { } template -inline void Deserialize(nnvm::Tuple *obj, const std::string &buffer, size_t *curr_pos) { +inline void Deserialize(mxnet::Tuple *obj, const std::string &buffer, size_t *curr_pos) { uint32_t size = obj->ndim(); std::memcpy(&size, &buffer[*curr_pos], 4); *curr_pos += 4; diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index b3192dc8281b..14b373edea57 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -36,8 +36,8 @@ struct CachedOpConfig : public dmlc::Parameter { bool static_alloc; bool static_shape; bool is_dynamic; - nnvm::Tuple data_indices; - nnvm::Tuple param_indices; + mxnet::Tuple data_indices; + mxnet::Tuple param_indices; std::string subgraph; DMLC_DECLARE_PARAMETER(CachedOpConfig) { DMLC_DECLARE_FIELD(static_alloc) @@ -59,10 +59,10 @@ struct CachedOpConfig : public dmlc::Parameter { .set_default(Imperative::BulkExecMaxNodeTrainBwd()) .describe("Segment size of bulk execution during backward pass."); DMLC_DECLARE_FIELD(data_indices) - .set_default(nnvm::Tuple()) + .set_default(mxnet::Tuple()) .describe("Position of argument variables."); DMLC_DECLARE_FIELD(param_indices) - .set_default(nnvm::Tuple()) + .set_default(mxnet::Tuple()) .describe("Position of parameters."); DMLC_DECLARE_FIELD(subgraph) .set_default(std::string("")) diff --git a/src/io/image_det_aug_default.cc b/src/io/image_det_aug_default.cc index 74e51b51603b..3bd37200b8e7 100644 --- a/src/io/image_det_aug_default.cc +++ b/src/io/image_det_aug_default.cc @@ -34,7 +34,7 @@ namespace mxnet { namespace io { -using nnvm::Tuple; +using mxnet::Tuple; namespace image_det_aug_default_enum { enum ImageDetAugDefaultCropEmitMode {kCenter, kOverlap}; @@ -462,7 +462,7 @@ class DefaultImageDetAugmenter : public ImageAugmenter { /*! \brief Check number of crop samplers and given parameters */ template - void ValidateCropParameters(nnvm::Tuple *param, const int num_sampler) { + void ValidateCropParameters(mxnet::Tuple *param, const int num_sampler) { if (num_sampler == 1) { CHECK_EQ(param->ndim(), 1); } else if (num_sampler > 1) { diff --git a/src/io/image_io.cc b/src/io/image_io.cc index 965078cb2766..c0357998f31c 100644 --- a/src/io/image_io.cc +++ b/src/io/image_io.cc @@ -295,7 +295,7 @@ struct MakeBorderParam : public dmlc::Parameter { int top, bot, left, right; int type; double value; - nnvm::Tuple values; + mxnet::Tuple values; DMLC_DECLARE_PARAMETER(MakeBorderParam) { DMLC_DECLARE_FIELD(top) .describe("Top margin."); diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index 69eb05f7d729..279690b594e6 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -78,7 +78,7 @@ class BatchLoader : public IIterator { // if overflow from previous round, directly return false, until before first is called if (num_overflow_ != 0) return false; - index_t top = 0; + size_t top = 0; while (base_->Next()) { const DataInst& d = base_->Value(); diff --git a/src/io/iter_sparse_batchloader.h b/src/io/iter_sparse_batchloader.h index 17c509a0f56b..c0d856df89ec 100644 --- a/src/io/iter_sparse_batchloader.h +++ b/src/io/iter_sparse_batchloader.h @@ -67,7 +67,7 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator this->head_ = 0; // if overflown from previous round, directly return false, until before first is called if (num_overflow_ != 0) return false; - index_t top = 0; + size_t top = 0; offsets_.clear(); while (sparse_base_->Next()) { const DataInst& inst = sparse_base_->Value(); diff --git a/src/libinfo.cc b/src/libinfo.cc index 2af61eac9eca..f67b45ed1c14 100644 --- a/src/libinfo.cc +++ b/src/libinfo.cc @@ -86,7 +86,9 @@ class FeatureSet { // Misc feature_bits.set(CAFFE, MXNET_USE_CAFFE); feature_bits.set(DIST_KVSTORE, MXNET_USE_DIST_KVSTORE); + feature_bits.set(INT64_TENSOR_SIZE, MXNET_USE_INT64_TENSOR_SIZE); feature_bits.set(SIGNAL_HANDLER, MXNET_USE_SIGNAL_HANDLER); + #ifndef NDEBUG feature_bits.set(DEBUG); #endif @@ -154,6 +156,7 @@ const std::vector EnumNames::names = { "PROFILER", "DIST_KVSTORE", "CXX14", + "INT64_TENSOR_SIZE", "SIGNAL_HANDLER", "DEBUG", }; diff --git a/src/operator/contrib/dgl_graph.cc b/src/operator/contrib/dgl_graph.cc index 313b855f0d2d..428899791a5d 100644 --- a/src/operator/contrib/dgl_graph.cc +++ b/src/operator/contrib/dgl_graph.cc @@ -1251,7 +1251,7 @@ void EdgeIDForwardCsrImpl(const OpContext& ctx, CHECK_EQ(req, kWriteTo) << "EdgeID with CSR only supports kWriteTo"; Stream *s = ctx.get_stream(); const NDArray& u = inputs[1]; - const nnvm::dim_t out_elems = u.shape().Size(); + const dim_t out_elems = u.shape().Size(); if (!inputs[0].storage_initialized()) { MSHADOW_TYPE_SWITCH(output.dtype(), DType, { Kernel, xpu>::Launch( @@ -1408,7 +1408,7 @@ the data value of float32. struct SubgraphCompactParam : public dmlc::Parameter { int num_args; bool return_mapping; - nnvm::Tuple graph_sizes; + mxnet::Tuple graph_sizes; DMLC_DECLARE_PARAMETER(SubgraphCompactParam) { DMLC_DECLARE_FIELD(num_args).set_lower_bound(2) .describe("Number of input arguments."); diff --git a/src/operator/contrib/multi_proposal-inl.h b/src/operator/contrib/multi_proposal-inl.h index 4d278fb40645..7010dadfedbc 100644 --- a/src/operator/contrib/multi_proposal-inl.h +++ b/src/operator/contrib/multi_proposal-inl.h @@ -56,8 +56,8 @@ struct MultiProposalParam : public dmlc::Parameter { int rpn_post_nms_top_n; float threshold; int rpn_min_size; - nnvm::Tuple scales; - nnvm::Tuple ratios; + mxnet::Tuple scales; + mxnet::Tuple ratios; int feature_stride; bool output_score; bool iou_loss; @@ -73,10 +73,10 @@ struct MultiProposalParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(rpn_min_size).set_default(16) .describe("Minimum height or width in proposal"); tmp[0] = 4.0f; tmp[1] = 8.0f; tmp[2] = 16.0f; tmp[3] = 32.0f; - DMLC_DECLARE_FIELD(scales).set_default(nnvm::Tuple(tmp, tmp + 4)) + DMLC_DECLARE_FIELD(scales).set_default(mxnet::Tuple(tmp, tmp + 4)) .describe("Used to generate anchor windows by enumerating scales"); tmp[0] = 0.5f; tmp[1] = 1.0f; tmp[2] = 2.0f; - DMLC_DECLARE_FIELD(ratios).set_default(nnvm::Tuple(tmp, tmp + 3)) + DMLC_DECLARE_FIELD(ratios).set_default(mxnet::Tuple(tmp, tmp + 3)) .describe("Used to generate anchor windows by enumerating ratios"); DMLC_DECLARE_FIELD(feature_stride).set_default(16) .describe("The size of the receptive field each unit in the convolution layer of the rpn," @@ -214,11 +214,11 @@ inline void _Transform(float scale, // out_anchors must have shape (n, 5), where n is ratios.size() * scales.size() inline void GenerateAnchors(const std::vector& base_anchor, - const nnvm::Tuple& ratios, - const nnvm::Tuple& scales, + const mxnet::Tuple& ratios, + const mxnet::Tuple& scales, std::vector *out_anchors) { - for (size_t j = 0; j < ratios.ndim(); ++j) { - for (size_t k = 0; k < scales.ndim(); ++k) { + for (int j = 0; j < ratios.ndim(); ++j) { + for (int k = 0; k < scales.ndim(); ++k) { _Transform(scales[k], ratios[j], base_anchor, out_anchors); } } diff --git a/src/operator/contrib/multibox_detection-inl.h b/src/operator/contrib/multibox_detection-inl.h index 1ac14e237f0d..34ad4471dedc 100644 --- a/src/operator/contrib/multibox_detection-inl.h +++ b/src/operator/contrib/multibox_detection-inl.h @@ -52,7 +52,7 @@ struct MultiBoxDetectionParam : public dmlc::Parameter { bool force_suppress; int keep_topk; int nms_topk; - nnvm::Tuple variances; + mxnet::Tuple variances; DMLC_DECLARE_PARAMETER(MultiBoxDetectionParam) { DMLC_DECLARE_FIELD(clip).set_default(true) .describe("Clip out-of-boundary boxes."); diff --git a/src/operator/contrib/multibox_detection.cc b/src/operator/contrib/multibox_detection.cc index 8d1082914df7..65fe5f1208bb 100644 --- a/src/operator/contrib/multibox_detection.cc +++ b/src/operator/contrib/multibox_detection.cc @@ -87,7 +87,7 @@ inline void MultiBoxDetectionForward(const Tensor &out, const Tensor &temp_space, const float threshold, const bool clip, - const nnvm::Tuple &variances, + const mxnet::Tuple &variances, const float nms_threshold, const bool force_suppress, const int nms_topk) { diff --git a/src/operator/contrib/multibox_detection.cu b/src/operator/contrib/multibox_detection.cu index 98151f8b8755..51b2aa7cdc77 100644 --- a/src/operator/contrib/multibox_detection.cu +++ b/src/operator/contrib/multibox_detection.cu @@ -213,7 +213,7 @@ inline void MultiBoxDetectionForward(const Tensor &out, const Tensor &temp_space, const float threshold, const bool clip, - const nnvm::Tuple &variances, + const mxnet::Tuple &variances, const float nms_threshold, const bool force_suppress, const int nms_topk) { diff --git a/src/operator/contrib/multibox_prior-inl.h b/src/operator/contrib/multibox_prior-inl.h index d8929f3deff4..bfc244f77805 100644 --- a/src/operator/contrib/multibox_prior-inl.h +++ b/src/operator/contrib/multibox_prior-inl.h @@ -57,11 +57,11 @@ enum MultiBoxPriorOpOutputs {kOut}; } // namespace mboxprior_enum struct MultiBoxPriorParam : public dmlc::Parameter { - nnvm::Tuple sizes; - nnvm::Tuple ratios; + mxnet::Tuple sizes; + mxnet::Tuple ratios; bool clip; - nnvm::Tuple steps; - nnvm::Tuple offsets; + mxnet::Tuple steps; + mxnet::Tuple offsets; DMLC_DECLARE_PARAMETER(MultiBoxPriorParam) { DMLC_DECLARE_FIELD(sizes).set_default({1.0f}) .describe("List of sizes of generated MultiBoxPriores."); diff --git a/src/operator/contrib/multibox_target-inl.h b/src/operator/contrib/multibox_target-inl.h index f7a92882650c..6034f13ef734 100644 --- a/src/operator/contrib/multibox_target-inl.h +++ b/src/operator/contrib/multibox_target-inl.h @@ -62,7 +62,7 @@ struct MultiBoxTargetParam : public dmlc::Parameter { float negative_mining_ratio; float negative_mining_thresh; int minimum_negative_samples; - nnvm::Tuple variances; + mxnet::Tuple variances; DMLC_DECLARE_PARAMETER(MultiBoxTargetParam) { DMLC_DECLARE_FIELD(overlap_threshold).set_default(0.5f) .describe("Anchor-GT overlap threshold to be regarded as a positive match."); diff --git a/src/operator/contrib/multibox_target.cc b/src/operator/contrib/multibox_target.cc index a1f2aac250ff..a1808c5a7c81 100644 --- a/src/operator/contrib/multibox_target.cc +++ b/src/operator/contrib/multibox_target.cc @@ -81,7 +81,7 @@ inline void MultiBoxTargetForward(const Tensor &loc_target, const float negative_mining_ratio, const float negative_mining_thresh, const int minimum_negative_samples, - const nnvm::Tuple &variances) { + const mxnet::Tuple &variances) { const DType *p_anchor = anchors.dptr_; const int num_batches = labels.size(0); const int num_labels = labels.size(1); diff --git a/src/operator/contrib/multibox_target.cu b/src/operator/contrib/multibox_target.cu index ca0428348a6c..a44c08b08923 100644 --- a/src/operator/contrib/multibox_target.cu +++ b/src/operator/contrib/multibox_target.cu @@ -349,7 +349,7 @@ inline void MultiBoxTargetForward(const Tensor &loc_target, const float negative_mining_ratio, const float negative_mining_thresh, const int minimum_negative_samples, - const nnvm::Tuple &variances) { + const mxnet::Tuple &variances) { const int num_batches = labels.size(0); const int num_labels = labels.size(1); const int label_width = labels.size(2); diff --git a/src/operator/contrib/proposal-inl.h b/src/operator/contrib/proposal-inl.h index 21e9fe198e63..10f1f86806e4 100644 --- a/src/operator/contrib/proposal-inl.h +++ b/src/operator/contrib/proposal-inl.h @@ -54,8 +54,8 @@ struct ProposalParam : public dmlc::Parameter { int rpn_post_nms_top_n; float threshold; int rpn_min_size; - nnvm::Tuple scales; - nnvm::Tuple ratios; + mxnet::Tuple scales; + mxnet::Tuple ratios; int feature_stride; bool output_score; bool iou_loss; @@ -71,10 +71,10 @@ struct ProposalParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(rpn_min_size).set_default(16) .describe("Minimum height or width in proposal"); tmp[0] = 4.0f; tmp[1] = 8.0f; tmp[2] = 16.0f; tmp[3] = 32.0f; - DMLC_DECLARE_FIELD(scales).set_default(nnvm::Tuple(tmp, tmp + 4)) + DMLC_DECLARE_FIELD(scales).set_default(mxnet::Tuple(tmp, tmp + 4)) .describe("Used to generate anchor windows by enumerating scales"); tmp[0] = 0.5f; tmp[1] = 1.0f; tmp[2] = 2.0f; - DMLC_DECLARE_FIELD(ratios).set_default(nnvm::Tuple(tmp, tmp + 3)) + DMLC_DECLARE_FIELD(ratios).set_default(mxnet::Tuple(tmp, tmp + 3)) .describe("Used to generate anchor windows by enumerating ratios"); DMLC_DECLARE_FIELD(feature_stride).set_default(16) .describe("The size of the receptive field each unit in the convolution layer of the rpn," @@ -212,11 +212,11 @@ inline void _Transform(float scale, // out_anchors must have shape (n, 5), where n is ratios.size() * scales.size() inline void GenerateAnchors(const std::vector& base_anchor, - const nnvm::Tuple& ratios, - const nnvm::Tuple& scales, + const mxnet::Tuple& ratios, + const mxnet::Tuple& scales, std::vector *out_anchors) { - for (size_t j = 0; j < ratios.ndim(); ++j) { - for (size_t k = 0; k < scales.ndim(); ++k) { + for (int j = 0; j < ratios.ndim(); ++j) { + for (int k = 0; k < scales.ndim(); ++k) { _Transform(scales[k], ratios[j], base_anchor, out_anchors); } } diff --git a/src/operator/convolution_v1-inl.h b/src/operator/convolution_v1-inl.h index 080c718dc9bf..d2126bd29d80 100644 --- a/src/operator/convolution_v1-inl.h +++ b/src/operator/convolution_v1-inl.h @@ -336,7 +336,7 @@ class ConvolutionV1Op : public Operator { // param_.workspace is in elements of sizeof(DType) // if param_.workspace is set to zero the nstep_ equals ishape[0] (batch) nstep_ = std::max( - std::min(static_cast(param_.workspace) / + std::min(param_.workspace / (shape_colunit_.Size() + shape_dstunit_.Size()), ishape[0]), 1); diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index 182cd682af8b..aeb189f35b78 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -215,16 +215,16 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs, } struct NormalizeParam : public dmlc::Parameter { - nnvm::Tuple mean; - nnvm::Tuple std; + mxnet::Tuple mean; + mxnet::Tuple std; DMLC_DECLARE_PARAMETER(NormalizeParam) { DMLC_DECLARE_FIELD(mean) - .set_default(nnvm::Tuple {0.0f, 0.0f, 0.0f, 0.0f}) + .set_default(mxnet::Tuple {0.0f, 0.0f, 0.0f, 0.0f}) .describe("Sequence of means for each channel. " "Default value is 0."); DMLC_DECLARE_FIELD(std) - .set_default(nnvm::Tuple {1.0f, 1.0f, 1.0f, 1.0f}) + .set_default(mxnet::Tuple {1.0f, 1.0f, 1.0f, 1.0f}) .describe("Sequence of standard deviations for each channel. " "Default value is 1."); } @@ -245,7 +245,7 @@ inline bool NormalizeOpShape(const nnvm::NodeAttrs& attrs, << "Input tensor must have shape (channels, height, width), or " << "(N, channels, height, width), but got " << dshape; - uint32_t nchannels; + int nchannels = 0; if (dshape.ndim() == 3) { nchannels = dshape[0]; CHECK(nchannels == 3 || nchannels == 1) @@ -981,7 +981,7 @@ inline void RandomColorJitter(const nnvm::NodeAttrs &attrs, } struct AdjustLightingParam : public dmlc::Parameter { - nnvm::Tuple alpha; + mxnet::Tuple alpha; DMLC_DECLARE_PARAMETER(AdjustLightingParam) { DMLC_DECLARE_FIELD(alpha) .describe("The lighting alphas for the R, G, B channels."); @@ -997,7 +997,7 @@ struct RandomLightingParam : public dmlc::Parameter { } }; -inline void AdjustLightingImpl(const nnvm::Tuple& alpha, +inline void AdjustLightingImpl(const mxnet::Tuple& alpha, const OpContext &ctx, const std::vector &inputs, const std::vector &req, diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h index 1eeccb02e030..58f9be702396 100644 --- a/src/operator/nn/deconvolution-inl.h +++ b/src/operator/nn/deconvolution-inl.h @@ -462,7 +462,7 @@ class DeconvolutionOp { oshape[2] * oshape[3]); // See convolution for workspace calculations. nstep_ will be the effective batch size nstep_ = std::max( - std::min(static_cast(param_.workspace) / + std::min(param_.workspace / (shape_colunit_.Size() + shape_dstunit_.Size()), ishape[0]), 1); diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 49eb96b9f8b2..bd923aebbb80 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -83,8 +83,8 @@ struct SGDParam : public dmlc::Parameter { }; struct MultiSGDParam : public dmlc::Parameter { - nnvm::Tuple lrs; - nnvm::Tuple wds; + mxnet::Tuple lrs; + mxnet::Tuple wds; float rescale_grad; float clip_gradient; int num_weights; @@ -110,8 +110,8 @@ struct MultiSGDParam : public dmlc::Parameter { }; struct MultiSGDMomParam : public dmlc::Parameter { - nnvm::Tuple lrs; - nnvm::Tuple wds; + mxnet::Tuple lrs; + mxnet::Tuple wds; float momentum; float rescale_grad; float clip_gradient; diff --git a/src/operator/swapaxis-inl.h b/src/operator/swapaxis-inl.h index 7335daa48392..b17a81f75bc6 100644 --- a/src/operator/swapaxis-inl.h +++ b/src/operator/swapaxis-inl.h @@ -106,8 +106,8 @@ class SwapAxisOp : public Operator { const std::vector &req) { using namespace mshadow; using namespace mshadow::expr; - uint32_t dim1 = param_.dim1; - uint32_t dim2 = param_.dim2; + int dim1 = param_.dim1; + int dim2 = param_.dim2; TBlob data_in = in_data[swapaxisenum::kData]; TBlob data_out = out_data[swapaxisenum::kData]; diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 6469aae17558..e8c5e884588b 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -1273,7 +1273,9 @@ inline bool GatherNDShape(const nnvm::NodeAttrs& attrs, mxnet::TShape oshape(ishape.ndim() - 1 + dshape.ndim() - ishape[0], -1); - for (int i = 0; i < ishape.ndim() - 1; ++i) oshape[i] = ishape[i+1]; + for (int i = 0; i < ishape.ndim() - 1; ++i) { + oshape[i] = ishape[i+1]; + } for (int i = 0; i < dshape.ndim() - ishape[0]; ++i) { oshape[ishape.ndim()-1+i] = dshape[ishape[0] + i]; } diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 0e7f66240926..e99741b70bb6 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -662,7 +662,7 @@ inline void GetIndexRange(const mxnet::TShape& dshape, << "step and begin must have the same length"; } - for (index_t i = 0; i < param_begin.ndim(); ++i) { + for (int i = 0; i < param_begin.ndim(); ++i) { index_t s = param_step.ndim() != 0U && param_step[i].has_value() ? param_step[i].value() : 1; CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0"; @@ -736,11 +736,11 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, MXNET_NDIM_SWITCH(dshape.ndim(), ndim, { common::StaticArray begin, end, step; GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step); - for (index_t i = 0; i < param.begin.ndim(); ++i) { + for (int i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; SetSliceOpOutputDimSize(i, b, e, s, &oshape); } - }); + }) SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); return shape_is_known(oshape); @@ -953,7 +953,7 @@ inline bool SliceAssignOpShape(const nnvm::NodeAttrs& attrs, const int b = begin[i], e = end[i], s = step[i]; SetSliceOpOutputDimSize(i, b, e, s, &vshape); } - }); + }) SHAPE_ASSIGN_CHECK(*in_attrs, 1, vshape); SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape); return true; @@ -1169,7 +1169,7 @@ inline bool SliceAxisShape(const nnvm::NodeAttrs& attrs, } mxnet::TShape shape(ishape.ndim(), -1); for (int i = 0; i < ishape.ndim(); ++i) { - if (static_cast(i) == axis) { + if (i == axis) { shape[i] = static_cast(end - begin); } else { shape[i] = ishape[i]; @@ -1227,7 +1227,7 @@ void SliceAxisGrad_(const nnvm::NodeAttrs& attrs, int axis; index_t begin, end; GetSliceAxisParams(param, outputs[0].shape_, &axis, &begin, &end); - int ndim = static_cast(outputs[0].shape_.ndim()); + int ndim = outputs[0].shape_.ndim(); if (axis + 1 == ndim) { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { @@ -1293,12 +1293,12 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, } else { mxnet::TShape shape(ishape); for (int i = 0; i < param.axes.ndim(); ++i) { - int axis = static_cast(param.axes[i]); + int axis = param.axes[i]; if (axis < 0) { - axis += static_cast(ishape.ndim()); + axis += ishape.ndim(); } CHECK_GE(axis, 0) - << "Slice axis: " << static_cast(param.axes[i]) << " too small"; + << "Slice axis: " << param.axes[i] << " too small"; CHECK_GT(ishape.ndim(), axis) << "Slice axis: " << axis << " exceeds first input: " << ishape.ndim(); CHECK_GT(from_shape.ndim(), axis) @@ -1330,15 +1330,15 @@ inline void SliceLikeInferRanges(const mxnet::TShape& dshape, } } else { for (int i = 0; i < axes.ndim(); ++i) { - int axis = static_cast(axes[i]); + int axis = axes[i]; if (axis < 0) { - axis += static_cast(dshape.ndim()); + axis += dshape.ndim(); } CHECK_GE(axis, 0) - << "Slice axis: " << static_cast(axes[i]) << " too small"; - CHECK_LT(axis, static_cast(dshape.ndim())) + << "Slice axis: " << axes[i] << " too small"; + CHECK_LT(axis, dshape.ndim()) << "Slice axis: " << axis << " exceeds first input: " << dshape.ndim(); - CHECK_LT(axis, static_cast(fshape.ndim())) + CHECK_LT(axis, fshape.ndim()) << "Slice axis: " << axis << " exceeds first input: " << fshape.ndim(); pb[axis] = 0; pe[axis] = fshape[axis]; diff --git a/src/operator/tensor/sparse_retain-inl.h b/src/operator/tensor/sparse_retain-inl.h index 951bf80b81b8..04860e6f369f 100644 --- a/src/operator/tensor/sparse_retain-inl.h +++ b/src/operator/tensor/sparse_retain-inl.h @@ -290,7 +290,7 @@ void SparseRetainOpForwardRspImpl(mshadow::Stream *s, Kernel::Launch(s, output_data.Size(), output_data.dptr()); MSHADOW_IDX_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type - if (input_idx.Size() == input_nd.shape()[0]) { // input rsp is dense + if (input_idx.Size() == static_cast(input_nd.shape()[0])) { // input rsp is dense using namespace mshadow; // copy indices Tensor output_idx_tensor = output_idx.FlatTo1D(s); diff --git a/src/operator/tensor/square_sum-inl.h b/src/operator/tensor/square_sum-inl.h index 016b383117bc..c2e3182c6a1e 100644 --- a/src/operator/tensor/square_sum-inl.h +++ b/src/operator/tensor/square_sum-inl.h @@ -434,14 +434,16 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs, " when ograd_stype = kRowSparseStorage"; CHECK_EQ(ograd.shape().ndim(), 2U); const TBlob ograd_row_idx = ograd.aux_data(rowsparse::kIdx); - CHECK(ograd_row_idx.Size() == in_row_idx.Size() || in_row_idx.Size() == in_data.shape_[0]); + CHECK(ograd_row_idx.Size() == in_row_idx.Size() || + in_row_idx.Size() == static_cast(in_data.shape_[0])); igrad->CheckAndAlloc({ograd.aux_shape(rowsparse::kIdx)}); const TBlob& igrad_data = igrad->data(); const TBlob igrad_row_idx = igrad->aux_data(rowsparse::kIdx); MSHADOW_IDX_TYPE_SWITCH(igrad_row_idx.type_flag_, IType, { // when ograd_row_idx and in_row_idx have the same size and input is not a full rsp // ograd_row_idx and in_row_idx are expected to have the same elements - if (in_row_idx.Size() != input.shape()[0]) { // if input data is not a full rsp + if (in_row_idx.Size() != static_cast(input.shape()[0])) { + // if input data is not a full rsp CHECK_EQ(ograd_row_idx.Size(), in_row_idx.Size()) << "SquareSumRspGradImpl only supports" " equal ograd_row_idx and" " input_row_idx when ograd and" @@ -452,7 +454,8 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs, } MSHADOW_TYPE_SWITCH(igrad_data.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req, req_type, { - if (in_row_idx.Size() != input.shape()[0]) { // input data is not a full rsp + if (in_row_idx.Size() != static_cast(input.shape()[0])) { + // input data is not a full rsp Kernel, xpu>::Launch( s, igrad_data.Size(), igrad_row_idx.dptr(), igrad_data.dptr(), ograd_row_idx.dptr(), diff --git a/tests/nightly/JenkinsfileForBinaries b/tests/nightly/JenkinsfileForBinaries index 53e1c30e188f..13bb50e0e484 100755 --- a/tests/nightly/JenkinsfileForBinaries +++ b/tests/nightly/JenkinsfileForBinaries @@ -19,6 +19,7 @@ //This is a Jenkinsfile for nightly tests. The format and some functions have been picked up from the top-level Jenkinsfile mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' node('utility') { // Loading the utilities requires a node context unfortunately @@ -39,6 +40,24 @@ core_logic: { utils.pack_lib('gpu', mx_lib) } } + }, + 'CPU: USE_INT64_TENSOR_SIZE': { + node(NODE_LINUX_CPU) { + ws('workspace/build-cpu-int64') { + utils.init_git() + utils.docker_run('ubuntu_nightly_cpu', 'build_ubuntu_cpu_large_tensor', false) + utils.pack_lib('ubuntu_cpu_int64', mx_cmake_lib, true) + } + } + }, + 'GPU: USE_INT64_TENSOR_SIZE': { + node(NODE_LINUX_GPU) { + ws('workspace/build-gpu-int64') { + utils.init_git() + utils.docker_run('ubuntu_nightly_gpu', 'build_ubuntu_gpu_large_tensor', true) + utils.pack_lib('ubuntu_gpu_int64', mx_cmake_lib, true) + } + } } } @@ -59,6 +78,22 @@ core_logic: { } } }, + 'Test Large Tensor Size: CPU': { + node(NODE_LINUX_CPU) { + ws('workspace/large_tensor-cpu') { + utils.unpack_and_init('cpu_int64', mx_cmake_lib) + utils.docker_run('ubuntu_nightly_cpu', 'nightly_test_large_tensor', false) + } + } + }, + 'Test Large Tensor Size: GPU': { + node(NODE_LINUX_GPU) { + ws('workspace/large_tensor-gpu') { + utils.unpack_and_init('gpu_int64', mx_cmake_lib) + utils.docker_run('ubuntu_nightly_gpu', 'nightly_test_large_tensor', true) + } + } + }, 'StraightDope: Python2 Single-GPU': { node(NODE_LINUX_GPU_P3) { ws('workspace/straight_dope-single_gpu') { diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index a627467cb959..1b7dad487a68 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -27,6 +27,7 @@ SMALL_Y = 50 LARGE_SIZE = LARGE_X * SMALL_Y + def test_gluon_embedding(): m = gluon.nn.Embedding(SMALL_Y, MEDIUM_X) m.initialize() @@ -35,22 +36,26 @@ def test_gluon_embedding(): assert b.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X) assert b.asnumpy().size == LARGE_SIZE + def test_ndarray_zeros(): a = nd.zeros(shape=(LARGE_X, SMALL_Y)) assert a[-1][0] == 0 assert a.shape == (LARGE_X, SMALL_Y) assert a.size == LARGE_SIZE + def test_ndarray_ones(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) assert a[-1][0] == 1 assert nd.sum(a).asnumpy() == LARGE_SIZE + @with_seed() def test_ndarray_random_uniform(): a = nd.random.uniform(shape=(LARGE_X, SMALL_Y)) assert a[-1][0] != 0 + @with_seed() def test_ndarray_random_randint(): a = nd.random.randint(100, 10000, shape=(LARGE_X, SMALL_Y)) @@ -59,14 +64,16 @@ def test_ndarray_random_randint(): low_large_value = 2**32 high_large_value = 2**34 a = nd.random.randint(low_large_value,high_large_value) - low = mx.nd.array([low_large_value],dtype='int64') - high = mx.nd.array([high_large_value],dtype='int64') + low = mx.nd.array([low_large_value], dtype='int64') + high = mx.nd.array([high_large_value], dtype='int64') assert a.__gt__(low) & a.__lt__(high) + def test_ndarray_empty(): a = nd.empty((LARGE_X, SMALL_Y)) assert a.shape == (LARGE_X, SMALL_Y) + def test_elementwise(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) b = nd.ones(shape=(LARGE_X, SMALL_Y)) @@ -77,22 +84,26 @@ def test_elementwise(): res = nd.sqrt(a + 3) assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] + def test_reduce(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) assert nd.sum(a).asnumpy() == a.shape[0] * a.shape[1] + def test_dot(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) b = nd.ones(shape=(SMALL_Y, SMALL_Y)) res = nd.dot(a, b) assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1] + def test_FullyConnected(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) b = nd.ones(shape=(SMALL_Y, SMALL_Y)) res = nd.FullyConnected(a, b, num_hidden=b.shape[1], no_bias=True) assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1] + def test_broadcast(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1) @@ -101,45 +112,53 @@ def test_broadcast(): res = mx.nd.broadcast_like(b, a) assert np.sum(res[-1].asnumpy() == LARGE_X) == a.shape[1] + def test_clip(): a = nd.arange(0, LARGE_X).reshape(LARGE_X, 1) b = nd.broadcast_to(a, shape=(a.shape[0], SMALL_Y)) res = nd.clip(b, a_min=100, a_max=1000) assert np.sum(res[-1].asnumpy() == 1000) == b.shape[1] + def test_take(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) idx = nd.arange(LARGE_X-1000, LARGE_X) res = nd.take(a, idx) assert np.sum(res[-1].asnumpy() == 1) == res.shape[1] + def test_slice(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) res = nd.slice(a, begin=(LARGE_X-1000, 1), end=(LARGE_X, SMALL_Y)) assert np.sum(res[-1].asnumpy() == 1) == res.shape[1] + def test_slice_assign(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) a[LARGE_X-1:LARGE_X] = 1000 assert np.sum(a[-1].asnumpy() == 1000) == a.shape[1] - + + def test_expand_dims(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) res = nd.expand_dims(a, axis=1) assert res.shape == (a.shape[0], 1, a.shape[1]) + def test_squeeze(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) data = nd.expand_dims(a, axis=1) res = nd.squeeze(data) assert res.shape == a.shape + def test_broadcast_div(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) b = nd.ones(shape=(LARGE_X, 1)) * 2 res = a / b assert np.sum(res[-1].asnumpy() == 0.5) == a.shape[1] + def test_Dense(ctx=mx.cpu(0)): data = mx.nd.ones(shape=(50*1000*1000, 100)) linear = gluon.nn.Dense(100) @@ -148,6 +167,7 @@ def test_Dense(ctx=mx.cpu(0)): res.wait_to_read() assert res.shape == (50000000, 100) + def test_where(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1) @@ -159,12 +179,14 @@ def test_where(): res = nd.sparse.where(csr_cond, a, b) assert np.sum(res[0].asnumpy() == 1) == b.shape[1] + def test_pick(): a = mx.nd.ones(shape=(256*35, 1024*1024)) b = mx.nd.ones(shape=(256*35,)) res = mx.nd.pick(a,b) assert res.shape == b.shape + if __name__ == '__main__': import nose nose.runmodule()