From a98bbe1ea8796b29f691954d80e48ce3715747a1 Mon Sep 17 00:00:00 2001 From: Roshani Nagmote Date: Fri, 15 Jun 2018 11:14:28 -0700 Subject: [PATCH] added comments and warning --- .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 91a72fbc56ac..5f5561ab32b6 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -55,11 +55,12 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals + import re +import logging import numpy as np from .export_onnx import MXNetGraph as mx_op - def import_onnx_modules(): """ To make sure ONNX is runtime dependency, it is imported used only when needed""" try: @@ -94,6 +95,7 @@ def transform_padding(pad_width): onnx_pad_width = [0]*num_pad_values start_index = 0 + # num_pad_values will always be multiple of 2 end_index = int(num_pad_values/2) for idx in range(0, num_pad_values): if idx % 2 == 0: @@ -552,6 +554,15 @@ def convert_pooling(node, **kwargs): input_node = proc_nodes[input_node_idx] name = node["name"] + pooling_convention = attrs.get('pooling_convention', 'valid') + + if pooling_convention == 'full': + pooling_warning = "Pooling: ONNX currently doesn't support pooling_convention. " \ + "This might lead to shape or accuracy issues. " \ + "https://github.com/onnx/onnx/issues/549" + + logging.warning(pooling_warning) + pad_dims = list(parse_helper(attrs, "pad", [0, 0])) pad_dims = pad_dims + pad_dims pool_types = {"max": "MaxPool", "avg": "AveragePool"}