diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 34abddd9dba3..c60e9f2ca997 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -162,6 +162,7 @@ def _register_array_function(): 'arctan2', 'copysign', 'degrees', + 'equal', 'subtract', 'multiply', # Uncomment divide when mxnet.numpy.true_divide is added diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index f9b1cf6eaba3..094e558bb600 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -146,6 +146,8 @@ def _prepare_workloads(): OpArgMngr.add_workload('copysign', np.array([-2, 5, 1, 4, 3], dtype=np.float16), np.array([0, 1, 2, 4, 2], dtype=np.float16)) OpArgMngr.add_workload('degrees', np.array(np.pi)) OpArgMngr.add_workload('degrees', np.array(-0.5*np.pi)) + OpArgMngr.add_workload('equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16)) + OpArgMngr.add_workload('equal', np.array([np.nan])) OpArgMngr.add_workload('subtract', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('subtract', array_pool['4x1'], 2) OpArgMngr.add_workload('subtract', 2, array_pool['4x1'])