Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[numpy] Support zero-dim and zero-size tensors in MXNet #14661

Merged
merged 32 commits into from
Apr 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
708c933
[numpy] Shape support scalar tensor (#14315)
reminisce Mar 6, 2019
a297a9a
[Numpy] Change semantics of ndim for operators in `src/operator/contr…
junrushao Mar 15, 2019
1179a59
[WIP] Use new shape definition (#14453)
reminisce Mar 18, 2019
3449ef8
[numpy] Fix unit tests after introducing numpy compatible shapes (#14…
reminisce Mar 22, 2019
098d189
Fix a bug to pass the test in test_contrib_rnn (#14520)
zheng-da Mar 26, 2019
9c77e3e
[numpy] Fix test_dynamic_shape.test_dynamic_shape (#14538)
junrushao Mar 27, 2019
ebc7d4d
[numpy] Fix numpy import in python2 (#14537)
reminisce Mar 27, 2019
c63555e
fix concat and slice (#14549)
TaoLv Mar 29, 2019
48cf659
fix R-package (#14536)
hetong007 Apr 3, 2019
b165b35
Fix cpp package build after using new shape definition (#14554)
reminisce Apr 3, 2019
bff0bdc
Fix pooling_v1 and deformable_convolution param initialization (#14577)
reminisce Apr 4, 2019
ebecad1
[Numpy] Misc fix (#14612)
junrushao Apr 4, 2019
4c35ade
[Numpy] fix test_operator_gpu.test_upsampling_bilinear_with_type (#14…
junrushao Apr 4, 2019
cdc9023
[Numpy] Java/Scala modification (#14625)
yzhliu Apr 5, 2019
c7f4ebd
fix shape index bug (#14630)
eric-haibin-lin Apr 5, 2019
c526ac4
fix jni lint (#14634)
yzhliu Apr 5, 2019
d12a1fa
[numpy] Fix numpy branch failing tests in CI (#14639)
reminisce Apr 8, 2019
7a83953
fix invalid ndarray dispose (#14657)
yzhliu Apr 10, 2019
3eab56a
swig fixes for the changes in c_api.h (#14655)
sergeykolychev Apr 10, 2019
4b2244a
Rename np_comp to np_compat for readability
reminisce Apr 10, 2019
04ad087
Fix import error
reminisce Apr 10, 2019
8a023eb
Keep old c apis unchanged
reminisce Apr 11, 2019
c8d2cce
Fix lint
reminisce Apr 11, 2019
a4a2841
Rebase and fix build
reminisce Apr 11, 2019
a07bf84
Fix R build failure
reminisce Apr 12, 2019
c204cd5
Fix Perl build failure
reminisce Apr 12, 2019
c876bc2
Rebase with master
reminisce Apr 13, 2019
b864c4b
Address cr comments
reminisce Apr 15, 2019
69791f1
Use just one scope to represent numpy compatibility
reminisce Apr 15, 2019
ab08f35
Add code comment to NumpyScope object in Scala
reminisce Apr 15, 2019
b0cb2f1
Add use_np_compat decorator
reminisce Apr 15, 2019
bd7c4fb
Fix pylint
reminisce Apr 15, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions R-package/src/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ Rcpp::RObject NDArrayPacker::CreateNDArrayPacker() {
}

Rcpp::Dimension NDArray::dim() const {
mx_uint ndim;
const mx_uint *pshape;
MX_CALL(MXNDArrayGetShape(
int ndim;
const int *pshape;
MX_CALL(MXNDArrayGetShapeEx(
ptr_->handle, &ndim, &pshape));
Rcpp::IntegerVector dat(pshape, pshape + ndim);
std::reverse(dat.begin(), dat.end());
Expand Down
20 changes: 10 additions & 10 deletions R-package/src/symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ Symbol::RObjectType Symbol::GetOutput(mx_uint index) const {

// helper function to convert shape into Rcpp vector
inline Rcpp::List BuildShapeData(mx_uint shape_size,
const mx_uint *shape_ndim,
const mx_uint **shape_data,
const int *shape_ndim,
const int **shape_data,
const std::vector<std::string> &names) {
Rcpp::List ret(shape_size);
for (mx_uint i = 0; i < shape_size; ++i) {
Expand All @@ -185,7 +185,7 @@ SEXP Symbol::InferShape(const Rcpp::List& kwargs) const {
<< "Need to pass parameters in key=value style.\n";
std::vector<std::string> keys = kwargs.names();
std::vector<mx_uint> arg_ind_ptr(1, 0);
std::vector<mx_uint> arg_shape_data;
std::vector<int> arg_shape_data;

for (size_t i = 0; i < kwargs.size(); ++i) {
RCHECK(keys[i].length() != 0)
Expand All @@ -197,17 +197,17 @@ SEXP Symbol::InferShape(const Rcpp::List& kwargs) const {
std::vector<const char*> c_keys = CKeys(keys);

mx_uint in_shape_size;
const mx_uint *in_shape_ndim;
const mx_uint **in_shape_data;
const int *in_shape_ndim;
const int **in_shape_data;
mx_uint out_shape_size;
const mx_uint *out_shape_ndim;
const mx_uint **out_shape_data;
const int *out_shape_ndim;
const int **out_shape_data;
mx_uint aux_shape_size;
const mx_uint *aux_shape_ndim;
const mx_uint **aux_shape_data;
const int *aux_shape_ndim;
const int **aux_shape_data;
int complete;

MX_CALL(MXSymbolInferShape(
MX_CALL(MXSymbolInferShapeEx(
handle_, static_cast<mx_uint>(kwargs.size()), dmlc::BeginPtr(c_keys),
dmlc::BeginPtr(arg_ind_ptr), dmlc::BeginPtr(arg_shape_data),
&in_shape_size, &in_shape_ndim, &in_shape_data,
Expand Down
8 changes: 4 additions & 4 deletions cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,11 @@ inline size_t NDArray::Size() const {
}

inline std::vector<mx_uint> NDArray::GetShape() const {
const mx_uint *out_pdata;
mx_uint out_dim;
MXNDArrayGetShape(blob_ptr_->handle_, &out_dim, &out_pdata);
const int *out_pdata;
int out_dim;
MXNDArrayGetShapeEx(blob_ptr_->handle_, &out_dim, &out_pdata);
std::vector<mx_uint> ret;
for (mx_uint i = 0; i < out_dim; ++i) {
for (int i = 0; i < out_dim; ++i) {
ret.push_back(out_pdata[i]);
}
return ret;
Expand Down
32 changes: 16 additions & 16 deletions cpp-package/include/mxnet-cpp/symbol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ inline void Symbol::InferShape(

std::vector<const char *> keys;
std::vector<mx_uint> arg_ind_ptr;
std::vector<mx_uint> arg_shape_data;
std::vector<int> arg_shape_data;

for (const auto &arg : arg_shapes) {
keys.push_back(arg.first.c_str());
Expand All @@ -200,40 +200,40 @@ inline void Symbol::InferShape(
arg_ind_ptr.push_back(arg_shape_data.size());

mx_uint in_shape_size;
const mx_uint *in_shape_ndim;
const mx_uint **in_shape_data;
const int *in_shape_ndim;
const int **in_shape_data;
mx_uint out_shape_size;
const mx_uint *out_shape_ndim;
const mx_uint **out_shape_data;
const int *out_shape_ndim;
const int **out_shape_data;
mx_uint aux_shape_size;
const mx_uint *aux_shape_ndim;
const mx_uint **aux_shape_data;
const int *aux_shape_ndim;
const int **aux_shape_data;
int complete;

CHECK_EQ(MXSymbolInferShape(GetHandle(), keys.size(), keys.data(),
arg_ind_ptr.data(), arg_shape_data.data(),
&in_shape_size, &in_shape_ndim, &in_shape_data,
&out_shape_size, &out_shape_ndim, &out_shape_data,
&aux_shape_size, &aux_shape_ndim, &aux_shape_data,
&complete),
CHECK_EQ(MXSymbolInferShapeEx(GetHandle(), keys.size(), keys.data(),
arg_ind_ptr.data(), arg_shape_data.data(),
&in_shape_size, &in_shape_ndim, &in_shape_data,
&out_shape_size, &out_shape_ndim, &out_shape_data,
&aux_shape_size, &aux_shape_ndim, &aux_shape_data,
&complete),
0);

if (complete) {
for (mx_uint i = 0; i < in_shape_size; ++i) {
in_shape->push_back(std::vector<mx_uint>());
for (mx_uint j = 0; j < in_shape_ndim[i]; ++j) {
for (int j = 0; j < in_shape_ndim[i]; ++j) {
(*in_shape)[i].push_back(in_shape_data[i][j]);
}
}
for (mx_uint i = 0; i < aux_shape_size; ++i) {
aux_shape->push_back(std::vector<mx_uint>());
for (mx_uint j = 0; j < aux_shape_ndim[i]; ++j) {
for (int j = 0; j < aux_shape_ndim[i]; ++j) {
(*aux_shape)[i].push_back(aux_shape_data[i][j]);
}
}
for (mx_uint i = 0; i < out_shape_size; ++i) {
out_shape->push_back(std::vector<mx_uint>());
for (mx_uint j = 0; j < out_shape_ndim[i]; ++j) {
for (int j = 0; j < out_shape_ndim[i]; ++j) {
(*out_shape)[i].push_back(out_shape_data[i][j]);
}
}
Expand Down
Loading