Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[numpy] Fix cpp-package build after using new shape definition #14554

Merged
merged 1 commit into from
Apr 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,11 @@ inline size_t NDArray::Size() const {
}

inline std::vector<mx_uint> NDArray::GetShape() const {
const mx_uint *out_pdata;
mx_uint out_dim;
const int *out_pdata;
int out_dim;
MXNDArrayGetShape(blob_ptr_->handle_, &out_dim, &out_pdata);
std::vector<mx_uint> ret;
for (mx_uint i = 0; i < out_dim; ++i) {
for (int i = 0; i < out_dim; ++i) {
ret.push_back(out_pdata[i]);
}
return ret;
Expand Down
20 changes: 10 additions & 10 deletions cpp-package/include/mxnet-cpp/symbol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ inline void Symbol::InferShape(

std::vector<const char *> keys;
std::vector<mx_uint> arg_ind_ptr;
std::vector<mx_uint> arg_shape_data;
std::vector<int> arg_shape_data;

for (const auto &arg : arg_shapes) {
keys.push_back(arg.first.c_str());
Expand All @@ -200,14 +200,14 @@ inline void Symbol::InferShape(
arg_ind_ptr.push_back(arg_shape_data.size());

mx_uint in_shape_size;
const mx_uint *in_shape_ndim;
const mx_uint **in_shape_data;
const int*in_shape_ndim;
const int **in_shape_data;
mx_uint out_shape_size;
const mx_uint *out_shape_ndim;
const mx_uint **out_shape_data;
const int *out_shape_ndim;
const int **out_shape_data;
mx_uint aux_shape_size;
const mx_uint *aux_shape_ndim;
const mx_uint **aux_shape_data;
const int *aux_shape_ndim;
const int **aux_shape_data;
int complete;

CHECK_EQ(MXSymbolInferShape(GetHandle(), keys.size(), keys.data(),
Expand All @@ -221,19 +221,19 @@ inline void Symbol::InferShape(
if (complete) {
for (mx_uint i = 0; i < in_shape_size; ++i) {
in_shape->push_back(std::vector<mx_uint>());
for (mx_uint j = 0; j < in_shape_ndim[i]; ++j) {
for (int j = 0; j < in_shape_ndim[i]; ++j) {
(*in_shape)[i].push_back(in_shape_data[i][j]);
}
}
for (mx_uint i = 0; i < aux_shape_size; ++i) {
aux_shape->push_back(std::vector<mx_uint>());
for (mx_uint j = 0; j < aux_shape_ndim[i]; ++j) {
for (int j = 0; j < aux_shape_ndim[i]; ++j) {
(*aux_shape)[i].push_back(aux_shape_data[i][j]);
}
}
for (mx_uint i = 0; i < out_shape_size; ++i) {
out_shape->push_back(std::vector<mx_uint>());
for (mx_uint j = 0; j < out_shape_ndim[i]; ++j) {
for (int j = 0; j < out_shape_ndim[i]; ++j) {
(*out_shape)[i].push_back(out_shape_data[i][j]);
}
}
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/tensor_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ class TBlob {
namespace dmlc {
// Add a few patches to support mxnet::TShape in dmlc/parameter.
DMLC_DECLARE_TYPE_NAME(mxnet::TShape, "Shape(tuple)");
DMLC_DECLARE_TYPE_NAME(mxnet::Tuple<int>, "Shape(tuple)");
DMLC_DECLARE_TYPE_NAME(mxnet::Tuple<dmlc::optional<int>>, "Shape(tuple)");
DMLC_DECLARE_TYPE_NAME(nnvm::Tuple<int>, "Shape(tuple)");
DMLC_DECLARE_TYPE_NAME(nnvm::Tuple<dmlc::optional<int>>, "Shape(tuple)");

Expand Down
6 changes: 5 additions & 1 deletion include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,10 @@ class Tuple {
*/
friend std::ostream &operator<<(std::ostream &os, const Tuple<ValueType> &t) {
if (t.ndim() == -1) {
os << "UNKNOWN_SHAPE";
// If t is an unknown shape, return string "None".
// This is consistent with returning unknown shape in Python and generating
// C++ operator APIs by OpWrapperGenerator.py (defaultString) in cpp-package.
os << "None";
return os;
}
os << '[';
Expand Down Expand Up @@ -727,6 +730,7 @@ struct hash<mxnet::TShape> {
namespace dmlc {
/*! \brief description for optional TShape */
DMLC_DECLARE_TYPE_NAME(optional<mxnet::TShape>, "Shape or None");
DMLC_DECLARE_TYPE_NAME(optional<mxnet::Tuple<int>>, "Shape or None");
// avoid low version of MSVC
#if !defined(_MSC_VER)
template<typename T>
Expand Down