Skip to content

Commit

Permalink
Fix cpp package build after using new shape definition (apache#14554)
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Apr 15, 2019
1 parent 48cf659 commit b165b35
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 14 deletions.
6 changes: 3 additions & 3 deletions cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,11 @@ inline size_t NDArray::Size() const {
}

inline std::vector<mx_uint> NDArray::GetShape() const {
const mx_uint *out_pdata;
mx_uint out_dim;
const int *out_pdata;
int out_dim;
MXNDArrayGetShape(blob_ptr_->handle_, &out_dim, &out_pdata);
std::vector<mx_uint> ret;
for (mx_uint i = 0; i < out_dim; ++i) {
for (int i = 0; i < out_dim; ++i) {
ret.push_back(out_pdata[i]);
}
return ret;
Expand Down
20 changes: 10 additions & 10 deletions cpp-package/include/mxnet-cpp/symbol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ inline void Symbol::InferShape(

std::vector<const char *> keys;
std::vector<mx_uint> arg_ind_ptr;
std::vector<mx_uint> arg_shape_data;
std::vector<int> arg_shape_data;

for (const auto &arg : arg_shapes) {
keys.push_back(arg.first.c_str());
Expand All @@ -200,14 +200,14 @@ inline void Symbol::InferShape(
arg_ind_ptr.push_back(arg_shape_data.size());

mx_uint in_shape_size;
const mx_uint *in_shape_ndim;
const mx_uint **in_shape_data;
const int*in_shape_ndim;
const int **in_shape_data;
mx_uint out_shape_size;
const mx_uint *out_shape_ndim;
const mx_uint **out_shape_data;
const int *out_shape_ndim;
const int **out_shape_data;
mx_uint aux_shape_size;
const mx_uint *aux_shape_ndim;
const mx_uint **aux_shape_data;
const int *aux_shape_ndim;
const int **aux_shape_data;
int complete;

CHECK_EQ(MXSymbolInferShape(GetHandle(), keys.size(), keys.data(),
Expand All @@ -221,19 +221,19 @@ inline void Symbol::InferShape(
if (complete) {
for (mx_uint i = 0; i < in_shape_size; ++i) {
in_shape->push_back(std::vector<mx_uint>());
for (mx_uint j = 0; j < in_shape_ndim[i]; ++j) {
for (int j = 0; j < in_shape_ndim[i]; ++j) {
(*in_shape)[i].push_back(in_shape_data[i][j]);
}
}
for (mx_uint i = 0; i < aux_shape_size; ++i) {
aux_shape->push_back(std::vector<mx_uint>());
for (mx_uint j = 0; j < aux_shape_ndim[i]; ++j) {
for (int j = 0; j < aux_shape_ndim[i]; ++j) {
(*aux_shape)[i].push_back(aux_shape_data[i][j]);
}
}
for (mx_uint i = 0; i < out_shape_size; ++i) {
out_shape->push_back(std::vector<mx_uint>());
for (mx_uint j = 0; j < out_shape_ndim[i]; ++j) {
for (int j = 0; j < out_shape_ndim[i]; ++j) {
(*out_shape)[i].push_back(out_shape_data[i][j]);
}
}
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/tensor_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ class TBlob {
namespace dmlc {
// Add a few patches to support mxnet::TShape in dmlc/parameter.
DMLC_DECLARE_TYPE_NAME(mxnet::TShape, "Shape(tuple)");
DMLC_DECLARE_TYPE_NAME(mxnet::Tuple<int>, "Shape(tuple)");
DMLC_DECLARE_TYPE_NAME(mxnet::Tuple<dmlc::optional<int>>, "Shape(tuple)");
DMLC_DECLARE_TYPE_NAME(nnvm::Tuple<int>, "Shape(tuple)");
DMLC_DECLARE_TYPE_NAME(nnvm::Tuple<dmlc::optional<int>>, "Shape(tuple)");

Expand Down
6 changes: 5 additions & 1 deletion include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,10 @@ class Tuple {
*/
friend std::ostream &operator<<(std::ostream &os, const Tuple<ValueType> &t) {
if (t.ndim() == -1) {
os << "UNKNOWN_SHAPE";
// If t is an unknown shape, return string "None".
// This is consistent with returning unknown shape in Python and generating
// C++ operator APIs by OpWrapperGenerator.py (defaultString) in cpp-package.
os << "None";
return os;
}
os << '[';
Expand Down Expand Up @@ -727,6 +730,7 @@ struct hash<mxnet::TShape> {
namespace dmlc {
/*! \brief description for optional TShape */
DMLC_DECLARE_TYPE_NAME(optional<mxnet::TShape>, "Shape or None");
DMLC_DECLARE_TYPE_NAME(optional<mxnet::Tuple<int>>, "Shape or None");
// avoid low version of MSVC
#if !defined(_MSC_VER)
template<typename T>
Expand Down

0 comments on commit b165b35

Please sign in to comment.