Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX import/export: Make pow backward compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Oct 16, 2018
1 parent ff470d7 commit a20b496
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
26 changes: 25 additions & 1 deletion python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,7 +2086,31 @@ def convert_power(node, **kwargs):
"Pow",
[input_node_a, input_node_b],
[name],
name=None
name=name
)
return [node]

@mx_op.register("broadcast_power")
def convert_broadcast_power(node, **kwargs):
"""Map MXNet's _power operator attributes to onnx's Pow operator
and return the created node.
"""
onnx = import_onnx_modules()
name = node["name"]
proc_nodes = kwargs["proc_nodes"]
inputs = node["inputs"]

input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
input_node_b_id = kwargs["index_lookup"][inputs[1][0]]

input_node_a = proc_nodes[input_node_a_id].name
input_node_b = proc_nodes[input_node_b_id].name

node = onnx.helper.make_node(
"Pow",
[input_node_a, input_node_b],
[name],
name=name
)
return [node]

Expand Down
11 changes: 8 additions & 3 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,10 +534,15 @@ def squareroot(attrs, inputs, proto_obj):
def power(attrs, inputs, proto_obj):
"""Returns element-wise result of base element raised to powers from exp element."""
new_attrs = translation_utils._fix_attribute_names(attrs, {'exponent':'exp'})
if 'broadcast' in attrs and attrs['broadcast'] == 1:
if 'broadcast' in attrs:
new_attrs = translation_utils._remove_attributes(new_attrs, ['broadcast'])
return 'broadcast_power', new_attrs, inputs
return 'pow', new_attrs, inputs
if attrs['broadcast'] == 1:
return 'broadcast_power', new_attrs, inputs
else:
mxnet_op = symbol.pow(inputs[0], inputs[1])
return mxnet_op, new_attrs, inputs
mxnet_op = symbol.broadcast_power(inputs[0], inputs[1])
return mxnet_op, new_attrs, inputs

def exponent(attrs, inputs, proto_obj):
"""Elementwise exponent of input array."""
Expand Down

0 comments on commit a20b496

Please sign in to comment.