Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
change test
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiyan66 committed Dec 18, 2019
1 parent 33a272e commit e69c91e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ def _add_workload_mean(array_pool):
OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]]), axis=0)
OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]]), axis=1)


def _add_workload_median(array_pool):
OpArgMngr.add_workload('median', array_pool['4x1'])
OpArgMngr.add_workload('median', array_pool['4x1'], axis=0, keepdims=True)
Expand Down
3 changes: 3 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ def hybrid_forward(self, F, a):
((2, 3, 4), (0, 2)),
((2, 3, 4), 1)
]

for hybridize, keepdims, (a_shape, axis), dtype in \
itertools.product(flags, flags, tensor_shapes, dtypes):
atol = 3e-4 if dtype == 'float16' else 1e-4
Expand All @@ -741,11 +742,13 @@ def hybrid_forward(self, F, a):
a = np.random.uniform(-1.0, 1.0, size=a_shape)
np_out = _np.median(a.asnumpy(), axis=axis, keepdims=keepdims)
mx_out = test_median(a)

assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, atol=atol, rtol=rtol)

mx_out = np.median(a, axis=axis, keepdims=keepdims)
np_out = _np.median(a.asnumpy(), axis=axis, keepdims=keepdims)

assert_almost_equal(mx_out.asnumpy(), np_out, atol=atol, rtol=rtol)

@with_seed()
Expand Down

0 comments on commit e69c91e

Please sign in to comment.