From fcefc5ae313d07fe3e40e8128ae9d53957f0f110 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 26 Mar 2019 17:07:13 -0700 Subject: [PATCH] [numpy] Fix test_dynamic_shape.test_dynamic_shape (#14538) * Initial commit * Address comments from Jun --- src/c_api/c_api_common.h | 5 +++-- src/imperative/imperative_utils.cc | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index 329dc9adc7cf..55608b950866 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -91,8 +91,9 @@ struct MXAPIThreadLocalEntry { data->resize(shapes.size()); size_t size = 0; for (const auto& s : shapes) { - if (s.ndim() > 0); - size += s.ndim(); + if (s.ndim() > 0) { + size += s.ndim(); + } } buffer->resize(size); int *ptr = buffer->data(); diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc index 6cb4a70324b5..733a47bfe6c1 100644 --- a/src/imperative/imperative_utils.cc +++ b/src/imperative/imperative_utils.cc @@ -19,6 +19,7 @@ #include "./imperative_utils.h" #include "./cached_op.h" +#include "../operator/operator_common.h" namespace mxnet { namespace imperative { @@ -186,7 +187,7 @@ void NaiveRunGraph( Imperative::Get()->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, state); for (size_t j = 0; j < ndoutputs.size(); ++j) { - if (ndoutputs[j]->shape().ndim() == 0) { + if (mxnet::op::shape_is_none(ndoutputs[j]->shape())) { ndoutputs[j]->WaitToRead(); ndoutputs[j]->SetShapeFromChunk(); }