diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py index 2ceabaec1dcd..5b33f9faac11 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py @@ -23,7 +23,7 @@ from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan from ._op_translations import softplus, shape, gather, lp_pooling, size from ._op_translations import ceil, floor, hardsigmoid, global_lppooling -from ._op_translations import concat +from ._op_translations import concat, hardmax from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, fully_connected from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm from ._op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm @@ -144,5 +144,6 @@ 'HardSigmoid' : hardsigmoid, 'LpPool' : lp_pooling, 'DepthToSpace' : depthtospace, - 'SpaceToDepth' : spacetodepth + 'SpaceToDepth' : spacetodepth, + 'Hardmax' : hardmax } diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 702832529314..ce0e0e51ef79 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -714,3 +714,29 @@ def spacetodepth(attrs, inputs, proto_obj): new_attrs = translation_utils._fix_attribute_names(attrs, {'blocksize':'block_size'}) return "space_to_depth", new_attrs, inputs + + +def hardmax(attrs, inputs, proto_obj): + """Returns batched one-hot vectors.""" + input_tensor_data = proto_obj.model_metadata.get('input_tensor_data')[0] + input_shape = input_tensor_data[1] + + axis = int(attrs.get('axis', 1)) + axis = axis if axis >= 0 else len(input_shape) + axis + + if axis == len(input_shape) - 1: + amax = symbol.argmax(inputs[0], axis=-1) + one_hot = symbol.one_hot(amax, depth=input_shape[-1]) + return one_hot, attrs, inputs + + # since reshape doesn't take a tensor for shape, + # computing with np.prod. This needs to be changed to + # to use mx.sym.prod() when mx.sym.reshape() is fixed. + # (https://github.com/apache/incubator-mxnet/issues/10789) + new_shape = (int(np.prod(input_shape[:axis])), + int(np.prod(input_shape[axis:]))) + reshape_op = symbol.reshape(inputs[0], new_shape) + amax = symbol.argmax(reshape_op, axis=-1) + one_hot = symbol.one_hot(amax, depth=new_shape[-1]) + hardmax_op = symbol.reshape(one_hot, input_shape) + return hardmax_op, attrs, inputs diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py index 92e80e056701..6a189b62492d 100644 --- a/tests/python-pytest/onnx/test_cases.py +++ b/tests/python-pytest/onnx/test_cases.py @@ -90,7 +90,8 @@ 'test_averagepool_2d_strides', 'test_averagepool_3d', 'test_LpPool_', - 'test_split_equal' + 'test_split_equal', + 'test_hardmax' ], 'export': ['test_random_uniform', 'test_random_normal',