Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1185] [WIP] Support large array in several operators #13191

Closed
wants to merge 46 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
a572198
Support large integer in operators
apeforest Nov 6, 2018
e37a06b
fix large array in sum
apeforest Nov 7, 2018
e48b274
Fix large array issue in slice operation
apeforest Nov 7, 2018
b183c3f
fix bug in shape
apeforest Nov 8, 2018
fcebf5a
fix getitem with large index
apeforest Nov 8, 2018
3c7557b
fix bug in slice operator
apeforest Nov 8, 2018
904f09b
fix bug in random uniform op
apeforest Nov 8, 2018
08bd8ab
add nightly test
apeforest Nov 8, 2018
244f386
fix lint error
apeforest Nov 9, 2018
3ecd257
fix compilation error on gpu
apeforest Nov 9, 2018
c70afe8
fix gpu compilation
apeforest Nov 9, 2018
ffcd175
fix build issue
apeforest Nov 9, 2018
dbe0e6c
fix windows build error
apeforest Nov 9, 2018
0680184
fix build issue in windows
apeforest Nov 9, 2018
8fda02a
fix omp build issue
apeforest Nov 9, 2018
87cd144
fix cpp-package build error
apeforest Nov 9, 2018
7afc7a8
fix mkldnn build
apeforest Nov 9, 2018
862be24
fix an array size bound
apeforest Nov 10, 2018
22213fa
add constants in tests
apeforest Nov 10, 2018
cbaa553
fix sparse array
apeforest Nov 13, 2018
7eca035
fix unit test
apeforest Nov 13, 2018
1b48d4a
fix unit test
apeforest Nov 14, 2018
cb2ee1e
fix R and scala package build
apeforest Nov 14, 2018
08471f2
Fix build error in scala, julia, perl
apeforest Nov 14, 2018
e5a3b32
fix a typo
apeforest Nov 14, 2018
629a7c5
fix R-package scala-package compiation error
apeforest Nov 14, 2018
2dd990a
fix scala unit test
apeforest Nov 14, 2018
5b0cd3a
fix python2 unit test
apeforest Nov 14, 2018
43ba3aa
fix scala unit test
apeforest Nov 15, 2018
5286f63
fix scala unit test
apeforest Nov 15, 2018
7247e6b
fix scala build
apeforest Nov 16, 2018
f8839b3
fix python unit test
apeforest Nov 16, 2018
e1cd1cd
update scala-package to fix unittest
apeforest Nov 17, 2018
e0f4e2d
Merge remote-tracking branch 'upstream/master' into bugfix/large-array
apeforest Nov 17, 2018
e0fe05c
fix scala unit test
apeforest Nov 19, 2018
024a0ce
fix array typecode for python 2 and python 3
apeforest Nov 19, 2018
1f3361b
lint it
apeforest Nov 19, 2018
23579e1
Merge remote-tracking branch 'upstream/master' into bugfix/large-array
apeforest Nov 19, 2018
1cd9b88
lint it again
apeforest Nov 20, 2018
335e896
fix python include error
apeforest Nov 20, 2018
a08c79e
fix unit test
apeforest Nov 20, 2018
69703fc
lint me in
apeforest Nov 20, 2018
01952c5
fix python unit test in python2 windows
apeforest Nov 20, 2018
a68bd97
fix perl-package unit test
apeforest Nov 20, 2018
c1b14d1
fix perl package
apeforest Nov 20, 2018
a3daa9b
Merge remote-tracking branch 'upstream/master' into bugfix/large-array
apeforest Nov 27, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions R-package/src/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@ inline std::vector<std::string> SafeGetListNames(const Rcpp::List& src) {
* \param rshape The dimension in R
* \return A internal vector representation of shapes in mxnet.
*/
inline std::vector<mx_uint> Dim2InternalShape(const Rcpp::Dimension &rshape) {
std::vector<mx_uint> shape(rshape.size());
inline std::vector<dim_t> Dim2InternalShape(const Rcpp::Dimension &rshape) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: what is the type of mx_uint and dim_t?

std::vector<dim_t> shape(rshape.size());
for (size_t i = 0; i < rshape.size(); ++i) {
shape[rshape.size() - i - 1] = rshape[i];
}
Expand Down
6 changes: 3 additions & 3 deletions R-package/src/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ Rcpp::RObject NDArrayPacker::CreateNDArrayPacker() {

Rcpp::Dimension NDArray::dim() const {
mx_uint ndim;
const mx_uint *pshape;
const dim_t *pshape;
MX_CALL(MXNDArrayGetShape(
ptr_->handle, &ndim, &pshape));
Rcpp::IntegerVector dat(pshape, pshape + ndim);
Expand All @@ -190,7 +190,7 @@ Rcpp::Dimension NDArray::dim() const {
}

NDArray NDArray::Clone() const {
std::vector<mx_uint> shape = Dim2InternalShape(this->dim());
std::vector<dim_t> shape = Dim2InternalShape(this->dim());
Context ctx = this->ctx();
NDArrayHandle handle;
MX_CALL(MXNDArrayCreate(dmlc::BeginPtr(shape),
Expand Down Expand Up @@ -276,7 +276,7 @@ Rcpp::List NDArray::Load(const std::string& filename) {
NDArray::RObjectType NDArray::Empty(
const Rcpp::Dimension& rshape,
const Context::RObjectType& rctx) {
std::vector<mx_uint> shape = Dim2InternalShape(rshape);
std::vector<dim_t> shape = Dim2InternalShape(rshape);
Context ctx(rctx);
NDArrayHandle handle;
MX_CALL(MXNDArrayCreate(dmlc::BeginPtr(shape),
Expand Down
12 changes: 6 additions & 6 deletions R-package/src/symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ Symbol::RObjectType Symbol::GetOutput(mx_uint index) const {
// helper function to convert shape into Rcpp vector
inline Rcpp::List BuildShapeData(mx_uint shape_size,
const mx_uint *shape_ndim,
const mx_uint **shape_data,
const dim_t **shape_data,
const std::vector<std::string> &names) {
Rcpp::List ret(shape_size);
for (mx_uint i = 0; i < shape_size; ++i) {
Expand All @@ -185,26 +185,26 @@ SEXP Symbol::InferShape(const Rcpp::List& kwargs) const {
<< "Need to pass parameters in key=value style.\n";
std::vector<std::string> keys = kwargs.names();
std::vector<mx_uint> arg_ind_ptr(1, 0);
std::vector<mx_uint> arg_shape_data;
std::vector<dim_t> arg_shape_data;

for (size_t i = 0; i < kwargs.size(); ++i) {
RCHECK(keys[i].length() != 0)
<< "Need to pass parameters in key=value style.\n";
std::vector<mx_uint> dim = Dim2InternalShape(kwargs[i]);
std::vector<dim_t> dim = Dim2InternalShape(kwargs[i]);
arg_shape_data.insert(arg_shape_data.end(), dim.begin(), dim.end());
arg_ind_ptr.push_back(static_cast<mx_uint>(arg_shape_data.size()));
}
std::vector<const char*> c_keys = CKeys(keys);

mx_uint in_shape_size;
const mx_uint *in_shape_ndim;
const mx_uint **in_shape_data;
const dim_t **in_shape_data;
mx_uint out_shape_size;
const mx_uint *out_shape_ndim;
const mx_uint **out_shape_data;
const dim_t **out_shape_data;
mx_uint aux_shape_size;
const mx_uint *aux_shape_ndim;
const mx_uint **aux_shape_data;
const dim_t **aux_shape_data;
int complete;

MX_CALL(MXSymbolInferShape(
Expand Down
2 changes: 1 addition & 1 deletion cpp-package/include/mxnet-cpp/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
namespace mxnet {
namespace cpp {

typedef unsigned index_t;
typedef int64_t index_t;

enum OpReqType {
/*! \brief no operation, do not write anything */
Expand Down
2 changes: 1 addition & 1 deletion cpp-package/include/mxnet-cpp/initializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class Xavier : public Initializer {
Shape shape(arr->GetShape());
float hw_scale = 1.0f;
if (shape.ndim() > 2) {
for (size_t i = 2; i < shape.ndim(); ++i) {
for (index_t i = 2; i < shape.ndim(); ++i) {
hw_scale *= shape[i];
}
}
Expand Down
4 changes: 2 additions & 2 deletions cpp-package/include/mxnet-cpp/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class NDArray {
* \param constext context of NDArray
* \param delay_alloc whether delay the allocation
*/
NDArray(const std::vector<mx_uint> &shape, const Context &context,
NDArray(const std::vector<index_t> &shape, const Context &context,
bool delay_alloc = true);
/*!
* \brief construct a new dynamic NDArray
Expand Down Expand Up @@ -444,7 +444,7 @@ class NDArray {
/*!
* \return the shape of current NDArray, in the form of mx_uint vector
*/
std::vector<mx_uint> GetShape() const;
std::vector<index_t> GetShape() const;
/*!
* \return the data type of current NDArray
*/
Expand Down
8 changes: 4 additions & 4 deletions cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ inline NDArray::NDArray() {
inline NDArray::NDArray(const NDArrayHandle &handle) {
blob_ptr_ = std::make_shared<NDBlob>(handle);
}
inline NDArray::NDArray(const std::vector<mx_uint> &shape, const Context &context,
inline NDArray::NDArray(const std::vector<index_t> &shape, const Context &context,
bool delay_alloc) {
NDArrayHandle handle;
CHECK_EQ(MXNDArrayCreate(shape.data(), shape.size(), context.GetDeviceType(),
Expand Down Expand Up @@ -396,11 +396,11 @@ inline size_t NDArray::Size() const {
return ret;
}

inline std::vector<mx_uint> NDArray::GetShape() const {
const mx_uint *out_pdata;
inline std::vector<index_t> NDArray::GetShape() const {
const index_t *out_pdata;
mx_uint out_dim;
MXNDArrayGetShape(blob_ptr_->handle_, &out_dim, &out_pdata);
std::vector<mx_uint> ret;
std::vector<index_t> ret;
for (mx_uint i = 0; i < out_dim; ++i) {
ret.push_back(out_pdata[i]);
}
Expand Down
8 changes: 4 additions & 4 deletions cpp-package/include/mxnet-cpp/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ class Symbol {
* \param aux_shapes use to store the infered shapes of auxiliary states
*/
void InferShape(
const std::map<std::string, std::vector<mx_uint> > &arg_shapes,
std::vector<std::vector<mx_uint> > *in_shape,
std::vector<std::vector<mx_uint> > *aux_shape,
std::vector<std::vector<mx_uint> > *out_shape) const;
const std::map<std::string, std::vector<index_t> > &arg_shapes,
std::vector<std::vector<index_t> > *in_shape,
std::vector<std::vector<index_t> > *aux_shape,
std::vector<std::vector<index_t> > *out_shape) const;
/*!
* \brief List the arguments names.
*
Expand Down
30 changes: 15 additions & 15 deletions cpp-package/include/mxnet-cpp/symbol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,14 @@ inline std::string Symbol::GetName() const {
}

inline void Symbol::InferShape(
const std::map<std::string, std::vector<mx_uint> > &arg_shapes,
std::vector<std::vector<mx_uint> > *in_shape,
std::vector<std::vector<mx_uint> > *aux_shape,
std::vector<std::vector<mx_uint> > *out_shape) const {
const std::map<std::string, std::vector<index_t> > &arg_shapes,
std::vector<std::vector<index_t> > *in_shape,
std::vector<std::vector<index_t> > *aux_shape,
std::vector<std::vector<index_t> > *out_shape) const {

std::vector<const char *> keys;
std::vector<mx_uint> arg_ind_ptr;
std::vector<mx_uint> arg_shape_data;
std::vector<index_t> arg_shape_data;

for (const auto &arg : arg_shapes) {
keys.push_back(arg.first.c_str());
Expand All @@ -201,13 +201,13 @@ inline void Symbol::InferShape(

mx_uint in_shape_size;
const mx_uint *in_shape_ndim;
const mx_uint **in_shape_data;
const index_t **in_shape_data;
mx_uint out_shape_size;
const mx_uint *out_shape_ndim;
const mx_uint **out_shape_data;
const index_t **out_shape_data;
mx_uint aux_shape_size;
const mx_uint *aux_shape_ndim;
const mx_uint **aux_shape_data;
const index_t **aux_shape_data;
int complete;

CHECK_EQ(MXSymbolInferShape(GetHandle(), keys.size(), keys.data(),
Expand All @@ -220,19 +220,19 @@ inline void Symbol::InferShape(

if (complete) {
for (mx_uint i = 0; i < in_shape_size; ++i) {
in_shape->push_back(std::vector<mx_uint>());
in_shape->push_back(std::vector<index_t>());
for (mx_uint j = 0; j < in_shape_ndim[i]; ++j) {
(*in_shape)[i].push_back(in_shape_data[i][j]);
}
}
for (mx_uint i = 0; i < aux_shape_size; ++i) {
aux_shape->push_back(std::vector<mx_uint>());
aux_shape->push_back(std::vector<index_t>());
for (mx_uint j = 0; j < aux_shape_ndim[i]; ++j) {
(*aux_shape)[i].push_back(aux_shape_data[i][j]);
}
}
for (mx_uint i = 0; i < out_shape_size; ++i) {
out_shape->push_back(std::vector<mx_uint>());
out_shape->push_back(std::vector<index_t>());
for (mx_uint j = 0; j < out_shape_ndim[i]; ++j) {
(*out_shape)[i].push_back(out_shape_data[i][j]);
}
Expand All @@ -250,8 +250,8 @@ inline void Symbol::InferExecutorArrays(
const std::map<std::string, NDArray> &aux_map) const {

const auto arg_name_list = ListArguments();
std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes;
std::map<std::string, std::vector<mx_uint> > arg_shapes;
std::vector<std::vector<index_t> > in_shapes, aux_shapes, out_shapes;
std::map<std::string, std::vector<index_t> > arg_shapes;

for (const auto &arg_name : arg_name_list) {
auto iter = args_map.find(arg_name);
Expand Down Expand Up @@ -307,8 +307,8 @@ inline void Symbol::InferArgsMap(
const std::map<std::string, NDArray> &known_args) const {

const auto arg_name_list = ListArguments();
std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes;
std::map<std::string, std::vector<mx_uint> > arg_shapes;
std::vector<std::vector<index_t> > in_shapes, aux_shapes, out_shapes;
std::map<std::string, std::vector<index_t> > arg_shapes;

for (const auto &arg_name : arg_name_list) {
auto iter = known_args.find(arg_name);
Expand Down
34 changes: 17 additions & 17 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ MXNET_DLL int MXNDArrayCreateNone(NDArrayHandle *out);
* \param out the returning handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayCreate(const mx_uint *shape,
MXNET_DLL int MXNDArrayCreate(const dim_t *shape,
mx_uint ndim,
int dev_type,
int dev_id,
Expand All @@ -506,7 +506,7 @@ MXNET_DLL int MXNDArrayCreate(const mx_uint *shape,
* \param out the returning handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
MXNET_DLL int MXNDArrayCreateEx(const dim_t *shape,
mx_uint ndim,
int dev_type,
int dev_id,
Expand All @@ -533,7 +533,7 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
const mx_uint *shape,
const dim_t *shape,
mx_uint ndim,
int dev_type,
int dev_id,
Expand All @@ -542,7 +542,7 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
mx_uint num_aux,
int *aux_type,
mx_uint *aux_ndims,
const mx_uint *aux_shape,
const dim_t *aux_shape,
NDArrayHandle *out);

/*!
Expand Down Expand Up @@ -650,7 +650,7 @@ MXNET_DLL int MXNDArraySyncCopyToCPU(NDArrayHandle handle,
*/
MXNET_DLL int MXNDArraySyncCopyFromNDArray(NDArrayHandle handle_dst,
const NDArrayHandle handle_src,
const int i);
const dim_t i);

/*!
* \brief check whether the NDArray format is valid
Expand Down Expand Up @@ -693,8 +693,8 @@ MXNET_DLL int MXNDArrayFree(NDArrayHandle handle);
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
mx_uint slice_begin,
mx_uint slice_end,
dim_t slice_begin,
dim_t slice_end,
NDArrayHandle *out);

/*!
Expand All @@ -705,7 +705,7 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayAt(NDArrayHandle handle,
mx_uint idx,
dim_t idx,
NDArrayHandle *out);

/*!
Expand Down Expand Up @@ -749,7 +749,7 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
*/
MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata);
const dim_t **out_pdata);
/*!
* \brief get the content of the data in NDArray
* \param handle the handle to the ndarray
Expand Down Expand Up @@ -1466,16 +1466,16 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_uint *arg_ind_ptr,
const mx_uint *arg_shape_data,
const dim_t *arg_shape_data,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
const dim_t ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
const dim_t ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
const dim_t ***aux_shape_data,
int *complete);
/*!
* \brief partially infer shape of unknown input shapes given the known one.
Expand Down Expand Up @@ -1505,16 +1505,16 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_uint *arg_ind_ptr,
const mx_uint *arg_shape_data,
const dim_t *arg_shape_data,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
const dim_t ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
const dim_t ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
const dim_t ***aux_shape_data,
int *complete);

/*!
Expand Down
2 changes: 1 addition & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ class NDArray {
/*!
* \brief Copy from src.data()/aux_data(i) to this->data()/aux_data(j)
*/
void SyncCopyFromNDArray(const NDArray &src, int i = -1, int j = -1);
void SyncCopyFromNDArray(const NDArray &src, index_t i = -1, index_t j = -1);

/*!
* \brief Do a synchronize copy to a continugous CPU memory region.
Expand Down
1 change: 1 addition & 0 deletions julia/src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Base.show(io::IO, e::MXError) = print(io, e.msg)
# Common types used in MXNet API
################################################################################
const MX_uint = Cuint
const MX_long = Clonglong
const MX_float = Cfloat
const MX_handle = Ptr{Void}

Expand Down
8 changes: 4 additions & 4 deletions julia/src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ See also the notes on NDArray shapes [`NDArray`](@ref).
"""
function size(x::NDArray)
ref_ndim = Ref{MX_uint}(0)
ref_shape = Ref{Ptr{MX_uint}}(0)
@mxcall(:MXNDArrayGetShape, (MX_handle, Ref{MX_uint}, Ref{Ptr{MX_uint}}),
ref_shape = Ref{Ptr{MX_long}}(0)
@mxcall(:MXNDArrayGetShape, (MX_handle, Ref{MX_uint}, Ref{Ptr{MX_long}}),
x, ref_ndim, ref_shape)
tuple(map(Int, flipdim(unsafe_wrap(Array, ref_shape[], ref_ndim[]),1))...)
end
Expand Down Expand Up @@ -278,8 +278,8 @@ ndims(x::NDArray) = ndims(x.handle)

function ndims(x::MX_NDArrayHandle)::Int
ref_ndim = Ref{MX_uint}(0)
ref_shape = Ref{Ptr{MX_uint}}(0)
@mxcall(:MXNDArrayGetShape, (MX_handle, Ref{MX_uint}, Ref{Ptr{MX_uint}}),
ref_shape = Ref{Ptr{MX_long}}(0)
@mxcall(:MXNDArrayGetShape, (MX_handle, Ref{MX_uint}, Ref{Ptr{MX_long}}),
x, ref_ndim, ref_shape)
ref_ndim[]
end
Expand Down
Loading