Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
new unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed May 30, 2019
1 parent d183768 commit 49947a3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 34 deletions.
12 changes: 12 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,17 @@ bool ConcatType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type);

struct NumpyConcatGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
CHECK_EQ(ograds.size(), 1);
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};


NNVM_REGISTER_OP(_npi_concatenate)
.describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
Expand All @@ -246,6 +257,7 @@ NNVM_REGISTER_OP(_npi_concatenate)
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", NumpyConcatGrad{"_backward_Concat"})
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

Expand Down
72 changes: 38 additions & 34 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,40 +324,44 @@ def __init__(self, axis=None):
def hybrid_forward(self, F, a, *args):
return F.np.concatenate([a] + list(args), axis=self._axis)

in_data_dim = random.choice([2, 3, 4])
shape = rand_shape_nd(in_data_dim, dim=3)

for hybridize in [True, False]:
for axis in [i for i in range(in_data_dim)]:
# test gluon
test_concat = TestConcat(axis=axis)
if hybridize:
test_concat.hybridize()
a = mx.nd.random.uniform(-1.0, 1.0, shape=shape).as_np_ndarray()
a.attach_grad()
b = mx.nd.random.uniform(-1.0, 1.0, shape=shape).as_np_ndarray()
b.attach_grad()
c = mx.nd.random.uniform(-1.0, 1.0, shape=shape).as_np_ndarray()
c.attach_grad()
d = mx.nd.random.uniform(-1.0, 1.0, shape=shape).as_np_ndarray()
d.attach_grad()
expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
with mx.autograd.record():
y = test_concat(a, b, c, d)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)

y.backward()

assert_almost_equal(a.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(b.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(c.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(d.grad.asnumpy(), _np.ones(shape), rtol=1e-3, atol=1e-5)

# test imperative
mx_out = np.concatenate([a, b, c, d], axis=axis)
np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
def get_new_shape(shape, axis):
shape_lst = list(shape)
shape_lst[axis] = random.randint(0, 3)
return tuple(shape_lst)

for shape in [(0, 0), (2, 3)]:
for hybridize in [True, False]:
for axis in [i for i in range(2)]:
# test gluon
test_concat = TestConcat(axis=axis)
if hybridize:
test_concat.hybridize()

a = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
a.attach_grad()
b = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
b.attach_grad()
c = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
c.attach_grad()
d = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
d.attach_grad()
expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
with mx.autograd.record():
y = test_concat(a, b, c, d)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)

y.backward()

assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5)

# test imperative
mx_out = np.concatenate([a, b, c, d], axis=axis)
np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
Expand Down

0 comments on commit 49947a3

Please sign in to comment.