From bcae4b2ecde16b4c3ddc9f082d3469c131860a30 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Mon, 10 Dec 2018 19:03:02 -0800 Subject: [PATCH] ONNX export: Gather --- .../contrib/onnx/mx2onnx/_op_translations.py | 20 +++++++++++++++++++ .../onnx/export/onnx_backend_test.py | 3 ++- tests/python-pytest/onnx/import/test_cases.py | 3 ++- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 5f09e8cc9fe7..d9c3d8ecf7ab 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1677,3 +1677,23 @@ def convert_multinomial(node, **kwargs): name=name, ) return [node] + + +@mx_op.register("take") +def convert_gather(node, **kwargs): + """Map MXNet's size_array operator attributes to onnx's Size operator + and return the created node. + """ + name, input_nodes, attrs = get_inputs(node, kwargs) + + axis = int(attrs.get('axis', '0')) + + node = onnx.helper.make_node( + "Gather", + input_nodes, + [name], + axis=axis, + name=name + ) + + return [node] diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py b/tests/python-pytest/onnx/export/onnx_backend_test.py index c9926c4d5e15..3facd6dfa128 100644 --- a/tests/python-pytest/onnx/export/onnx_backend_test.py +++ b/tests/python-pytest/onnx/export/onnx_backend_test.py @@ -98,7 +98,8 @@ 'test_hardsigmoid', 'test_instancenorm', 'test_shape', - 'test_size' + 'test_size', + 'test_gather' ] BASIC_MODEL_TESTS = [ diff --git a/tests/python-pytest/onnx/import/test_cases.py b/tests/python-pytest/onnx/import/test_cases.py index e0b26cc49830..04cfe93b7ed7 100644 --- a/tests/python-pytest/onnx/import/test_cases.py +++ b/tests/python-pytest/onnx/import/test_cases.py @@ -86,7 +86,8 @@ 'test_operator_params', 'test_operator_permute2', 'test_depthtospace', - 'test_size' + 'test_size', + 'test_gather' ] BASIC_MODEL_TESTS = [