From 4da1a3024b8e81b7134decc6b759013c48d7da90 Mon Sep 17 00:00:00 2001 From: "Maxim, Vafin" Date: Fri, 5 Jun 2020 15:24:16 +0300 Subject: [PATCH] Move EmbeddingBag resolving back to front phase --- model-optimizer/automation/package_BOM.txt | 1 - .../extensions/front/ATenToEmbeddingBag.py | 59 ++++++- .../front/ATenToEmbeddingBag_test.py | 163 ++++++++++++++---- .../extensions/middle/EmbeddingBagResolver.py | 69 -------- .../middle/EmbeddingBagResolver_test.py | 133 -------------- .../extensions/ops/embedding_bag.py | 26 +-- 6 files changed, 188 insertions(+), 263 deletions(-) delete mode 100644 model-optimizer/extensions/middle/EmbeddingBagResolver.py delete mode 100644 model-optimizer/extensions/middle/EmbeddingBagResolver_test.py diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index aaf8ce75f45a0f..630ef2abe1f1cd 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -520,7 +520,6 @@ extensions/middle/DilatedConvolution.py extensions/middle/EltwiseChecker.py extensions/middle/EltwiseInputNormalization.py extensions/middle/EltwiseInputReshape.py -extensions/middle/EmbeddingBagResolver.py extensions/middle/FakeSplitOutputs.py extensions/middle/FusedBatchNormNonConstant.py extensions/middle/FusedBatchNormTraining.py diff --git a/model-optimizer/extensions/front/ATenToEmbeddingBag.py b/model-optimizer/extensions/front/ATenToEmbeddingBag.py index 6ff4d52a82dad5..8da70dc5c0d164 100644 --- a/model-optimizer/extensions/front/ATenToEmbeddingBag.py +++ b/model-optimizer/extensions/front/ATenToEmbeddingBag.py @@ -14,9 +14,18 @@ limitations under the License. """ -from extensions.ops.embedding_bag import ATenEmbeddingBag +from extensions.ops.embedding_bag import EmbeddingBagOffsetsSum, EmbeddingBagPackedSum +from extensions.ops.rank import Rank +from mo.front.common.partial_infer.utils import int64_array from mo.front.common.replacement import FrontReplacementPattern +from mo.front.tf.graph_utils import create_op_with_const_inputs from mo.graph.graph import Graph, rename_node +from mo.ops.broadcast import Broadcast +from mo.ops.concat import Concat +from mo.ops.shape import Shape +from mo.ops.unsqueeze import Unsqueeze +from mo.utils.shape import node_to_get_shape_value_of_indices, get_canonical_axis_index_node, \ + get_shape_values_by_indices_node class AtenToEmbeddingBag(FrontReplacementPattern): @@ -27,14 +36,52 @@ class AtenToEmbeddingBag(FrontReplacementPattern): def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'): + assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \ + 'mode is supported for node {}.'.format(node.id) node_name = node.soft_get('name', node.id) rename_node(node, node_name + '/TBR') - embedding_bag = ATenEmbeddingBag(graph, {'name': node_name, 'mode': node.soft_get('mode', 1)}).create_node() + is_packed = False + if len(node.in_ports()) < 3 or node.in_port(2).disconnected(): + is_packed = True + embedding_bag = EmbeddingBagPackedSum(graph, {'name': node_name}).create_node() + else: + embedding_bag = EmbeddingBagOffsetsSum(graph, {'name': node_name}).create_node() + node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2)) rename_node(embedding_bag, node_name) node.in_port(0).get_connection().set_destination(embedding_bag.in_port(0)) node.in_port(1).get_connection().set_destination(embedding_bag.in_port(1)) - if node.is_in_port_connected(2): - node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2)) - if node.is_in_port_connected(3): - node.in_port(3).get_connection().set_destination(embedding_bag.in_port(3)) node.out_port(0).get_connection().set_source(embedding_bag.out_port(0)) + if len(node.in_ports()) == 4 and not node.in_port(3).disconnected(): + if is_packed: + node.in_port(3).get_connection().set_destination(embedding_bag.in_port(2)) + else: + # connect per_sample_weights + node.in_port(3).get_connection().set_destination(embedding_bag.in_port(4)) + + weights_shape_node = Shape(graph, {'name': node_name + '/WeightsShape'}).create_node() + + weights_rank_node = Rank(graph, {'name': node_name + '/WeightsRank'}).create_node() + last_dim_node = get_canonical_axis_index_node(weights_rank_node, -1) + weights_last_dim = get_shape_values_by_indices_node(weights_shape_node, last_dim_node) + + weights_first_dim = node_to_get_shape_value_of_indices(weights_shape_node, [0]) + + zero_col_node = create_op_with_const_inputs(graph, Broadcast, {0: int64_array([0])}, + {'name': node_name + '/Broadcast'}) + zero_col_node.in_port(1).connect(weights_last_dim.out_port(0)) + + default_embeddings_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)}, + {'name': node_name + '/Unsqueeze'}) + default_embeddings_node.in_port(0).connect(zero_col_node.out_port(0)) + + # expand embedding table with zeros + weights_concat = Concat(graph, {'axis': 0, 'in_ports_count': 2, + 'name': node_name + '/Concat'}).create_node() + embedding_bag.in_port(0).get_connection().set_destination(weights_concat.in_port(0)) + weights_concat.in_port(0).get_connection().add_destination(weights_shape_node.in_port(0)) + weights_concat.in_port(0).get_connection().add_destination(weights_rank_node.in_port(0)) + weights_concat.in_port(1).connect(default_embeddings_node.out_port(0)) + weights_concat.out_port(0).connect(embedding_bag.in_port(0)) + + # point default index to expanded part of embedding table + weights_first_dim.out_port(0).connect(embedding_bag.in_port(3)) diff --git a/model-optimizer/extensions/front/ATenToEmbeddingBag_test.py b/model-optimizer/extensions/front/ATenToEmbeddingBag_test.py index 9783c7cd08cb35..bade2da38dbc9d 100644 --- a/model-optimizer/extensions/front/ATenToEmbeddingBag_test.py +++ b/model-optimizer/extensions/front/ATenToEmbeddingBag_test.py @@ -16,46 +16,147 @@ import unittest +import numpy as np + from extensions.front.ATenToEmbeddingBag import AtenToEmbeddingBag +from mo.front.common.partial_infer.utils import int64_array from mo.utils.ir_engine.compare_graphs import compare_graphs -from mo.utils.unittest.graph import build_graph - -nodes_attributes = { - 'weights_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'indices_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'offsets_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'aten': {'type': None, 'kind': 'op', 'op': 'ATen', 'mode': 0, 'operator': 'embedding_bag', 'name': 'my_aten'}, - 'result': {'type': 'Result', 'value': None, 'kind': 'op', 'op': 'Result'}, - - # new EmbeddingBag layer - 'emb_bag': {'type': None, 'kind': 'op', 'op': 'ATenEmbeddingBag', 'mode': 0}, -} +from mo.utils.unittest.graph import build_graph, result, \ + regular_op, const class AtenToEmbeddingBagTest(unittest.TestCase): def test(self): - graph = build_graph(nodes_attributes, - [('weights_inp', 'aten', {'in': 0, 'out': 0}), - ('indices_inp', 'aten', {'in': 1, 'out': 0}), - ('offsets_inp', 'aten', {'in': 2, 'out': 0}), - ('aten', 'result', {'in': 0, 'out': 0}), - ], - {}, nodes_with_edges_only=True) - - graph_ref = build_graph(nodes_attributes, - [('weights_inp', 'emb_bag', {'in': 0, 'out': 0}), - ('indices_inp', 'emb_bag', {'in': 1, 'out': 0}), - ('offsets_inp', 'emb_bag', {'in': 2, 'out': 0}), - ('emb_bag', 'result', {'in': 0, 'out': 0}), - ], - {}, nodes_with_edges_only=True) + nodes = { + **const('weights_inp', np.random.randn(100, 2)), + **regular_op('indices_inp', {'type': 'Parameter'}), + **regular_op('offsets_inp', {'type': 'Parameter'}), + **regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0, + 'name': 'my_aten'}), + + **regular_op('emb_bag', {'type': 'EmbeddingBagOffsetsSum', 'kind': 'op', 'op': 'EmbeddingBagOffsetsSum'}), + **result('result'), + } + edges = [('weights_inp', 'aten'), + ('indices_inp', 'aten'), + ('offsets_inp', 'aten'), + ('aten', 'result'), + ] + graph = build_graph(nodes, edges) + + graph.graph['layout'] = 'NCHW' + graph.stage = 'front' + + edges_ref = [('weights_inp', 'emb_bag'), + ('indices_inp', 'emb_bag'), + ('offsets_inp', 'emb_bag'), + ('emb_bag', 'result'), + ] + + graph_ref = build_graph(nodes, edges_ref) + + AtenToEmbeddingBag().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + def test_packed(self): + nodes = { + **const('weights_inp', np.random.randn(100, 4)), + **regular_op('indices_inp', {'type': 'Parameter'}), + **regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0, + 'name': 'my_aten'}), + + **regular_op('emb_bag', {'type': 'EmbeddingBagPackedSum', 'kind': 'op', + 'op': 'EmbeddingBagPackedSum'}), + **result('result'), + } + edges = [('weights_inp', 'aten'), + ('indices_inp', 'aten'), + ('aten', 'result'), + ] + graph = build_graph(nodes, edges) graph.graph['layout'] = 'NCHW' graph.stage = 'front' - replacer = AtenToEmbeddingBag() - replacer.find_and_replace_pattern(graph) + edges_ref = [('weights_inp', 'emb_bag'), + ('indices_inp', 'emb_bag'), + ('emb_bag', 'result'), + ] + + graph_ref = build_graph(nodes, edges_ref) + + AtenToEmbeddingBag().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + def test_per_sample_weights(self): + nodes = { + **const('weights_inp', np.random.randn(100, 2)), + **regular_op('indices_inp', {'type': 'Parameter'}), + **regular_op('offsets_inp', {'type': 'Parameter'}), + **regular_op('per_sample_weights', {'type': 'Parameter'}), + **regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0, + 'name': 'my_aten'}), + + **regular_op('emb_bag', {'type': 'EmbeddingBagOffsetsSum', 'kind': 'op', + 'op': 'EmbeddingBagOffsetsSum'}), + **regular_op('WeightsRank', {'type': None, 'kind': 'op', 'op': 'Rank'}), + **regular_op('WeightsRank/axis', {'type': 'Add', 'kind': 'op', 'op': 'Add'}), + **regular_op('gather1', {'type': 'Gather', 'kind': 'op', 'op': 'Gather'}), + **regular_op('gather2', {'type': 'Gather', 'kind': 'op', 'op': 'Gather'}), + **regular_op('WeightsShape', {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'}), + **regular_op('Broadcast', {'type': 'Broadcast', 'kind': 'op', 'op': 'Broadcast'}), + **regular_op('Unsqueeze', {'type': 'Unsqueeze', 'kind': 'op', 'op': 'Unsqueeze'}), + **const('WeightsShape/Axis', int64_array(0)), + **const('zero1', int64_array(0)), + **const('zero2', int64_array(0)), + **const('Unsqueeze/value', int64_array(0)), + **const('Broadcast/value', int64_array(0)), + **const('neg', int64_array(-1)), + **regular_op('Concat', {'type': 'Concat', 'kind': 'op', 'op': 'Concat'}), + **result('result'), + } + edges = [('weights_inp', 'aten'), + ('indices_inp', 'aten'), + ('offsets_inp', 'aten'), + ('per_sample_weights', 'aten'), + ('aten', 'result'), + ] + graph = build_graph(nodes, edges, nodes_with_edges_only=True) + + graph.graph['layout'] = 'NCHW' + graph.stage = 'front' + + edges_ref = [('weights_inp', 'Concat', {'in': 0, 'out': 0}), + ('weights_inp', 'WeightsShape', {'in': 0, 'out': 0}), + ('weights_inp', 'WeightsRank', {'in': 0, 'out': 0}), + ('WeightsRank', 'WeightsRank/axis'), + ('neg', 'WeightsRank/axis'), + ('WeightsShape', 'gather1', {'in': 0, 'out': 0}), + ('WeightsRank/axis', 'gather1'), + ('WeightsShape/Axis', 'gather1'), + ('WeightsShape', 'gather2', {'in': 0, 'out': 0}), + ('zero1', 'gather2'), + ('zero2', 'gather2'), + ('Broadcast/value', 'Broadcast'), + ('gather1', 'Broadcast'), + ('Broadcast', 'Unsqueeze'), + ('Unsqueeze/value', 'Unsqueeze'), + ('Unsqueeze', 'Concat'), + ('Concat', 'emb_bag'), + ('indices_inp', 'emb_bag'), + ('offsets_inp', 'emb_bag'), + ('gather2', 'emb_bag'), + ('per_sample_weights', 'emb_bag'), + ('emb_bag', 'result'), + ] + + graph_ref = build_graph(nodes, edges_ref, nodes_with_edges_only=True) + + AtenToEmbeddingBag().find_and_replace_pattern(graph) - (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) + (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) - self.assertTrue(graph.node[graph.get_nodes_with_attributes(op='ATenEmbeddingBag')[0]]['name'] == 'my_aten') diff --git a/model-optimizer/extensions/middle/EmbeddingBagResolver.py b/model-optimizer/extensions/middle/EmbeddingBagResolver.py deleted file mode 100644 index 67e4418095a72b..00000000000000 --- a/model-optimizer/extensions/middle/EmbeddingBagResolver.py +++ /dev/null @@ -1,69 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import numpy as np - -from extensions.ops.embedding_bag import EmbeddingBagOffsetsSum, EmbeddingBagPackedSum -from mo.front.common.partial_infer.utils import int64_array -from mo.front.tf.graph_utils import create_op_with_const_inputs -from mo.graph.graph import Graph, rename_node -from mo.middle.replacement import MiddleReplacementPattern -from mo.ops.concat import Concat -from mo.ops.const import Const - - -class EmbeddingBagResolver(MiddleReplacementPattern): - """ - Converts the ATenEmbeddingBag layer to correct internal EmbeddingBag layer. - """ - enabled = True - - def find_and_replace_pattern(self, graph: Graph): - for node in graph.get_op_nodes(op='ATenEmbeddingBag'): - assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \ - 'mode is supported for node {}.'.format(node.id) - node_name = node.soft_get('name', node.id) - rename_node(node, node_name + '/TBR') - is_packed = False - if len(node.in_ports()) < 3 or node.in_port(2).disconnected(): - is_packed = True - embedding_bag = EmbeddingBagPackedSum(graph, {'name': node_name}).create_node() - else: - embedding_bag = EmbeddingBagOffsetsSum(graph, {'name': node_name}).create_node() - node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2)) - rename_node(embedding_bag, node_name) - node.in_port(0).get_connection().set_destination(embedding_bag.in_port(0)) - node.in_port(1).get_connection().set_destination(embedding_bag.in_port(1)) - node.out_port(0).get_connection().set_source(embedding_bag.out_port(0)) - if len(node.in_ports()) == 4 and not node.in_port(3).disconnected(): - if is_packed: - node.in_port(3).get_connection().set_destination(embedding_bag.in_port(2)) - else: - # connect per_sample_weights - node.in_port(3).get_connection().set_destination(embedding_bag.in_port(4)) - - weights_shape = embedding_bag.in_port(0).data.get_shape() - - # expand embedding table with zeros - default_embeddings = np.zeros([1, weights_shape[-1]]) - weights_concat = create_op_with_const_inputs(graph, Concat, {1: default_embeddings}, - {'axis': 0, 'in_ports_count': 2}) - embedding_bag.in_port(0).get_connection().set_destination(weights_concat.in_port(0)) - weights_concat.out_port(0).connect(embedding_bag.in_port(0)) - - # point default index to expanded part of embedding table - default_index = Const(graph, {'value': int64_array(weights_shape[0])}).create_node() - default_index.out_port(0).connect(embedding_bag.in_port(3)) diff --git a/model-optimizer/extensions/middle/EmbeddingBagResolver_test.py b/model-optimizer/extensions/middle/EmbeddingBagResolver_test.py deleted file mode 100644 index d964da69ccd1e0..00000000000000 --- a/model-optimizer/extensions/middle/EmbeddingBagResolver_test.py +++ /dev/null @@ -1,133 +0,0 @@ -""" - Copyright (C) 2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import unittest - -import numpy as np - -from extensions.middle.EmbeddingBagResolver import EmbeddingBagResolver -from mo.front.common.partial_infer.utils import int64_array -from mo.utils.ir_engine.compare_graphs import compare_graphs -from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, result, \ - connect - - -class AtenToEmbeddingBagTest(unittest.TestCase): - def test(self): - nodes = { - **valued_const_with_data('weights_inp', np.random.randn(100, 2)), - **regular_op_with_shaped_data('indices_inp', [20], {'type': 'Parameter'}), - **regular_op_with_shaped_data('offsets_inp', [10], {'type': 'Parameter'}), - **regular_op_with_shaped_data('aten', [10, 2], - {'type': None, 'kind': 'op', 'op': 'ATenEmbeddingBag', 'mode': 0, - 'name': 'my_aten'}), - - **regular_op_with_shaped_data('emb_bag', [10, 2], {'type': 'EmbeddingBagOffsetsSum', 'kind': 'op', - 'op': 'EmbeddingBagOffsetsSum'}), - **result('result'), - } - edges = [*connect('weights_inp', '0:aten'), - *connect('indices_inp', '1:aten'), - *connect('offsets_inp', '2:aten'), - *connect('aten', 'result'), - ] - graph = build_graph(nodes, edges) - - edges_ref = [*connect('weights_inp', '0:emb_bag'), - *connect('indices_inp', '1:emb_bag'), - *connect('offsets_inp', '2:emb_bag'), - *connect('emb_bag', 'result'), - ] - - graph_ref = build_graph(nodes, edges_ref) - - EmbeddingBagResolver().find_and_replace_pattern(graph) - - (flag, resp) = compare_graphs(graph, graph_ref, 'result') - self.assertTrue(flag, resp) - - def test_packed(self): - nodes = { - **valued_const_with_data('weights_inp', np.random.randn(100, 4)), - **regular_op_with_shaped_data('indices_inp', [10, 2], {'type': 'Parameter'}), - **regular_op_with_shaped_data('aten', [10, 4], - {'type': None, 'kind': 'op', 'op': 'ATenEmbeddingBag', 'mode': 0, - 'name': 'my_aten'}), - - **regular_op_with_shaped_data('emb_bag', [10, 4], {'type': 'EmbeddingBagPackedSum', 'kind': 'op', - 'op': 'EmbeddingBagPackedSum'}), - **result('result'), - } - edges = [*connect('weights_inp', '0:aten'), - *connect('indices_inp', '1:aten'), - *connect('aten', 'result'), - ] - graph = build_graph(nodes, edges) - - edges_ref = [*connect('weights_inp', '0:emb_bag'), - *connect('indices_inp', '1:emb_bag'), - *connect('emb_bag', 'result'), - ] - - graph_ref = build_graph(nodes, edges_ref) - - EmbeddingBagResolver().find_and_replace_pattern(graph) - - (flag, resp) = compare_graphs(graph, graph_ref, 'result') - self.assertTrue(flag, resp) - - def test_per_sample_weights(self): - nodes = { - **valued_const_with_data('weights_inp', np.random.randn(100, 2)), - **regular_op_with_shaped_data('indices_inp', [20], {'type': 'Parameter'}), - **regular_op_with_shaped_data('offsets_inp', [10], {'type': 'Parameter'}), - **regular_op_with_shaped_data('per_sample_weights', [20], {'type': 'Parameter'}), - **regular_op_with_shaped_data('aten', [10, 2], - {'type': None, 'kind': 'op', 'op': 'ATenEmbeddingBag', 'mode': 0, - 'name': 'my_aten'}), - - **regular_op_with_shaped_data('emb_bag', [10, 2], {'type': 'EmbeddingBagOffsetsSum', 'kind': 'op', - 'op': 'EmbeddingBagOffsetsSum'}), - **valued_const_with_data('zeros', np.zeros([1, 2])), - **regular_op_with_shaped_data('concat', None, {'type': 'Concat', 'kind': 'op', 'op': 'Concat'}), - 'def_index': {'kind': 'op', 'value': int64_array(100), 'shape': None, 'type': 'Const'}, - 'def_index_d': {'kind': 'data', 'value': None, 'shape': None}, - **result('result'), - } - edges = [*connect('weights_inp', '0:aten'), - *connect('indices_inp', '1:aten'), - *connect('offsets_inp', '2:aten'), - *connect('per_sample_weights', '3:aten'), - *connect('aten', 'result'), - ] - graph = build_graph(nodes, edges) - - edges_ref = [*connect('weights_inp', '0:concat'), - *connect('zeros', '1:concat'), - *connect('concat', '0:emb_bag'), - *connect('indices_inp', '1:emb_bag'), - *connect('offsets_inp', '2:emb_bag'), - *connect('def_index', '3:emb_bag'), - *connect('per_sample_weights', '4:emb_bag'), - *connect('emb_bag', 'result'), - ] - - graph_ref = build_graph(nodes, edges_ref) - - EmbeddingBagResolver().find_and_replace_pattern(graph) - - (flag, resp) = compare_graphs(graph, graph_ref, 'result') - self.assertTrue(flag, resp) diff --git a/model-optimizer/extensions/ops/embedding_bag.py b/model-optimizer/extensions/ops/embedding_bag.py index fc13dde08ae060..b65b039aaa23ac 100644 --- a/model-optimizer/extensions/ops/embedding_bag.py +++ b/model-optimizer/extensions/ops/embedding_bag.py @@ -65,7 +65,7 @@ def infer(node: Node): offsets_shape = node.in_port(2).data.get_shape() assert offsets_shape is not None and len(offsets_shape) == 1 - node.out_port(0).data.set_shape(np.concatenate((offsets_shape[:1], weights_shape[1:])).astype(np.int64)) + node.out_port(0).data.set_shape(np.concatenate((offsets_shape[:1], weights_shape[1:]))) class EmbeddingBagPackedSum(EmbeddingBagBase): @@ -87,7 +87,7 @@ def infer(node: Node): input_shape = node.in_port(1).data.get_shape() assert input_shape is not None - node.out_port(0).data.set_shape(np.concatenate((input_shape[:1], weights_shape[1:])).astype(np.int64)) + node.out_port(0).data.set_shape(np.concatenate((input_shape[:1], weights_shape[1:]))) class EmbeddingSegmentsSum(EmbeddingBagBase): @@ -113,25 +113,5 @@ def infer(node: Node): num_segments = node.in_port(3).data.get_value() assert num_segments is not None, "EmbeddingSegmentsSum should have a constant num_segments provided, but it " \ "doesn't for node: `{}`.".format(name) - output_shape = np.concatenate(([num_segments], weights_shape[1:])).astype(np.int64) + output_shape = np.concatenate(([num_segments], weights_shape[1:])) node.out_port(0).data.set_shape(output_shape) - - -class ATenEmbeddingBag(EmbeddingBagBase): - op = 'ATenEmbeddingBag' - op_type = None - version = None - in_ports_count = 4 - - @staticmethod - def infer(node: Node): - weights_shape = node.in_port(0).data.get_shape() - assert len(weights_shape) >= 2 - indices_shape = node.in_port(1).data.get_shape() - assert indices_shape is not None - if len(indices_shape) == 2: - node.out_port(0).data.set_shape(np.concatenate((indices_shape[:1], weights_shape[1:])).astype(np.int64)) - elif len(indices_shape) == 1: - offsets_shape = node.in_port(2).data.get_shape() - assert offsets_shape is not None and len(offsets_shape) == 1 - node.out_port(0).data.set_shape(np.concatenate((offsets_shape[:1], weights_shape[1:])).astype(np.int64))