Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

onnx broadcast ops fixes #13604

Merged
merged 5 commits into from
Feb 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 4 additions & 30 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,48 +62,22 @@ def sample_multinomial(attrs, inputs, proto_obj):
new_attrs['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(attrs.get('dtype', 6))]
return 'sample_multinomial', new_attrs, inputs


# Arithmetic Operations
def add(attrs, inputs, proto_obj):
"""Adding two tensors"""
new_attr = {}
if 'broadcast' in attrs and attrs['broadcast'] == 1:
broadcast_axis = attrs['axis']
op_value = translation_utils._fix_broadcast('broadcast_add', inputs,
broadcast_axis, proto_obj)
return op_value, new_attr, inputs
return 'broadcast_add', new_attr, inputs
return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_add')

def subtract(attrs, inputs, proto_obj):
"""Subtracting two tensors"""
new_attr = {}
if 'broadcast' in attrs and attrs['broadcast'] == 1:
broadcast_axis = attrs['axis']
op_value = translation_utils._fix_broadcast('broadcast_sub', inputs,
broadcast_axis, proto_obj)
return op_value, new_attr, inputs
return 'broadcast_sub', new_attr, inputs

return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_sub')

def multiply(attrs, inputs, proto_obj):
"""Multiply two tensors"""
new_attr = {}
if 'broadcast' in attrs and attrs['broadcast'] == 1:
broadcast_axis = attrs['axis']
op_value = translation_utils._fix_broadcast('broadcast_mul', inputs,
broadcast_axis, proto_obj)
return op_value, new_attr, inputs
return 'broadcast_mul', new_attr, inputs
return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_mul')

def divide(attrs, inputs, proto_obj):
"""Divide two tensors"""
new_attr = {}
if 'broadcast' in attrs and attrs['broadcast'] == 1:
broadcast_axis = attrs['axis']
op_value = translation_utils._fix_broadcast('broadcast_div', inputs,
broadcast_axis, proto_obj)
return op_value, new_attr, inputs
return 'broadcast_div', new_attr, inputs
return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_div')

def mean(attrs, inputs, proto_obj):
"""Mean of all the input tensors."""
Expand Down
16 changes: 15 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def get_input_shape(sym, proto_obj):
model_input_shape = [data[1] for data in proto_obj.model_metadata.get('input_tensor_data')]
data_names = [data[0] for data in proto_obj.model_metadata.get('input_tensor_data')]

#creating dummy inputs
# creating dummy inputs
inputs = []
for in_shape in model_input_shape:
inputs.append(nd.ones(shape=in_shape))
Expand All @@ -245,3 +245,17 @@ def get_input_shape(sym, proto_obj):
result = mod.get_outputs()[0].asnumpy()

return result.shape

def broadcast_arithmetic_helper(attrs, inputs, proto_obj, current_op_name):
"""Helper function for broadcast arithmetic ops."""
new_attr = {}
op_names = ['batchnorm, convolution, deconvolution']
if 'broadcast' in attrs and attrs['broadcast'] == 1:
broadcast_axis = attrs['axis']
for op_name in op_names:
# if input is bias which comes after conv, deconv, batchnorm operators
# then only reshape bias term
if inputs[0].name.startswith(op_name):
op_value = _fix_broadcast(current_op_name, inputs, broadcast_axis, proto_obj)
return op_value, new_attr, inputs
return current_op_name, new_attr, inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we are losing info about attrs here? We always return empty from this function. Sorry couldn't understand this logic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

onnx attrs can have "broadcast/consumed_inputs" as input according to old onnx opset.. If broadcast is mentioned, we check that in line 253. But MXNet doesnt have these inputs.. broadcast ops in mxnet have empty attr. So by default passing empty.