Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

add a compiler flag to use int64 as tensor size #14570

Merged
merged 37 commits into from
Apr 23, 2019
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
41351f3
use a compile flag to use int64 tensor size
apeforest Mar 29, 2019
e9bd3cc
use personal mshadow repo
apeforest Mar 29, 2019
d8d21ed
Merge remote-tracking branch 'upstream/master' into perf/large-tensor
apeforest Apr 2, 2019
caf8e7f
update data type
apeforest Apr 2, 2019
0ea2cbc
update make config
apeforest Apr 2, 2019
3a3c02f
change size_t to index_t and add documentation
apeforest Apr 9, 2019
b1ca6dd
update mshadow submodule to master
apeforest Apr 15, 2019
5443fd5
fix compilation warning
apeforest Apr 15, 2019
872255f
fix compiler warning
apeforest Apr 15, 2019
4bd1805
fix compiler warning
apeforest Apr 15, 2019
08e9b10
fix compiler warning
apeforest Apr 15, 2019
3a4661a
Merge remote-tracking branch 'upstream/master' into perf/large-tensor
apeforest Apr 15, 2019
d3d6cc6
fix compiler warning
apeforest Apr 15, 2019
7e3ed63
fix compiler error
apeforest Apr 15, 2019
54735db
change nnvm::Tuple to mxnet::Tuple
apeforest Apr 16, 2019
5fd9ad1
Merge remote-tracking branch 'upstream/master' into perf/large-tensor
apeforest Apr 16, 2019
0758d0c
fix compiler warning
apeforest Apr 16, 2019
a503ec5
fix compiler warning
apeforest Apr 16, 2019
cd9aa53
fix compiler warning
apeforest Apr 16, 2019
12559b1
fix compiler warning
apeforest Apr 16, 2019
a4e4a0c
fix compiler warning
apeforest Apr 16, 2019
2399864
fix lint
apeforest Apr 17, 2019
334d775
update CI runtime_functons
apeforest Apr 17, 2019
826613a
update runtime function
apeforest Apr 17, 2019
4412b90
correct runtime_functions
apeforest Apr 17, 2019
1047eb5
udpate runtime functions
apeforest Apr 17, 2019
97a1c08
add nightly test for large tensor
apeforest Apr 17, 2019
861b95e
update Jenkins files to test new compiler flag
apeforest Apr 17, 2019
5054f8d
Merge remote-tracking branch 'upstream/master' into perf/large-tensor
apeforest Apr 17, 2019
935389d
Merge remote-tracking branch 'upstream/master' into perf/large-tensor
apeforest Apr 18, 2019
b86e630
fix CI
apeforest Apr 18, 2019
f7540d1
Merge remote-tracking branch 'upstream/master' into perf/large-tensor
apeforest Apr 18, 2019
d8b04b3
add runtime feature detect for the compiler flag
apeforest Apr 19, 2019
20221d6
change build from make to cmake
apeforest Apr 19, 2019
bc95113
fix CI
apeforest Apr 19, 2019
9c672b7
move tests to nightly
apeforest Apr 20, 2019
27584ea
Merge remote-tracking branch 'upstream/master' into perf/large-tensor
apeforest Apr 20, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/mshadow
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ mxnet_option(USE_SIGNAL_HANDLER "Print stack traces on segfaults." OFF)
mxnet_option(USE_TENSORRT "Enable infeference optimization with TensorRT." OFF)
mxnet_option(USE_ASAN "Enable Clang/GCC ASAN sanitizers." OFF)
mxnet_option(ENABLE_TESTCOVERAGE "Enable compilation with test coverage metric output" OFF)
mxnet_option(USE_INT64_TENSOR_SIZE "Use int64_t to represent the total number of elements in a tensor" OFF)

message(STATUS "CMAKE_CROSSCOMPILING ${CMAKE_CROSSCOMPILING}")
message(STATUS "CMAKE_HOST_SYSTEM_PROCESSOR ${CMAKE_HOST_SYSTEM_PROCESSOR}")
Expand Down Expand Up @@ -295,6 +296,13 @@ else()
add_definitions(-DMXNET_USE_NCCL=0)
endif()

if (USE_INT64_TENSOR_SIZE)
apeforest marked this conversation as resolved.
Show resolved Hide resolved
message(STATUS "Using 64-bit integer for tensor size")
add_definitions(-DMSHADOW_INT64_TENSOR_SIZE=1)
else()
add_definitions(-DMSHADOW_INT64_TENSOR_SIZE=0)
endif()

include(cmake/ChooseBlas.cmake)
if(USE_CUDA AND FIRST_CUDA)
include(3rdparty/mshadow/cmake/Utils.cmake)
Expand Down
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ ifeq ($(USE_OPERATOR_TUNING), 1)
CFLAGS += -DMXNET_USE_OPERATOR_TUNING=1
endif

ifeq ($(USE_INT64_TENSOR_SIZE), 1)
CFLAGS += -DMSHADOW_INT64_TENSOR_SIZE=1
apeforest marked this conversation as resolved.
Show resolved Hide resolved
else
CFLAGS += -DMSHADOW_INT64_TENSOR_SIZE=0
endif
# verify existence of separate lapack library when using blas/openblas/atlas
# switch off lapack support in case it can't be found
# issue covered with this
Expand Down
54 changes: 54 additions & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,53 @@ build_ubuntu_gpu_cmake() {
ninja -v
}

build_ubuntu_cpu_large_tensor() {
set -ex
cd /work/build
build_ccache_wrappers
cmake \
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \
-DUSE_SIGNAL_HANDLER=ON \
-DENABLE_TESTCOVERAGE=ON \
-DUSE_CUDA=OFF \
-DUSE_CUDNN=OFF \
-DUSE_MKLDNN=OFF \
-DCMAKE_BUILD_TYPE=Release \
-DUSE_INT64_TENSOR_SIZE=ON \
-G Ninja \
/work/mxnet

ninja -v
}

build_ubuntu_gpu_large_tensor() {
set -ex
cd /work/build
build_ccache_wrappers
cmake \
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \
-DUSE_SIGNAL_HANDLER=ON \
-DENABLE_TESTCOVERAGE=ON \
-DUSE_CUDA=ON \
-DUSE_CUDNN=ON \
-DUSE_MKL_IF_AVAILABLE=OFF \
-DUSE_MKLML_MKL=OFF \
-DUSE_MKLDNN=OFF \
-DUSE_DIST_KVSTORE=ON \
-DCMAKE_BUILD_TYPE=Release \
-DCUDA_ARCH_NAME=Manual \
-DCUDA_ARCH_BIN=$CI_CMAKE_CUDA_ARCH_BIN \
-DUSE_INT64_TENSOR_SIZE=ON \
-G Ninja \
/work/mxnet

ninja -v
}

build_ubuntu_blc() {
echo "pass"
}
Expand Down Expand Up @@ -1183,6 +1230,13 @@ nightly_test_KVStore_singleNode() {
python tests/nightly/test_kvstore.py
}

#Test Large Tensor Size
nightly_test_large_tensor() {
set -ex
export PYTHONPATH=./python/
python tests/nightly/test_large_array.py
}

#Tests Amalgamation Build with 5 different sets of flags
nightly_test_amalgamation() {
set -ex
Expand Down
28 changes: 28 additions & 0 deletions ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,34 @@ def compile_unix_openblas_debug_cpu() {
}]
}

def compile_unix_int64_cpu() {
return ['CPU: USE_INT64_TENSOR_SIZE': {
node(NODE_LINUX_CPU) {
ws('workspace/build-cpu-int64') {
timeout(time: max_time, unit: 'MINUTES') {
utils.init_git()
utils.docker_run('ubuntu_cpu', 'build_ubuntu_cpu_large_tensor', false)
utils.pack_lib('ubuntu_cpu_int64', mx_lib, true)
}
}
}
}]
}

def compile_unix_int64_gpu() {
return ['GPU: USE_INT64_TENSOR_SIZE': {
node(NODE_LINUX_GPU) {
ws('workspace/build-gpu-int64') {
timeout(time: max_time, unit: 'MINUTES') {
utils.init_git()
utils.docker_run('ubuntu_gpu', 'build_ubuntu_gpu_large_tensor', false)
utils.pack_lib('ubuntu_gpu_int64', mx_lib, true)
}
}
}
}]
}

def compile_unix_mkl_cpu() {
return ['CPU: MKL': {
node(NODE_LINUX_CPU) {
Expand Down
3 changes: 2 additions & 1 deletion ci/jenkins/Jenkinsfile_unix_cpu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ core_logic: {
custom_steps.compile_unix_openblas_debug_cpu(),
custom_steps.compile_unix_mkl_cpu(),
custom_steps.compile_unix_mkldnn_cpu(),
custom_steps.compile_unix_mkldnn_mkl_cpu()
custom_steps.compile_unix_mkldnn_mkl_cpu(),
custom_steps.compile_unix_int64_cpu()
])

utils.parallel_stage('Tests', [
Expand Down
1 change: 1 addition & 0 deletions ci/jenkins/Jenkinsfile_unix_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ core_logic: {
custom_steps.compile_unix_cmake_mkldnn_gpu(),
custom_steps.compile_unix_cmake_gpu(),
custom_steps.compile_unix_tensorrt_gpu(),
custom_steps.compile_unix_int64_gpu()
])

utils.parallel_stage('Tests', [
Expand Down
6 changes: 5 additions & 1 deletion include/mxnet/libinfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@
#define MXNET_USE_SIGNAL_HANDLER 0
#endif


#ifndef MXNET_USE_INT64_TENSOR_SIZE
#define MXNET_USE_INT64_TENSOR_SIZE MSHADOW_INT64_TENSOR_SIZE
#endif

namespace mxnet {
namespace features {
Expand Down Expand Up @@ -177,6 +179,8 @@ enum : unsigned {
PROFILER,
DIST_KVSTORE,
CXX14,
INT64_TENSOR_SIZE,

// Signal handler to print stack traces on exceptions
SIGNAL_HANDLER,
DEBUG,
Expand Down
13 changes: 7 additions & 6 deletions include/mxnet/tensor_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,16 @@ class TBlob {
return shape_.ndim();
}
/*!
* \brief return size of i-th dimension, start counting from highest dimension
* \brief return size of i-th dimension, start counting from highest dimension.
* return type needs to be a signed integer.
* \param idx the dimension count from the highest dimensin
* \return the size
* \return the size. -1 means unknown size to support zero-size tensor.
*/
inline index_t size(index_t idx) const {
return shape_[idx];
}
/*! \brief total number of elements in the tensor */
inline index_t Size(void) const {
inline size_t Size(void) const {
return shape_.Size();
}
/*! \brief get pointer in dtype */
Expand Down Expand Up @@ -443,7 +444,7 @@ class FieldEntry<mxnet::TShape>
throw dmlc::ParamError(os.str());
}
if (enforce_nonzero_) {
for (mxnet::index_t i = 0; i < v.ndim(); ++i) {
for (int i = 0; i < v.ndim(); ++i) {
apeforest marked this conversation as resolved.
Show resolved Hide resolved
if (v[i] == 0U) {
std::ostringstream os;
os << "value " << v << "for Parameter " << this->key_
Expand All @@ -457,7 +458,7 @@ class FieldEntry<mxnet::TShape>
this->enforce_nonzero_ = true;
return this->self();
}
inline FieldEntry<mxnet::TShape> &set_expect_ndim(mxnet::index_t ndim) {
inline FieldEntry<mxnet::TShape> &set_expect_ndim(int ndim) {
apeforest marked this conversation as resolved.
Show resolved Hide resolved
expect_ndim_ = ndim;
return this->self();
}
Expand All @@ -466,7 +467,7 @@ class FieldEntry<mxnet::TShape>
// whether all the entries need to be nonzero
bool enforce_nonzero_;
// expected number of dimension, default = 0 means no restriction.
mxnet::index_t expect_ndim_;
int expect_ndim_;
};

} // namespace parameter
Expand Down
16 changes: 8 additions & 8 deletions include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ class TShape : public Tuple<dim_t> {
* \param axis_end The ending axis specified.
* \return the flat 3d shape
*/
inline mshadow::Shape<3> FlatTo3D(size_t axis_begin, size_t axis_end) const {
inline mshadow::Shape<3> FlatTo3D(int axis_begin, int axis_end) const {
CHECK(axis_end >= axis_begin);
mshadow::Shape<3> s;
CHECK(ndim_is_known(ndim())) << "shape must have a valid ndim";
Expand All @@ -579,10 +579,10 @@ class TShape : public Tuple<dim_t> {
s.shape_[1] = 1;
s.shape_[2] = 1;

for (size_t i = 0; i < axis_begin; ++i) {
for (int i = 0; i < axis_begin; ++i) {
s.shape_[0] *= d[i];
}
for (size_t i = axis_begin; i <= axis_end; ++i) {
for (int i = axis_begin; i <= axis_end; ++i) {
s.shape_[1] *= d[i];
}
for (int i = axis_end + 1; i < ndim(); ++i) {
Expand All @@ -595,7 +595,7 @@ class TShape : public Tuple<dim_t> {
* \param axis The axis specified.
* \return the flat 3d shape
*/
inline mshadow::Shape<3> FlatTo3D(size_t axis) const {
inline mshadow::Shape<3> FlatTo3D(int axis) const {
return FlatTo3D(axis, axis);
}
inline bool operator==(const TShape &s) const {
Expand Down Expand Up @@ -712,8 +712,8 @@ template<typename T>
struct hash<mxnet::Tuple<T> > {
/*! \brief hash a Tuple into unsigned int */
size_t operator()(const mxnet::Tuple<T>& val) const {
std::hash<uint32_t> hash_uint;
size_t res = hash_uint(val.ndim());
std::hash<int> hash_int;
size_t res = hash_int(val.ndim());
for (int i = 0; i < val.ndim(); ++i) {
res = dmlc::HashCombine(res, val[i]);
}
Expand All @@ -726,8 +726,8 @@ template<>
struct hash<mxnet::TShape> {
/*! \brief hash a TShape into unsigned int */
size_t operator()(const mxnet::TShape& val) const {
std::hash<uint32_t> hash_uint;
size_t res = hash_uint(val.ndim());
std::hash<int> hash_int;
size_t res = hash_int(val.ndim());
for (int i = 0; i < val.ndim(); ++i) {
res = dmlc::HashCombine(res, val[i]);
}
Expand Down
6 changes: 6 additions & 0 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ EXTRA_OPERATORS =
# Create C++ interface package
USE_CPP_PACKAGE = 0

# Use int64_t type to represent the total number of elements in a tensor
# This will cause performance degradation reported in issue #14496
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
# Set to 1 for large tensor with tensor size greater than INT32_MAX i.e. 2147483647
# Note: the size of each dimension is still bounded by INT32_MAX
USE_INT64_TENSOR_SIZE = 0

#----------------------------
# plugins
#----------------------------
Expand Down
6 changes: 6 additions & 0 deletions make/crosscompile.jetson.mk
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ EXTRA_OPERATORS =
# Create C++ interface package
USE_CPP_PACKAGE = 0

# Use int64_t type to represent the total number of elements in the tensor
# This will cause performance degradation reported in issue #14496
# Set to 1 for large tensor with tensor size greater than INT32_MAX i.e. 2147483647
# Note: the size of each dimension is still bounded by INT32_MAX
USE_INT64_TENSOR_SIZE = 0

#----------------------------
# plugins
#----------------------------
Expand Down
6 changes: 6 additions & 0 deletions make/osx.mk
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ EXTRA_OPERATORS =
# Create C++ interface package
USE_CPP_PACKAGE = 0

# Use int64_t type to represent the total number of elements in a tensor
# This will cause performance degradation reported in issue #14496
# Set to 1 for large tensor with tensor size greater than INT32_MAX i.e. 2147483647
# Note: the size of each dimension is still bounded by INT32_MAX
USE_INT64_TENSOR_SIZE = 0

#----------------------------
# plugins
#----------------------------
Expand Down
12 changes: 6 additions & 6 deletions src/common/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ template<typename T>
inline size_t SerializedSize(const T &obj);

template<typename T>
inline size_t SerializedSize(const nnvm::Tuple <T> &obj);
inline size_t SerializedSize(const mxnet::Tuple <T> &obj);

template<typename K, typename V>
inline size_t SerializedSize(const std::map <K, V> &obj);
Expand All @@ -64,7 +64,7 @@ template<typename T>
inline void Serialize(const T &obj, char **buffer);

template<typename T>
inline void Serialize(const nnvm::Tuple <T> &obj, char **buffer);
inline void Serialize(const mxnet::Tuple <T> &obj, char **buffer);

template<typename K, typename V>
inline void Serialize(const std::map <K, V> &obj, char **buffer);
Expand All @@ -79,7 +79,7 @@ template<typename T>
inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos);

template<typename T>
inline void Deserialize(nnvm::Tuple <T> *obj, const std::string &buffer, size_t *curr_pos);
inline void Deserialize(mxnet::Tuple <T> *obj, const std::string &buffer, size_t *curr_pos);

template<typename K, typename V>
inline void Deserialize(std::map <K, V> *obj, const std::string &buffer, size_t *curr_pos);
Expand All @@ -102,7 +102,7 @@ inline size_t SerializedSize(const T &obj) {
}

template<typename T>
inline size_t SerializedSize(const nnvm::Tuple <T> &obj) {
inline size_t SerializedSize(const mxnet::Tuple <T> &obj) {
if (is_container<T>::value) {
size_t sum_val = 4;
for (const auto& el : obj) {
Expand Down Expand Up @@ -180,7 +180,7 @@ inline void Serialize(const T &obj, char **buffer) {
}

template<typename T>
inline void Serialize(const nnvm::Tuple <T> &obj, char **buffer) {
inline void Serialize(const mxnet::Tuple <T> &obj, char **buffer) {
uint32_t size = obj.ndim();
std::memcpy(*buffer, &size, 4);
*buffer += 4;
Expand Down Expand Up @@ -244,7 +244,7 @@ inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos) {
}

template<typename T>
inline void Deserialize(nnvm::Tuple <T> *obj, const std::string &buffer, size_t *curr_pos) {
inline void Deserialize(mxnet::Tuple <T> *obj, const std::string &buffer, size_t *curr_pos) {
uint32_t size = obj->ndim();
std::memcpy(&size, &buffer[*curr_pos], 4);
*curr_pos += 4;
Expand Down
8 changes: 4 additions & 4 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
bool static_alloc;
bool static_shape;
bool is_dynamic;
nnvm::Tuple<uint32_t> data_indices;
nnvm::Tuple<uint32_t> param_indices;
mxnet::Tuple<uint32_t> data_indices;
mxnet::Tuple<uint32_t> param_indices;
std::string subgraph;
DMLC_DECLARE_PARAMETER(CachedOpConfig) {
DMLC_DECLARE_FIELD(static_alloc)
Expand All @@ -59,10 +59,10 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
.set_default(Imperative::BulkExecMaxNodeTrainBwd())
.describe("Segment size of bulk execution during backward pass.");
DMLC_DECLARE_FIELD(data_indices)
.set_default(nnvm::Tuple<uint32_t>())
.set_default(mxnet::Tuple<uint32_t>())
.describe("Position of argument variables.");
DMLC_DECLARE_FIELD(param_indices)
.set_default(nnvm::Tuple<uint32_t>())
.set_default(mxnet::Tuple<uint32_t>())
.describe("Position of parameters.");
DMLC_DECLARE_FIELD(subgraph)
.set_default(std::string(""))
Expand Down
Loading