diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index b3e73d7a0732..1d59f882c680 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -214,7 +214,7 @@ def test_onnx_export_softmax(tmp_path, dtype): op_export_test('softmax_1', M1, [x], tmp_path) M2 = def_model('softmax', use_length=True, axis=0) l2 = mx.nd.array([[2,0,3,1],[1,3,2,0], [0,0,0,1]], dtype=int) - op_export_test('softmax_2', M1, [x, l2], tmp_path) + op_export_test('softmax_2', M2, [x, l2], tmp_path) M3 = def_model('softmax', use_length=True, axis=-1, temperature=0.5) l3 = mx.nd.array([[2,0,1],[0,0,0]], dtype=int) - op_export_test('softmax_3', M1, [x, l3], tmp_path) + op_export_test('softmax_3', M3, [x, l3], tmp_path)