From f2fca36953d58e6b3d725bc146004c13308125c9 Mon Sep 17 00:00:00 2001 From: Roshani Nagmote Date: Mon, 10 Dec 2018 15:25:17 -0800 Subject: [PATCH 1/5] broadcasting fixes --- .../contrib/onnx/onnx2mx/_op_translations.py | 44 ++++++++++++++----- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index a7cef7674496..730bd8ce3a8b 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -67,42 +67,62 @@ def sample_multinomial(attrs, inputs, proto_obj): def add(attrs, inputs, proto_obj): """Adding two tensors""" new_attr = {} + op_names = ['batchnorm, convolution, deconvolution'] if 'broadcast' in attrs and attrs['broadcast'] == 1: broadcast_axis = attrs['axis'] - op_value = translation_utils._fix_broadcast('broadcast_add', inputs, - broadcast_axis, proto_obj) - return op_value, new_attr, inputs + for op_name in op_names: + if inputs[0].name.startswith(op_name): + op_value = translation_utils._fix_broadcast('broadcast_add', inputs, + broadcast_axis, proto_obj) + return op_value, new_attr, inputs + else: + return 'broadcast_add', attrs, inputs return 'broadcast_add', new_attr, inputs def subtract(attrs, inputs, proto_obj): """Subtracting two tensors""" new_attr = {} + op_names = ['batchnorm, convolution, deconvolution'] if 'broadcast' in attrs and attrs['broadcast'] == 1: broadcast_axis = attrs['axis'] - op_value = translation_utils._fix_broadcast('broadcast_sub', inputs, - broadcast_axis, proto_obj) - return op_value, new_attr, inputs + for op_name in op_names: + if inputs[0].name.startswith(op_name): + op_value = translation_utils._fix_broadcast('broadcast_sub', inputs, + broadcast_axis, proto_obj) + return op_value, new_attr, inputs + else: + return 'broadcast_sub', attrs, inputs return 'broadcast_sub', new_attr, inputs def multiply(attrs, inputs, proto_obj): """Multiply two tensors""" new_attr = {} + op_names = ['batchnorm, convolution, deconvolution'] if 'broadcast' in attrs and attrs['broadcast'] == 1: broadcast_axis = attrs['axis'] - op_value = translation_utils._fix_broadcast('broadcast_mul', inputs, - broadcast_axis, proto_obj) - return op_value, new_attr, inputs + for op_name in op_names: + if inputs[0].name.startswith(op_name): + op_value = translation_utils._fix_broadcast('broadcast_mul', inputs, + broadcast_axis, proto_obj) + return op_value, new_attr, inputs + else: + return 'broadcast_mul', attrs, inputs return 'broadcast_mul', new_attr, inputs def divide(attrs, inputs, proto_obj): """Divide two tensors""" new_attr = {} + op_names = ['batchnorm, convolution, deconvolution'] if 'broadcast' in attrs and attrs['broadcast'] == 1: broadcast_axis = attrs['axis'] - op_value = translation_utils._fix_broadcast('broadcast_div', inputs, - broadcast_axis, proto_obj) - return op_value, new_attr, inputs + for op_name in op_names: + if inputs[0].name.startswith(op_name): + op_value = translation_utils._fix_broadcast('broadcast_div', inputs, + broadcast_axis, proto_obj) + return op_value, new_attr, inputs + else: + return 'broadcast_div', attrs, inputs return 'broadcast_div', new_attr, inputs def mean(attrs, inputs, proto_obj): From 600a54e3eaae8c5c667b72ab478349025af35895 Mon Sep 17 00:00:00 2001 From: Roshani Nagmote Date: Mon, 10 Dec 2018 15:45:36 -0800 Subject: [PATCH 2/5] fix --- python/mxnet/contrib/onnx/onnx2mx/_op_translations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 730bd8ce3a8b..a176cdc3180e 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -106,8 +106,8 @@ def multiply(attrs, inputs, proto_obj): op_value = translation_utils._fix_broadcast('broadcast_mul', inputs, broadcast_axis, proto_obj) return op_value, new_attr, inputs - else: - return 'broadcast_mul', attrs, inputs + else: + return 'broadcast_mul', attrs, inputs return 'broadcast_mul', new_attr, inputs def divide(attrs, inputs, proto_obj): From 50ef074231015aa25622682f37bf5ab27c24febe Mon Sep 17 00:00:00 2001 From: Roshani Nagmote Date: Mon, 18 Feb 2019 10:43:29 -0800 Subject: [PATCH 3/5] addressing comments --- .../contrib/onnx/onnx2mx/_op_translations.py | 54 ++----------------- .../onnx/onnx2mx/_translation_utils.py | 18 ++++++- 2 files changed, 21 insertions(+), 51 deletions(-) diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index a176cdc3180e..1a8d2cea9cd6 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -62,68 +62,22 @@ def sample_multinomial(attrs, inputs, proto_obj): new_attrs['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(attrs.get('dtype', 6))] return 'sample_multinomial', new_attrs, inputs - # Arithmetic Operations def add(attrs, inputs, proto_obj): """Adding two tensors""" - new_attr = {} - op_names = ['batchnorm, convolution, deconvolution'] - if 'broadcast' in attrs and attrs['broadcast'] == 1: - broadcast_axis = attrs['axis'] - for op_name in op_names: - if inputs[0].name.startswith(op_name): - op_value = translation_utils._fix_broadcast('broadcast_add', inputs, - broadcast_axis, proto_obj) - return op_value, new_attr, inputs - else: - return 'broadcast_add', attrs, inputs - return 'broadcast_add', new_attr, inputs + return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_add') def subtract(attrs, inputs, proto_obj): """Subtracting two tensors""" - new_attr = {} - op_names = ['batchnorm, convolution, deconvolution'] - if 'broadcast' in attrs and attrs['broadcast'] == 1: - broadcast_axis = attrs['axis'] - for op_name in op_names: - if inputs[0].name.startswith(op_name): - op_value = translation_utils._fix_broadcast('broadcast_sub', inputs, - broadcast_axis, proto_obj) - return op_value, new_attr, inputs - else: - return 'broadcast_sub', attrs, inputs - return 'broadcast_sub', new_attr, inputs - + return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_sub') def multiply(attrs, inputs, proto_obj): """Multiply two tensors""" - new_attr = {} - op_names = ['batchnorm, convolution, deconvolution'] - if 'broadcast' in attrs and attrs['broadcast'] == 1: - broadcast_axis = attrs['axis'] - for op_name in op_names: - if inputs[0].name.startswith(op_name): - op_value = translation_utils._fix_broadcast('broadcast_mul', inputs, - broadcast_axis, proto_obj) - return op_value, new_attr, inputs - else: - return 'broadcast_mul', attrs, inputs - return 'broadcast_mul', new_attr, inputs + return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_mul') def divide(attrs, inputs, proto_obj): """Divide two tensors""" - new_attr = {} - op_names = ['batchnorm, convolution, deconvolution'] - if 'broadcast' in attrs and attrs['broadcast'] == 1: - broadcast_axis = attrs['axis'] - for op_name in op_names: - if inputs[0].name.startswith(op_name): - op_value = translation_utils._fix_broadcast('broadcast_div', inputs, - broadcast_axis, proto_obj) - return op_value, new_attr, inputs - else: - return 'broadcast_div', attrs, inputs - return 'broadcast_div', new_attr, inputs + return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_div') def mean(attrs, inputs, proto_obj): """Mean of all the input tensors.""" diff --git a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py index 6fd52665ca31..de85d632246f 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py @@ -221,7 +221,7 @@ def get_input_shape(sym, proto_obj): model_input_shape = [data[1] for data in proto_obj.model_metadata.get('input_tensor_data')] data_names = [data[0] for data in proto_obj.model_metadata.get('input_tensor_data')] - #creating dummy inputs + # creating dummy inputs inputs = [] for in_shape in model_input_shape: inputs.append(nd.ones(shape=in_shape)) @@ -245,3 +245,19 @@ def get_input_shape(sym, proto_obj): result = mod.get_outputs()[0].asnumpy() return result.shape + +def broadcast_arithmetic_helper(attrs, inputs, proto_obj, op_name): + """Helper function for broadcast arithmetic ops.""" + new_attr = {} + op_names = ['batchnorm, convolution, deconvolution'] + if 'broadcast' in attrs and attrs['broadcast'] == 1: + broadcast_axis = attrs['axis'] + for op_name in op_names: + # if input is bias which comes after conv, deconv, batchnorm operators + # then only reshape bias term + if inputs[0].name.startswith(op_name): + op_value = _fix_broadcast(op_name, inputs, broadcast_axis, proto_obj) + return op_value, new_attr, inputs + else: + return op_name, attrs, inputs + return op_name, new_attr, inputs From 371f38f47a3b0c6d31d48d36876343487d5301d2 Mon Sep 17 00:00:00 2001 From: Roshani Nagmote Date: Mon, 18 Feb 2019 10:53:46 -0800 Subject: [PATCH 4/5] fix --- python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py index de85d632246f..a29cfee8f546 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py @@ -246,7 +246,7 @@ def get_input_shape(sym, proto_obj): return result.shape -def broadcast_arithmetic_helper(attrs, inputs, proto_obj, op_name): +def broadcast_arithmetic_helper(attrs, inputs, proto_obj, current_op_name): """Helper function for broadcast arithmetic ops.""" new_attr = {} op_names = ['batchnorm, convolution, deconvolution'] @@ -256,8 +256,8 @@ def broadcast_arithmetic_helper(attrs, inputs, proto_obj, op_name): # if input is bias which comes after conv, deconv, batchnorm operators # then only reshape bias term if inputs[0].name.startswith(op_name): - op_value = _fix_broadcast(op_name, inputs, broadcast_axis, proto_obj) + op_value = _fix_broadcast(current_op_name, inputs, broadcast_axis, proto_obj) return op_value, new_attr, inputs else: - return op_name, attrs, inputs - return op_name, new_attr, inputs + return current_op_name, attrs, inputs + return current_op_name, new_attr, inputs From 44eb0cbe7bc8adb4354cbf193c944b25b202b97b Mon Sep 17 00:00:00 2001 From: Roshani Nagmote Date: Mon, 18 Feb 2019 11:59:45 -0800 Subject: [PATCH 5/5] fix --- python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py index a29cfee8f546..0c6730513d4b 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py @@ -258,6 +258,4 @@ def broadcast_arithmetic_helper(attrs, inputs, proto_obj, current_op_name): if inputs[0].name.startswith(op_name): op_value = _fix_broadcast(current_op_name, inputs, broadcast_axis, proto_obj) return op_value, new_attr, inputs - else: - return current_op_name, attrs, inputs return current_op_name, new_attr, inputs