diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index cc6aedaac2a6..f52fa257374e 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -389,11 +389,38 @@ NDArray NDArray::At(index_t idx) const { NDArray NDArray::AtWithRecord(index_t idx) { CHECK(storage_type() == kDefaultStorage) << "Storage type " << storage_type() << " doesn't support At()"; - NDArray ret = this->SliceWithRecord(idx, idx+1); + NDArray sliced = this->SliceWithRecord(idx, idx+1); if (shape_.ndim() > 1 || Imperative::Get()->is_np_shape()) { - return ret.ReshapeWithRecord(mxnet::TShape(shape_.data()+1, shape_.data()+shape_.ndim())); + // Imperative reshape with concrete shape + NDArray reshaped = sliced.Reshape(mxnet::TShape(shape_.data()+1, shape_.data()+shape_.ndim())); + + // Record reshape with magic numbers + nnvm::NodeAttrs attrs; + std::ostringstream os; + if (!Imperative::Get()->is_np_shape()) { + os << mxnet::TShape({-3, -2}); // See ndarray.py reshape for definition of magic numbers + attrs.op = nnvm::Op::Get("Reshape");; + attrs.dict.insert({"shape", os.str()}); + } else { + // See NumpyXReshapeInferShape for definition of magic numbers + os << mxnet::TShape({-3, -4}); + attrs.op = nnvm::Op::Get("_npx_reshape");; + attrs.dict.insert({"newshape", os.str()}); + } + attrs.op->attr_parser(&attrs); + std::vector inputs(1, &sliced), outputs(1, &reshaped); + + bool is_recording = Imperative::Get()->is_recording(); + bool is_deferred_compute = Imperative::Get()->is_deferred_compute(); + if (is_recording) { + Imperative::Get()->RecordOp(std::move(attrs), inputs, outputs); + } else if (is_deferred_compute) { + Imperative::Get()->RecordDeferredCompute(std::move(attrs), inputs, outputs); + } + + return reshaped; } else { - return ret; + return sliced; } } diff --git a/tests/python/unittest/test_deferred_compute.py b/tests/python/unittest/test_deferred_compute.py index 28a347ce72a8..dc28dc86823e 100644 --- a/tests/python/unittest/test_deferred_compute.py +++ b/tests/python/unittest/test_deferred_compute.py @@ -588,5 +588,6 @@ def forward(self, x): try: mx.npx.set_np() net(mx.np.zeros((2, 2, 4, 0, 128))) + net(mx.np.zeros((2, 2, 4, 2, 128))) # test indexing after input shape change finally: mx.npx.reset_np()