diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index f6d06f0395fa..cffe7a7b964b 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -717,6 +717,7 @@ def spacetodepth(attrs, inputs, proto_obj): def hardmax(attrs, inputs, proto_obj): + """Returns batched one-hot vectors.""" input_tensor_data = proto_obj.model_metadata.get('input_tensor_data')[0] input_shape = input_tensor_data[1] @@ -724,8 +725,8 @@ def hardmax(attrs, inputs, proto_obj): axis = axis if axis >= 0 else len(input_shape) + axis if axis == len(input_shape) - 1: - argmax = symbol.argmax(inputs[0], axis=-1) - one_hot = symbol.one_hot(argmax, depth=input_shape[-1]) + amax = symbol.argmax(inputs[0], axis=-1) + one_hot = symbol.one_hot(amax, depth=input_shape[-1]) return one_hot, attrs, inputs # since reshape doesn't take a tensor for shape, @@ -734,7 +735,7 @@ def hardmax(attrs, inputs, proto_obj): new_shape = (int(np.prod(input_shape[:axis])), int(np.prod(input_shape[axis:]))) reshape_op = symbol.reshape(inputs[0], new_shape) - argmax = symbol.argmax(reshape_op, axis=-1) - one_hot = symbol.one_hot(argmax, depth=new_shape[-1]) + amax = symbol.argmax(reshape_op, axis=-1) + one_hot = symbol.one_hot(amax, depth=new_shape[-1]) hardmax_op = symbol.reshape(one_hot, input_shape) return hardmax_op, attrs, inputs