Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei Chu committed Dec 17, 2020
1 parent 693f7cb commit ea33835
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def export_to_onnx(model, model_name, inputs):

def onnx_rt(onnx_file, inputs):
sess = rt.InferenceSession(onnx_file)
input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy()) for i in range(len(inputs)))
dtype_0 = inputs[0].asnumpy().dtype
input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy().astype(dtype_0)) for i in range(len(inputs)))
pred = sess.run(None, input_dict)[0]
return pred

Expand Down Expand Up @@ -205,3 +206,15 @@ def test_onnx_export_fully_connected(tmp_path, dtype, num_hidden, no_bias, flatt
if not no_bias:
args.append(mx.nd.random.uniform(0,1,(num_hidden,)))
op_export_test('FullyConnected', M, args, tmp_path)

@pytest.mark.parametrize('dtype', ['float32', 'float64'])
def test_onnx_export_softmax(tmp_path, dtype):
x = mx.nd.random.uniform(0, 1, (2, 3, 4), dtype=dtype)
M1 = def_model('softmax')
op_export_test('softmax_1', M1, [x], tmp_path)
M2 = def_model('softmax', use_length=True, axis=0)
l2 = mx.nd.array([[2,0,3,1],[1,3,2,0], [0,0,0,1]], dtype=int)
op_export_test('softmax_2', M1, [x, l2], tmp_path)
M3 = def_model('softmax', use_length=True, axis=-1, temperature=0.5)
l3 = mx.nd.array([[2,0,1],[0,0,0]], dtype=int)
op_export_test('softmax_3', M1, [x, l3], tmp_path)

0 comments on commit ea33835

Please sign in to comment.