From 587e9825d2c9c3ba901635419488514dbe86789f Mon Sep 17 00:00:00 2001 From: reminisce Date: Wed, 3 Apr 2019 13:40:11 -0700 Subject: [PATCH] Fix cpp package build after using new shape definition (#14554) --- cpp-package/include/mxnet-cpp/ndarray.hpp | 6 +++--- cpp-package/include/mxnet-cpp/symbol.hpp | 20 ++++++++++---------- include/mxnet/tensor_blob.h | 2 ++ include/mxnet/tuple.h | 6 +++++- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp b/cpp-package/include/mxnet-cpp/ndarray.hpp index b667542bffb5..bf1d82ca33b4 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.hpp +++ b/cpp-package/include/mxnet-cpp/ndarray.hpp @@ -397,11 +397,11 @@ inline size_t NDArray::Size() const { } inline std::vector NDArray::GetShape() const { - const mx_uint *out_pdata; - mx_uint out_dim; + const int *out_pdata; + int out_dim; MXNDArrayGetShape(blob_ptr_->handle_, &out_dim, &out_pdata); std::vector ret; - for (mx_uint i = 0; i < out_dim; ++i) { + for (int i = 0; i < out_dim; ++i) { ret.push_back(out_pdata[i]); } return ret; diff --git a/cpp-package/include/mxnet-cpp/symbol.hpp b/cpp-package/include/mxnet-cpp/symbol.hpp index aed963949060..d82b7abaf614 100644 --- a/cpp-package/include/mxnet-cpp/symbol.hpp +++ b/cpp-package/include/mxnet-cpp/symbol.hpp @@ -188,7 +188,7 @@ inline void Symbol::InferShape( std::vector keys; std::vector arg_ind_ptr; - std::vector arg_shape_data; + std::vector arg_shape_data; for (const auto &arg : arg_shapes) { keys.push_back(arg.first.c_str()); @@ -200,14 +200,14 @@ inline void Symbol::InferShape( arg_ind_ptr.push_back(arg_shape_data.size()); mx_uint in_shape_size; - const mx_uint *in_shape_ndim; - const mx_uint **in_shape_data; + const int*in_shape_ndim; + const int **in_shape_data; mx_uint out_shape_size; - const mx_uint *out_shape_ndim; - const mx_uint **out_shape_data; + const int *out_shape_ndim; + const int **out_shape_data; mx_uint aux_shape_size; - const mx_uint *aux_shape_ndim; - const mx_uint **aux_shape_data; + const int *aux_shape_ndim; + const int **aux_shape_data; int complete; CHECK_EQ(MXSymbolInferShape(GetHandle(), keys.size(), keys.data(), @@ -221,19 +221,19 @@ inline void Symbol::InferShape( if (complete) { for (mx_uint i = 0; i < in_shape_size; ++i) { in_shape->push_back(std::vector()); - for (mx_uint j = 0; j < in_shape_ndim[i]; ++j) { + for (int j = 0; j < in_shape_ndim[i]; ++j) { (*in_shape)[i].push_back(in_shape_data[i][j]); } } for (mx_uint i = 0; i < aux_shape_size; ++i) { aux_shape->push_back(std::vector()); - for (mx_uint j = 0; j < aux_shape_ndim[i]; ++j) { + for (int j = 0; j < aux_shape_ndim[i]; ++j) { (*aux_shape)[i].push_back(aux_shape_data[i][j]); } } for (mx_uint i = 0; i < out_shape_size; ++i) { out_shape->push_back(std::vector()); - for (mx_uint j = 0; j < out_shape_ndim[i]; ++j) { + for (int j = 0; j < out_shape_ndim[i]; ++j) { (*out_shape)[i].push_back(out_shape_data[i][j]); } } diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h index 45d4c7fda639..a7a57266dab8 100755 --- a/include/mxnet/tensor_blob.h +++ b/include/mxnet/tensor_blob.h @@ -418,6 +418,8 @@ class TBlob { namespace dmlc { // Add a few patches to support mxnet::TShape in dmlc/parameter. DMLC_DECLARE_TYPE_NAME(mxnet::TShape, "Shape(tuple)"); +DMLC_DECLARE_TYPE_NAME(mxnet::Tuple, "Shape(tuple)"); +DMLC_DECLARE_TYPE_NAME(mxnet::Tuple>, "Shape(tuple)"); DMLC_DECLARE_TYPE_NAME(nnvm::Tuple, "Shape(tuple)"); DMLC_DECLARE_TYPE_NAME(nnvm::Tuple>, "Shape(tuple)"); diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h index d83e843033e3..c5a358628ccd 100644 --- a/include/mxnet/tuple.h +++ b/include/mxnet/tuple.h @@ -236,7 +236,10 @@ class Tuple { */ friend std::ostream &operator<<(std::ostream &os, const Tuple &t) { if (t.ndim() == -1) { - os << "UNKNOWN_SHAPE"; + // If t is an unknown shape, return string "None". + // This is consistent with returning unknown shape in Python and generating + // C++ operator APIs by OpWrapperGenerator.py (defaultString) in cpp-package. + os << "None"; return os; } os << '['; @@ -727,6 +730,7 @@ struct hash { namespace dmlc { /*! \brief description for optional TShape */ DMLC_DECLARE_TYPE_NAME(optional, "Shape or None"); +DMLC_DECLARE_TYPE_NAME(optional>, "Shape or None"); // avoid low version of MSVC #if !defined(_MSC_VER) template