Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

np.broadcast_to extension #17358

Merged
merged 1 commit into from
Jan 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,17 +467,21 @@ bool NumpyBroadcastToShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape& ishape = (*in_attrs)[0];
if (!mxnet::shape_is_known(ishape)) return false;
const BroadcastToParam& param = nnvm::get<BroadcastToParam>(attrs.parsed);
CHECK(mxnet::shape_is_known(param.shape))
<< "the objective shape for broadcasting array must be known";
CHECK_LE(ishape.ndim(), param.shape.ndim())
<< "shape " << ishape << " is not broadcastable to " << param.shape;
TShape pshape = param.shape;
for (int i = param.shape.ndim() - 1; i >= 0; --i) {
int j = i - param.shape.ndim() + ishape.ndim();
if (j < 0) break;
CHECK(ishape[j] == param.shape[i] || ishape[j] == 1)
<< "shape " << ishape << " is not broadcastable to " << param.shape;
if (pshape[i] == -2) {
pshape[i] = ishape[j];
}
CHECK(ishape[j] == pshape[i] || ishape[j] == 1)
<< "shape " << ishape << " is not broadcastable to " << pshape;
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
CHECK(mxnet::shape_is_known(pshape))
haojin2 marked this conversation as resolved.
Show resolved Hide resolved
<< "the objective shape for broadcasting array must be known";
SHAPE_ASSIGN_CHECK(*out_attrs, 0, pshape);
return true;
}

Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,7 @@ def hybrid_forward(self, F, x):
((4, 1), (1, 2, 3, 4, 5)),
((4, 1), (1, 0, 3, 4, 5))
]

for src_shape, dst_shape in shapes:
for hybridize in [True, False]:
test_broadcast_to = TestBroadcastTo(dst_shape)
Expand Down Expand Up @@ -1587,6 +1588,32 @@ def hybrid_forward(self, F, x):
ret = test_scalar_broadcast_to(np.empty(()))
assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-5, atol=1e-6, use_broadcast=False)

# Test npx functionality
shapes = [
((5,), (3, 4, -2), (3, 4, 5)),
((5,), (0, -2), (0, 5)),
((1, 0), (2, -2, -2), (2, 1, 0)),
((3, 4), (1, 2, 3, -2), (1, 2, 3, 4)),
((3, 4), (1, 0, -2, 4), (1, 0, 3, 4))
]

for src_shape, npx_dst_shape, np_dst_shape in shapes:
for hybridize in [True, False]:
test_broadcast_to = TestBroadcastTo(npx_dst_shape)
if hybridize:
test_broadcast_to.hybridize()

a = _np.random.uniform(size=src_shape).astype(np.float32)
expected_ret = _np.broadcast_to(a, np_dst_shape)
a_mx = np.array(a, dtype=a.dtype)
a_mx.attach_grad()
with mx.autograd.record():
ret = test_broadcast_to(a_mx)
assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-5, atol=1e-6, use_broadcast=False)
ret.backward()
expected_grad = collapse_sum_like(_np.ones_like(expected_ret), src_shape)
assert_almost_equal(a_mx.grad.asnumpy(), expected_grad, rtol=1e-5, atol=1e-6, use_broadcast=False)


@with_seed()
@use_np
Expand Down