Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Added support L2Normalization op
Browse files Browse the repository at this point in the history
Added some error checking
  • Loading branch information
rajanksin committed Jun 15, 2018
1 parent 20bb2fd commit f4902e1
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,31 @@ def convert_lrn(node, **kwargs):
return [lrn_node]


@mx_op.register("L2Normalization")
def convert_l2normalization(node, **kwargs):
"""Map MXNet's L2Normalization operator attributes to onnx's LpNormalization operator
and return the created node.
"""
helper, _, _ = import_onnx_modules()
name = node["name"]
input_id = kwargs["index_lookup"][node["inputs"][0][0]]
input_name = kwargs["proc_nodes"][input_id].name
attrs = node["attrs"]
mode = attrs.get("mode", "instance")

if mode != "channel":
raise AttributeError("ONNX currently supports channel mode only")

l2norm_node = helper.make_node(
"LpNormalization",
[input_name],
[name],
axis=1, # channel only
name=name
)
return [l2norm_node]


@mx_op.register("Dropout")
def convert_dropout(node, **kwargs):
"""Map MXNet's Dropout operator attributes to onnx's Dropout operator
Expand Down Expand Up @@ -1625,7 +1650,10 @@ def convert_slice_axis(node, **kwargs):
inputs = node["inputs"]
axes = int(node["attrs"]["axis"])
starts = int(node["attrs"]["begin"])
ends = int(node["attrs"]["end"])
if node["attrs"]["end"] == 'None':
raise ValueError("Slice: ONNX doesnt't support 'None' in 'end' attribute")
else:
ends = int(node["attrs"]["end"])

input_node_id = kwargs["index_lookup"][inputs[0][0]]
input_node = proc_nodes[input_node_id].name
Expand Down

0 comments on commit f4902e1

Please sign in to comment.