From a9ca9f8b748311adeb1ebcfe74b580f34c7952b9 Mon Sep 17 00:00:00 2001 From: Roshani Nagmote Date: Wed, 20 Feb 2019 16:45:44 -0800 Subject: [PATCH] onnx broadcast ops fixes (#13604) * broadcasting fixes * fix * addressing comments * fix * fix --- .../contrib/onnx/onnx2mx/_op_translations.py | 34 +++---------------- .../onnx/onnx2mx/_translation_utils.py | 16 ++++++++- 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index a7cef7674496..1a8d2cea9cd6 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -62,48 +62,22 @@ def sample_multinomial(attrs, inputs, proto_obj): new_attrs['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(attrs.get('dtype', 6))] return 'sample_multinomial', new_attrs, inputs - # Arithmetic Operations def add(attrs, inputs, proto_obj): """Adding two tensors""" - new_attr = {} - if 'broadcast' in attrs and attrs['broadcast'] == 1: - broadcast_axis = attrs['axis'] - op_value = translation_utils._fix_broadcast('broadcast_add', inputs, - broadcast_axis, proto_obj) - return op_value, new_attr, inputs - return 'broadcast_add', new_attr, inputs + return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_add') def subtract(attrs, inputs, proto_obj): """Subtracting two tensors""" - new_attr = {} - if 'broadcast' in attrs and attrs['broadcast'] == 1: - broadcast_axis = attrs['axis'] - op_value = translation_utils._fix_broadcast('broadcast_sub', inputs, - broadcast_axis, proto_obj) - return op_value, new_attr, inputs - return 'broadcast_sub', new_attr, inputs - + return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_sub') def multiply(attrs, inputs, proto_obj): """Multiply two tensors""" - new_attr = {} - if 'broadcast' in attrs and attrs['broadcast'] == 1: - broadcast_axis = attrs['axis'] - op_value = translation_utils._fix_broadcast('broadcast_mul', inputs, - broadcast_axis, proto_obj) - return op_value, new_attr, inputs - return 'broadcast_mul', new_attr, inputs + return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_mul') def divide(attrs, inputs, proto_obj): """Divide two tensors""" - new_attr = {} - if 'broadcast' in attrs and attrs['broadcast'] == 1: - broadcast_axis = attrs['axis'] - op_value = translation_utils._fix_broadcast('broadcast_div', inputs, - broadcast_axis, proto_obj) - return op_value, new_attr, inputs - return 'broadcast_div', new_attr, inputs + return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_div') def mean(attrs, inputs, proto_obj): """Mean of all the input tensors.""" diff --git a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py index 6fd52665ca31..0c6730513d4b 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py @@ -221,7 +221,7 @@ def get_input_shape(sym, proto_obj): model_input_shape = [data[1] for data in proto_obj.model_metadata.get('input_tensor_data')] data_names = [data[0] for data in proto_obj.model_metadata.get('input_tensor_data')] - #creating dummy inputs + # creating dummy inputs inputs = [] for in_shape in model_input_shape: inputs.append(nd.ones(shape=in_shape)) @@ -245,3 +245,17 @@ def get_input_shape(sym, proto_obj): result = mod.get_outputs()[0].asnumpy() return result.shape + +def broadcast_arithmetic_helper(attrs, inputs, proto_obj, current_op_name): + """Helper function for broadcast arithmetic ops.""" + new_attr = {} + op_names = ['batchnorm, convolution, deconvolution'] + if 'broadcast' in attrs and attrs['broadcast'] == 1: + broadcast_axis = attrs['axis'] + for op_name in op_names: + # if input is bias which comes after conv, deconv, batchnorm operators + # then only reshape bias term + if inputs[0].name.startswith(op_name): + op_value = _fix_broadcast(current_op_name, inputs, broadcast_axis, proto_obj) + return op_value, new_attr, inputs + return current_op_name, new_attr, inputs