Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Handle default p_value
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Dec 19, 2018
1 parent 1247006 commit 866fae0
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def convert_pooling(node, **kwargs):
pool_type = attrs["pool_type"]
stride = eval(attrs["stride"]) if attrs.get("stride") else None
global_pool = get_boolean_attribute_value(attrs, "global_pool")
p_value = int(attrs.get('p_value', '2'))
p_value = attrs.get('p_value', 'None')

pooling_convention = attrs.get('pooling_convention', 'valid')

Expand All @@ -592,13 +592,16 @@ def convert_pooling(node, **kwargs):
global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool",
"lp": "GlobalLpPool"}

if pool_type == 'lp' and p_value == 'None':
raise AttributeError('ONNX requires a p value for LpPool and GlobalLpPool')

if global_pool:
if pool_type == 'lp':
node = onnx.helper.make_node(
global_pool_types[pool_type],
input_nodes, # input
[name],
p=p_value,
p=int(p_value),
name=name
)
else:
Expand All @@ -614,7 +617,7 @@ def convert_pooling(node, **kwargs):
pool_types[pool_type],
input_nodes, # input
[name],
p=p_value,
p=int(p_value),
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
Expand Down

0 comments on commit 866fae0

Please sign in to comment.