From d6dff415287bf0fffb263898bc829be58037d92f Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Thu, 16 Jan 2020 21:05:05 +0000 Subject: [PATCH] np.broadcast_to extension --- .../numpy/np_broadcast_reduce_op_value.cc | 14 ++++++---- tests/python/unittest/test_numpy_op.py | 27 +++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index 4bc8e6737bcf..8b6e8b7fc775 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -467,17 +467,21 @@ bool NumpyBroadcastToShape(const nnvm::NodeAttrs& attrs, mxnet::TShape& ishape = (*in_attrs)[0]; if (!mxnet::shape_is_known(ishape)) return false; const BroadcastToParam& param = nnvm::get(attrs.parsed); - CHECK(mxnet::shape_is_known(param.shape)) - << "the objective shape for broadcasting array must be known"; CHECK_LE(ishape.ndim(), param.shape.ndim()) << "shape " << ishape << " is not broadcastable to " << param.shape; + TShape pshape = param.shape; for (int i = param.shape.ndim() - 1; i >= 0; --i) { int j = i - param.shape.ndim() + ishape.ndim(); if (j < 0) break; - CHECK(ishape[j] == param.shape[i] || ishape[j] == 1) - << "shape " << ishape << " is not broadcastable to " << param.shape; + if (pshape[i] == -2) { + pshape[i] = ishape[j]; + } + CHECK(ishape[j] == pshape[i] || ishape[j] == 1) + << "shape " << ishape << " is not broadcastable to " << pshape; } - SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape); + CHECK(mxnet::shape_is_known(pshape)) + << "the objective shape for broadcasting array must be known"; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, pshape); return true; } diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index f0b62e280ab5..2c5120c11edb 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1559,6 +1559,7 @@ def hybrid_forward(self, F, x): ((4, 1), (1, 2, 3, 4, 5)), ((4, 1), (1, 0, 3, 4, 5)) ] + for src_shape, dst_shape in shapes: for hybridize in [True, False]: test_broadcast_to = TestBroadcastTo(dst_shape) @@ -1587,6 +1588,32 @@ def hybrid_forward(self, F, x): ret = test_scalar_broadcast_to(np.empty(())) assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-5, atol=1e-6, use_broadcast=False) + # Test npx functionality + shapes = [ + ((5,), (3, 4, -2), (3, 4, 5)), + ((5,), (0, -2), (0, 5)), + ((1, 0), (2, -2, -2), (2, 1, 0)), + ((3, 4), (1, 2, 3, -2), (1, 2, 3, 4)), + ((3, 4), (1, 0, -2, 4), (1, 0, 3, 4)) + ] + + for src_shape, npx_dst_shape, np_dst_shape in shapes: + for hybridize in [True, False]: + test_broadcast_to = TestBroadcastTo(npx_dst_shape) + if hybridize: + test_broadcast_to.hybridize() + + a = _np.random.uniform(size=src_shape).astype(np.float32) + expected_ret = _np.broadcast_to(a, np_dst_shape) + a_mx = np.array(a, dtype=a.dtype) + a_mx.attach_grad() + with mx.autograd.record(): + ret = test_broadcast_to(a_mx) + assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-5, atol=1e-6, use_broadcast=False) + ret.backward() + expected_grad = collapse_sum_like(_np.ones_like(expected_ret), src_shape) + assert_almost_equal(a_mx.grad.asnumpy(), expected_grad, rtol=1e-5, atol=1e-6, use_broadcast=False) + @with_seed() @use_np