-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-344] Add more operators to onnx import #11856
Conversation
Thanks for working on this. LGTM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@zhreshold - Can you please take a look once? |
for op_input in inputs: | ||
concat_input.append(symbol.expand_dims(op_input, axis=0)) | ||
concat_sym = symbol.concat(*concat_input, dim=0) | ||
mean_sym = symbol.mean(concat_sym, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is mean on axis 0, are you sure it is the desired behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this is correct. We are doing an unsqueeze along axis=0, hence the mean along axis=0. It has also passed all the operator tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, so it's elem-wise mean
@@ -348,6 +361,14 @@ def global_avgpooling(attrs, inputs, proto_obj): | |||
'pool_type': 'avg'}) | |||
return 'Pooling', new_attrs, inputs | |||
|
|||
def global_lppooling(attrs, inputs, proto_obj): | |||
"""Performs global lp pooling on the input.""" | |||
p_value = attrs['p'] if 'p' in attrs else 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
p_value = attrs.get('p', 2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is lp pooling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -505,6 +529,30 @@ def exponent(attrs, inputs, proto_obj): | |||
"""Elementwise exponent of input array.""" | |||
return 'exp', attrs, inputs | |||
|
|||
def _cos(attrs, inputs, proto_obj): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why _cos, _sin for example need to start with leading underscore? in python they are under math namescope, so I guess there's no conflict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there will be a namespace conflict if there is an import math statement. probably better practice to is the leading underscore for names that are very common and prevent namespace collision.
@@ -578,6 +637,20 @@ def avg_pooling(attrs, inputs, proto_obj): | |||
|
|||
return new_op, new_attrs, inputs | |||
|
|||
def lp_pooling(attrs, inputs, proto_obj): | |||
"""LP Pooling""" | |||
p_value = attrs['p'] if 'p' in attrs else 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use get rather than if else
"""Mean of two input tensors.""" | ||
concat_input = [] | ||
for op_input in inputs: | ||
concat_input.append(symbol.expand_dims(op_input, axis=0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use list comprehensive is as simple as one line
concat_input = [symbol.expand_dims(op_input, axis=0) for op_input in inputs]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please address this, the rest looks good to me now.
def global_lppooling(attrs, inputs, proto_obj): | ||
"""Performs global lp pooling on the input.""" | ||
p_value = attrs.get('p', 2) | ||
new_attrs = translation_utils._add_extra_attributes(attrs, {'global_pool': True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the indentation is bad here, two suggested way from PEP8 is
a = super_long_function_name(
xxx, xxx, xxx, xxx, xx, xxx)
# or
a = super_long_function_name(xxx, xxx, xxx,
xxx, xxx, xxx)
I suggest to use the first style here to avoid too many lines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the one I have used is the second style and it has passed pylint and it is good code readability-wise. I would ideally want the dict to be printed as name, value pairs rather than as a single long list.
def lp_pooling(attrs, inputs, proto_obj): | ||
"""LP Pooling""" | ||
p_value = attrs.get('p', 2) | ||
new_attrs = translation_utils._fix_attribute_names(attrs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here for the indentation
* add more ops * use dict.get * add list comprehensive * retrigger CI due to unrelated flaky test failure
Description
Add more operators to onnx import - Mean, Acos, Asin, Atan, Cos, Sin, Softplus, Tan, Shape, Gather. HardSigmoid, LpPool, GlobalLpPool, ReduceL1
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments