diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index b1ab40e1bf02..161d77579231 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -648,12 +648,13 @@ def convert_pooling(node, **kwargs): p_value = attrs.get('p_value', 'None') pooling_convention = attrs.get('pooling_convention', 'valid') - + ceil_mode = False if pooling_convention == 'full': - pooling_warning = "Pooling: ONNX currently doesn't support pooling_convention. " \ - "This might lead to shape or accuracy issues. " \ - "https://github.com/onnx/onnx/issues/549" - + if onnx.__version__ < "1.5.0": + pooling_warning = "Pooling: ONNX lower than 1.5.0 doesn't support pooling_convention. " \ + "This might lead to shape or accuracy issues. " \ + "https://github.com/onnx/onnx/issues/549" + ceil_mode = True logging.warning(pooling_warning) pad_dims = list(parse_helper(attrs, "pad", [0, 0])) @@ -694,15 +695,27 @@ def convert_pooling(node, **kwargs): name=name ) else: - node = onnx.helper.make_node( - pool_types[pool_type], - input_nodes, # input - [name], - kernel_shape=kernel, - pads=pad_dims, - strides=stride, - name=name - ) + if onnx.__version__ >= "1.5.0": + node = onnx.helper.make_node( + pool_types[pool_type], + input_nodes, # input + [name], + kernel_shape=kernel, + pads=pad_dims, + strides=stride, + name=name, + ceil_mode=ceil_mode + ) + else: + node = onnx.helper.make_node( + pool_types[pool_type], + input_nodes, # input + [name], + kernel_shape=kernel, + pads=pad_dims, + strides=stride, + name=name + ) return [node]