Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Assert equal for boolean ndarrays
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Oct 21, 2019
1 parent 1bb3517 commit a24cd05
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
3 changes: 3 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,9 @@ def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=
b = b.asnumpy()

if use_np_allclose:
if a.dtype == np.bool_ and b.dtype == np.bool_:
np.testing.assert_equal(a, b)
return
if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
return
else:
Expand Down
18 changes: 12 additions & 6 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,15 +975,17 @@ def _add_workload_equal(array_pool):
# TODO(junwu): fp16 does not work yet with TVM generated ops
# OpArgMngr.add_workload('equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('equal', np.array([np.nan]), np.array([np.nan]))
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('equal', np.array([np.nan]), np.array([np.nan]))
OpArgMngr.add_workload('equal', array_pool['4x1'], array_pool['1x2'])


def _add_workload_not_equal(array_pool):
# TODO(junwu): fp16 does not work yet with TVM generated ops
# OpArgMngr.add_workload('not_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('not_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('not_equal', np.array([np.nan]), np.array([np.nan]))
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('not_equal', np.array([np.nan]), np.array([np.nan]))
OpArgMngr.add_workload('not_equal', array_pool['4x1'], array_pool['1x2'])


Expand All @@ -992,31 +994,35 @@ def _add_workload_greater(array_pool):
# OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('greater', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('greater', np.array([np.nan]), np.array([np.nan]))
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('greater', np.array([np.nan]), np.array([np.nan]))


def _add_workload_greater_equal(array_pool):
# TODO(junwu): fp16 does not work yet with TVM generated ops
# OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('greater_equal', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('greater_equal', np.array([np.nan]), np.array([np.nan]))
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('greater_equal', np.array([np.nan]), np.array([np.nan]))


def _add_workload_less(array_pool):
# TODO(junwu): fp16 does not work yet with TVM generated ops
# OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('less', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('less', np.array([np.nan]), np.array([np.nan]))
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('less', np.array([np.nan]), np.array([np.nan]))


def _add_workload_less_equal(array_pool):
# TODO(junwu): fp16 does not work yet with TVM generated ops
# OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('less_equal', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('less_equal', np.array([np.nan]), np.array([np.nan]))
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('less_equal', np.array([np.nan]), np.array([np.nan]))


@use_np
Expand Down

0 comments on commit a24cd05

Please sign in to comment.