Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Misc fix #14612

Merged
merged 5 commits into from
Apr 4, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ class NDArray {
/*! \brief set the shape for ith aux data, and update storage shape if necessary */
inline void set_aux_shape(const size_t i, const mxnet::TShape& shape) {
aux_shapes[i] = shape;
if (storage_shape.ndim() > 0) {
if (storage_shape.ndim() >= 0) {
if (storage_type == kRowSparseStorage && i == rowsparse::kIdx) {
storage_shape[0] = shape[0];
} else if (storage_type == kCSRStorage && i == csr::kIdx) {
Expand Down
1 change: 1 addition & 0 deletions src/c_api/c_predict_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ int MXPredGetOutputShape(PredictorHandle handle,
<< "Index exceed number of outputs";

const mxnet::TShape& s = p->out_shapes[out_index];
CHECK_GE(s.ndim(), 0);
p->out_shapes_buffer.resize(s.ndim());
nnvm::ShapeTypeCast(s.begin(), s.end(), p->out_shapes_buffer.data());
*shape_data = p->out_shapes_buffer.data();
Expand Down
3 changes: 2 additions & 1 deletion src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "../common/utils.h"
#include "../common/exec_utils.h"
#include "../operator/nn/mkldnn/mkldnn_base-inl.h"
#include "../operator/operator_common.h"

#ifndef MXNET_IMPERATIVE_IMPERATIVE_UTILS_H_
#define MXNET_IMPERATIVE_IMPERATIVE_UTILS_H_
Expand Down Expand Up @@ -196,7 +197,7 @@ inline void SetShapeType(const Context& ctx,

for (size_t i = 0; i < outputs.size(); ++i) {
NDArrayStorageType storage_type = static_cast<NDArrayStorageType>(out_storage_types[i]);
if (outputs[i]->is_none() || outputs[i]->shape().ndim() == 0) {
if (outputs[i]->is_none() || mxnet::op::shape_is_none(outputs[i]->shape())) {
if (is_dynamic_shape_existing) {
// once there is dynamic shape somewhere, we could not pre-determine the shape.
*outputs[i] = NDArray(ctx, out_types[i]);
Expand Down
10 changes: 5 additions & 5 deletions src/kvstore/gradient_compression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ int64_t GradientCompression::GetCompressedSize(const int64_t original_size) {

void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *to,
mxnet::NDArray *residual, const int priority) {
CHECK(from.shape().ndim() != 0) << "source operand has zero dimension shape";
CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape";
CHECK(residual->shape().ndim() != 0) << "residual operand has zero dimension shape";
CHECK(shape_is_known(from.shape())) << "source operand has undefined shape";
CHECK(shape_is_known(to->shape())) << "destination operand has undefined shape";
CHECK(shape_is_known(residual->shape())) << "residual operand has undefined shape";
const int a = from.ctx().dev_mask();
const int b = to->ctx().dev_mask();
const float threshold = threshold_;
Expand Down Expand Up @@ -137,8 +137,8 @@ void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *t

void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to,
const int priority) {
CHECK(from.shape().ndim() != 0) << "source operands has zero dimension shape";
CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape";
CHECK(shape_is_known(from.shape())) << "source operand has undefined shape";
CHECK(shape_is_known(to->shape())) << "destination operand has undefined shape";
const int a = from.ctx().dev_mask();
const int b = to->ctx().dev_mask();
const float threshold = threshold_;
Expand Down
11 changes: 7 additions & 4 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1191,8 +1191,8 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op
CHECK(from.shape() == to.shape())
<< "operands shape mismatch"
<< "from.shape = " << from.shape() << " to.shape=" << to.shape();
CHECK(from.shape().ndim() != 0)
<< "source operands have zero dimension shape";
CHECK(!mxnet::op::shape_is_none(from.shape()))
<< "source operands have undefined shape";
// important: callback must always capture by value
const Context from_ctx = from.ctx();
const int a = from_ctx.dev_mask();
Expand Down Expand Up @@ -1663,7 +1663,7 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) {
// load shape
mxnet::TShape shape;
if (!LegacyTShapeLoad(strm, &shape, magic)) return false;
if (shape.ndim() == 0) {
if (mxnet::op::shape_is_none(shape)) {
junrushao marked this conversation as resolved.
Show resolved Hide resolved
*this = NDArray(); return true;
}
// load context
Expand Down Expand Up @@ -1711,7 +1711,10 @@ bool NDArray::Load(dmlc::Stream *strm) {
// load shape
mxnet::TShape shape;
if (!shape.Load(strm)) return false;
if (shape.ndim() == 0) {
if (Imperative::Get()->is_np_comp()) {
junrushao marked this conversation as resolved.
Show resolved Hide resolved
common::ConvertToNumpyShape(&shape);
}
if (mxnet::op::shape_is_none(shape)) {
junrushao marked this conversation as resolved.
Show resolved Hide resolved
*this = NDArray(); return true;
}

Expand Down
2 changes: 1 addition & 1 deletion src/ndarray/ndarray_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace ndarray {
struct BinaryBase {
inline static mxnet::TShape GetShape(const mxnet::TShape &lshape, const mxnet::TShape &rshape) {
CHECK(lshape == rshape) << "operands shape mismatch";
CHECK(lshape.ndim() != 0) << "source operand have zero dimension shape";
CHECK(!mxnet::op::shape_is_none(lshape)) << "source operand have zero dimension shape";
return lshape;
}
};
Expand Down