Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei Chu committed Dec 18, 2020
1 parent 5a350af commit d960c03
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,14 @@ def test_onnx_export_dropout(tmp_path, dtype, p):
x = mx.nd.array([[3,0.5,-0.5,2,7],[2,-0.4,7,3,0.2]], dtype=dtype)
op_export_test('Dropout', M, [x], tmp_path)

@pytest.mark.parametrize('src_dtype', ['float16', 'float32', 'float64'])
@pytest.mark.parametrize('dst_dtype', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'])
@pytest.mark.parametrize('shape', [(2,3), (4,5,6)])
def test_onnx_export_cast(tmp_path, src_dtype, dst_dtype, shape):
M = def_model('Cast', dtype=dst_dtype)
x = mx.nd.ones(shape, dtype=src_dtype)
op_export_test('Cast', M, [x], tmp_path)

@pytest.mark.parametrize('dtype', ['float16', 'float32'])
@pytest.mark.parametrize('temperature', [0.3, 0.5, 1.0])
def test_onnx_export_softmax(tmp_path, dtype, temperature):
Expand Down

0 comments on commit d960c03

Please sign in to comment.