diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index ab1fd6c78b64..3dd6123a5914 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2086,7 +2086,31 @@ def convert_power(node, **kwargs): "Pow", [input_node_a, input_node_b], [name], - name=None + name=name + ) + return [node] + +@mx_op.register("broadcast_power") +def convert_broadcast_power(node, **kwargs): + """Map MXNet's _power operator attributes to onnx's Pow operator + and return the created node. + """ + onnx = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_b_id = kwargs["index_lookup"][inputs[1][0]] + + input_node_a = proc_nodes[input_node_a_id].name + input_node_b = proc_nodes[input_node_b_id].name + + node = onnx.helper.make_node( + "Pow", + [input_node_a, input_node_b], + [name], + name=name ) return [node] diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 7040103c005a..fedd7134c3d4 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -534,10 +534,15 @@ def squareroot(attrs, inputs, proto_obj): def power(attrs, inputs, proto_obj): """Returns element-wise result of base element raised to powers from exp element.""" new_attrs = translation_utils._fix_attribute_names(attrs, {'exponent':'exp'}) - if 'broadcast' in attrs and attrs['broadcast'] == 1: + if 'broadcast' in attrs: new_attrs = translation_utils._remove_attributes(new_attrs, ['broadcast']) - return 'broadcast_power', new_attrs, inputs - return 'pow', new_attrs, inputs + if attrs['broadcast'] == 1: + return 'broadcast_power', new_attrs, inputs + else: + mxnet_op = symbol.pow(inputs[0], inputs[1]) + return mxnet_op, new_attrs, inputs + mxnet_op = symbol.broadcast_power(inputs[0], inputs[1]) + return mxnet_op, new_attrs, inputs def exponent(attrs, inputs, proto_obj): """Elementwise exponent of input array."""