diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 3576242e0d77..5f6f8b2af486 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2341,7 +2341,8 @@ def convert_topk(node, **kwargs): axis = int(attrs.get('axis', '-1')) k = int(attrs.get('k', '1')) ret_type = attrs.get('ret_typ', 'indices') - is_ascend = int(attrs.get('is_ascend', '0')) + is_ascend = attrs.get('is_ascend', 'False') + is_ascend = is_ascend in ['1', 'True'] dtype = attrs.get('dtype', 'float32') if ret_type == 'mask': @@ -2355,30 +2356,30 @@ def convert_topk(node, **kwargs): if dtype == 'int64': nodes += [ make_node('TopK', [input_nodes[0], name+'_k'], [name+'0', name+'1'], axis=axis, - largest=(0 if is_ascend else 1), sorted=1), + largest=(not is_ascend), sorted=1), ] else: nodes += [ make_node('TopK', [input_nodes[0], name+'_k'], [name+'0', name+'_1_i'], axis=axis, - largest=(0 if is_ascend else 1), sorted=1), + largest=(not is_ascend), sorted=1), make_node('Cast', [name+'_1_i'], [name+'1'], to=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]) ] elif ret_type == 'value': nodes += [ make_node('TopK', [input_nodes[0], name+'_k'], [name+'0', name+'_'], axis=axis, - largest=(0 if is_ascend else 1), sorted=1), + largest=(not is_ascend), sorted=1), ] else: if dtype == 'int64': nodes += [ make_node('TopK', [input_nodes[0], name+'_k'], [name+'_', name], axis=axis, - largest=(0 if is_ascend else 1), sorted=1), + largest=(not is_ascend), sorted=1), ] else: nodes += [ make_node('TopK', [input_nodes[0], name+'_k'], [name+'__', name+'_tmp'], axis=axis, - largest=(0 if is_ascend else 1), sorted=1), + largest=(not is_ascend), sorted=1), make_node('Cast', [name+'_tmp'], [name], to=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]) ] @@ -3871,3 +3872,40 @@ def convert_log2(node, **kwargs): ] return nodes + + +@mx_op.register('argsort') +def convert_argsort(node, **kwargs): + """Map MXNet's argsort operator attributes to onnx's TopK operator + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + opset_version = kwargs['opset_version'] + if opset_version < 11: + raise AttributeError('ONNX opset 11 or greater is required to export this operator') + + axis = int(attrs.get('axis', '-1')) + is_ascend = attrs.get('is_ascend', 'True') + is_ascend = is_ascend in ['True', '1'] + dtype = attrs.get('dtype', 'float32') + + create_tensor([axis], name+'_axis', kwargs['initializer']) + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Gather', [name+'_shape', name+'_axis'], [name+'_k']) + ] + if dtype == 'int64': + nodes += [ + make_node('TopK', [input_nodes[0], name+'_k'], [name+'_', name], axis=axis, + largest=(not is_ascend), sorted=1), + ] + else: + nodes += [ + make_node('TopK', [input_nodes[0], name+'_k'], [name+'_', name+'_temp'], axis=axis, + largest=(not is_ascend), sorted=1), + make_node('Cast', [name+'_temp'], [name], + to=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]) + ] + + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index f4b44e58f188..79a2b4d6c34b 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -548,13 +548,16 @@ def test_onnx_export_where(tmp_path, dtype, shape): # onnxruntime does not seem to support float64 and int32 @pytest.mark.parametrize('dtype', ['float16', 'float32', 'int64']) @pytest.mark.parametrize('axis', [0, 2, -1, -2, -3]) -@pytest.mark.parametrize('is_ascend', [0, 1]) +@pytest.mark.parametrize('is_ascend', [True, False, 0, 1, None]) @pytest.mark.parametrize('k', [1, 4]) @pytest.mark.parametrize('dtype_i', ['float32', 'int32', 'int64']) @pytest.mark.parametrize('ret_typ', ['value', 'indices', 'both']) def test_onnx_export_topk(tmp_path, dtype, axis, is_ascend, k, dtype_i, ret_typ): A = mx.random.uniform(0, 100, (4, 5, 6)).astype(dtype) - M = def_model('topk', axis=axis, is_ascend=is_ascend, k=k, dtype=dtype_i, ret_typ=ret_typ) + kwargs = {} + if is_ascend is not None: + kwargs['is_ascend'] = is_ascend + M = def_model('topk', axis=axis, k=k, dtype=dtype_i, ret_typ=ret_typ, **kwargs) op_export_test('topk', M, [A], tmp_path) @@ -939,3 +942,17 @@ def test_onnx_export_batchnorm(tmp_path, dtype, momentum): moving_var = mx.nd.abs(mx.nd.random.normal(0, 10, (3))).astype(dtype) M = def_model('BatchNorm', eps=1e-5, momentum=momentum, fix_gamma=False, use_global_stats=False) op_export_test('BatchNorm1', M, [x, gamma, beta, moving_mean, moving_var], tmp_path) + + +# onnxruntime does not seem to support float64 and int32 +@pytest.mark.parametrize('dtype', ['float32', 'int64']) +@pytest.mark.parametrize('axis', [0, 2, -1, -2, -3]) +@pytest.mark.parametrize('is_ascend', [True, False, 0, 1, None]) +@pytest.mark.parametrize('dtype_i', ['float32', 'int32', 'int64']) +def test_onnx_export_argsort(tmp_path, dtype, axis, is_ascend, dtype_i): + A = mx.random.uniform(0, 100, (4, 5, 6)).astype(dtype) + kwargs = {} + if is_ascend is not None: + kwargs['is_ascend'] = is_ascend + M = def_model('argsort', axis=axis, dtype=dtype_i, **kwargs) + op_export_test('argsort', M, [A], tmp_path)