Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add test for raise
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei Chu committed Feb 12, 2021
1 parent a094577 commit 7ede696
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,3 +1133,12 @@ def test_onnx_export_take(tmp_path, dtype, axis, mode):
op_export_test('take1', M1, [x, y], tmp_path)
M2 = def_model('take', axis=axis, mode=mode)
op_export_test('take2', M2, [x, y], tmp_path)


@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
@pytest.mark.parametrize('axis', [-3, -2, -1, 0, 1, 2])
def test_onnx_export_take_raise(tmp_path, dtype, axis):
x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype)
y = mx.random.randint(0, 4, (6, 7)).astype(dtype)
M = def_model('take', axis=axis, mode='raise')
op_export_test('take', M, [x, y], tmp_path)

0 comments on commit 7ede696

Please sign in to comment.