diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 1153049838bb..56059bc536fe 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1139,6 +1139,6 @@ def test_onnx_export_take(tmp_path, dtype, axis, mode): @pytest.mark.parametrize('axis', [-3, -2, -1, 0, 1, 2]) def test_onnx_export_take_raise(tmp_path, dtype, axis): x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype) - y = mx.random.randint(0, 4, (6, 7)).astype(dtype) + y = mx.random.randint(0, 3, (6, 7)).astype(dtype) M = def_model('take', axis=axis, mode='raise') op_export_test('take', M, [x, y], tmp_path) \ No newline at end of file