Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Gather
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Dec 11, 2018
1 parent aca5ee5 commit 1f3f409
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
20 changes: 20 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,3 +1677,23 @@ def convert_multinomial(node, **kwargs):
name=name,
)
return [node]


@mx_op.register("take")
def convert_gather(node, **kwargs):
"""Map MXNet's size_array operator attributes to onnx's Size operator
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

axis = int(attrs.get('axis', '0'))

node = onnx.helper.make_node(
"Gather",
input_nodes,
[name],
axis=axis,
name=name
)

return [node]
3 changes: 2 additions & 1 deletion tests/python-pytest/onnx/export/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@
'test_hardsigmoid',
'test_instancenorm',
'test_shape',
'test_size'
'test_size',
'test_gather'
]

BASIC_MODEL_TESTS = [
Expand Down
3 changes: 2 additions & 1 deletion tests/python-pytest/onnx/import/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@
'test_operator_params',
'test_operator_permute2',
'test_depthtospace',
'test_size'
'test_size',
'test_gather'
]

BASIC_MODEL_TESTS = [
Expand Down

0 comments on commit 1f3f409

Please sign in to comment.