diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 2232ebe7be40..13fb42ce521e 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -956,7 +956,7 @@ class NDArray { /*! \brief set the shape for ith aux data, and update storage shape if necessary */ inline void set_aux_shape(const size_t i, const mxnet::TShape& shape) { aux_shapes[i] = shape; - if (storage_shape.ndim() > 0) { + if (storage_shape.ndim() >= 0) { if (storage_type == kRowSparseStorage && i == rowsparse::kIdx) { storage_shape[0] = shape[0]; } else if (storage_type == kCSRStorage && i == csr::kIdx) { diff --git a/src/c_api/c_predict_api.cc b/src/c_api/c_predict_api.cc index 3b9f43d86079..7de23ef935ef 100644 --- a/src/c_api/c_predict_api.cc +++ b/src/c_api/c_predict_api.cc @@ -436,6 +436,7 @@ int MXPredGetOutputShape(PredictorHandle handle, << "Index exceed number of outputs"; const mxnet::TShape& s = p->out_shapes[out_index]; + CHECK_GE(s.ndim(), 0); p->out_shapes_buffer.resize(s.ndim()); nnvm::ShapeTypeCast(s.begin(), s.end(), p->out_shapes_buffer.data()); *shape_data = p->out_shapes_buffer.data(); diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 6864428d2559..d058df4b3806 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -31,6 +31,7 @@ #include "../common/utils.h" #include "../common/exec_utils.h" #include "../operator/nn/mkldnn/mkldnn_base-inl.h" +#include "../operator/operator_common.h" #ifndef MXNET_IMPERATIVE_IMPERATIVE_UTILS_H_ #define MXNET_IMPERATIVE_IMPERATIVE_UTILS_H_ @@ -196,7 +197,7 @@ inline void SetShapeType(const Context& ctx, for (size_t i = 0; i < outputs.size(); ++i) { NDArrayStorageType storage_type = static_cast(out_storage_types[i]); - if (outputs[i]->is_none() || outputs[i]->shape().ndim() == 0) { + if (outputs[i]->is_none() || mxnet::op::shape_is_none(outputs[i]->shape())) { if (is_dynamic_shape_existing) { // once there is dynamic shape somewhere, we could not pre-determine the shape. *outputs[i] = NDArray(ctx, out_types[i]); diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc index e4a06fa9a1f2..30aaec91e27f 100644 --- a/src/kvstore/gradient_compression.cc +++ b/src/kvstore/gradient_compression.cc @@ -100,9 +100,9 @@ int64_t GradientCompression::GetCompressedSize(const int64_t original_size) { void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *to, mxnet::NDArray *residual, const int priority) { - CHECK(from.shape().ndim() != 0) << "source operand has zero dimension shape"; - CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape"; - CHECK(residual->shape().ndim() != 0) << "residual operand has zero dimension shape"; + CHECK(shape_is_known(from.shape())) << "source operand has undefined shape"; + CHECK(shape_is_known(to->shape())) << "destination operand has undefined shape"; + CHECK(shape_is_known(residual->shape())) << "residual operand has undefined shape"; const int a = from.ctx().dev_mask(); const int b = to->ctx().dev_mask(); const float threshold = threshold_; @@ -137,8 +137,8 @@ void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *t void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to, const int priority) { - CHECK(from.shape().ndim() != 0) << "source operands has zero dimension shape"; - CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape"; + CHECK(shape_is_known(from.shape())) << "source operand has undefined shape"; + CHECK(shape_is_known(to->shape())) << "destination operand has undefined shape"; const int a = from.ctx().dev_mask(); const int b = to->ctx().dev_mask(); const float threshold = threshold_; diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 04518e0feb77..604000028bf1 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1191,8 +1191,8 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op CHECK(from.shape() == to.shape()) << "operands shape mismatch" << "from.shape = " << from.shape() << " to.shape=" << to.shape(); - CHECK(from.shape().ndim() != 0) - << "source operands have zero dimension shape"; + CHECK(!mxnet::op::shape_is_none(from.shape())) + << "source operands have undefined shape"; // important: callback must always capture by value const Context from_ctx = from.ctx(); const int a = from_ctx.dev_mask(); @@ -1663,7 +1663,7 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) { // load shape mxnet::TShape shape; if (!LegacyTShapeLoad(strm, &shape, magic)) return false; - if (shape.ndim() == 0) { + if (mxnet::op::shape_is_none(shape)) { *this = NDArray(); return true; } // load context @@ -1711,7 +1711,10 @@ bool NDArray::Load(dmlc::Stream *strm) { // load shape mxnet::TShape shape; if (!shape.Load(strm)) return false; - if (shape.ndim() == 0) { + if (!Imperative::Get()->is_np_comp()) { + common::ConvertToNumpyShape(&shape); + } + if (mxnet::op::shape_is_none(shape)) { *this = NDArray(); return true; } diff --git a/src/ndarray/ndarray_function.h b/src/ndarray/ndarray_function.h index 70b626dbb9b7..505bd205a8d5 100644 --- a/src/ndarray/ndarray_function.h +++ b/src/ndarray/ndarray_function.h @@ -40,7 +40,7 @@ namespace ndarray { struct BinaryBase { inline static mxnet::TShape GetShape(const mxnet::TShape &lshape, const mxnet::TShape &rshape) { CHECK(lshape == rshape) << "operands shape mismatch"; - CHECK(lshape.ndim() != 0) << "source operand have zero dimension shape"; + CHECK(!mxnet::op::shape_is_none(lshape)) << "source operand have zero dimension shape"; return lshape; } };