Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] More numpy dispatch tests (#16426)
Browse files Browse the repository at this point in the history
* tests added

* remove not equal

* fix tiny bug

* remove meshgrid test

* modify meshgrid return type, add test
  • Loading branch information
xidulu authored and reminisce committed Oct 18, 2019
1 parent 63fbfb1 commit f01bcaa
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/mxnet/numpy/stride_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ def broadcast_arrays(*args):

if all(array.shape == shape for array in args):
# Common case where nothing needs to be broadcasted.
return args
return list(args)

return [_mx_np_op.broadcast_to(array, shape) for array in args]
3 changes: 3 additions & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'transpose',
'var',
'zeros_like',
'meshgrid',
'outer'
]


Expand Down Expand Up @@ -196,6 +198,7 @@ def _register_array_function():
'ceil',
'trunc',
'floor',
'logical_not',
]


Expand Down
21 changes: 21 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def _prepare_workloads():
OpArgMngr.add_workload('min', array_pool['4x1'])
OpArgMngr.add_workload('mean', array_pool['4x1'])
OpArgMngr.add_workload('mean', array_pool['4x1'], axis=0, keepdims=True)
OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]]))
OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]]), axis=0)
OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]]), axis=1)
OpArgMngr.add_workload('ones_like', array_pool['4x1'])
OpArgMngr.add_workload('prod', array_pool['4x1'])

Expand Down Expand Up @@ -157,6 +160,10 @@ def _prepare_workloads():
OpArgMngr.add_workload('transpose', array_pool['4x1'])
OpArgMngr.add_workload('var', array_pool['4x1'])
OpArgMngr.add_workload('zeros_like', array_pool['4x1'])
OpArgMngr.add_workload('outer', np.ones((5)), np.ones((2)))
OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3]))
OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3]), np.array([4, 5, 6, 7]))
OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3]), np.array([4, 5, 6, 7]), indexing='ij')

# workloads for array ufunc protocol
OpArgMngr.add_workload('add', array_pool['4x1'], array_pool['1x2'])
Expand All @@ -175,6 +182,9 @@ def _prepare_workloads():
OpArgMngr.add_workload('power', array_pool['4x1'], 2)
OpArgMngr.add_workload('power', 2, array_pool['4x1'])
OpArgMngr.add_workload('power', array_pool['4x1'], array_pool['1x1x0'])
OpArgMngr.add_workload('power', np.array([1, 2, 3], np.int32), 2.00001)
OpArgMngr.add_workload('power', np.array([15, 15], np.int64), np.array([15, 15], np.int64))
OpArgMngr.add_workload('power', 0, np.arange(1, 10))
OpArgMngr.add_workload('mod', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('mod', array_pool['4x1'], 2)
OpArgMngr.add_workload('mod', 2, array_pool['4x1'])
Expand Down Expand Up @@ -256,6 +266,12 @@ def _signs(dt):
OpArgMngr.add_workload('exp', array_pool['4x1'])
OpArgMngr.add_workload('log', array_pool['4x1'])
OpArgMngr.add_workload('log2', array_pool['4x1'])
OpArgMngr.add_workload('log2', np.array(2.**65))
OpArgMngr.add_workload('log2', np.array(np.inf))
OpArgMngr.add_workload('log2', np.array(1.))
OpArgMngr.add_workload('log1p', np.array(-1.))
OpArgMngr.add_workload('log1p', np.array(np.inf))
OpArgMngr.add_workload('log1p', np.array(1e-6))
OpArgMngr.add_workload('log10', array_pool['4x1'])
OpArgMngr.add_workload('expm1', array_pool['4x1'])
OpArgMngr.add_workload('sqrt', array_pool['4x1'])
Expand All @@ -282,6 +298,11 @@ def _signs(dt):
OpArgMngr.add_workload('ceil', array_pool['4x1'])
OpArgMngr.add_workload('trunc', array_pool['4x1'])
OpArgMngr.add_workload('floor', array_pool['4x1'])
OpArgMngr.add_workload('logical_not', np.ones(10, dtype=np.int32))
OpArgMngr.add_workload('logical_not', array_pool['4x1'])
OpArgMngr.add_workload('logical_not', np.array([True, False, True, False], dtype=np.bool))




_prepare_workloads()
Expand Down

0 comments on commit f01bcaa

Please sign in to comment.