Skip to content

Commit

Permalink
Move EmbeddingBag resolving back to front phase
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Jun 5, 2020
1 parent 9dcee86 commit 4da1a30
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 263 deletions.
1 change: 0 additions & 1 deletion model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,6 @@ extensions/middle/DilatedConvolution.py
extensions/middle/EltwiseChecker.py
extensions/middle/EltwiseInputNormalization.py
extensions/middle/EltwiseInputReshape.py
extensions/middle/EmbeddingBagResolver.py
extensions/middle/FakeSplitOutputs.py
extensions/middle/FusedBatchNormNonConstant.py
extensions/middle/FusedBatchNormTraining.py
Expand Down
59 changes: 53 additions & 6 deletions model-optimizer/extensions/front/ATenToEmbeddingBag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,18 @@
limitations under the License.
"""

from extensions.ops.embedding_bag import ATenEmbeddingBag
from extensions.ops.embedding_bag import EmbeddingBagOffsetsSum, EmbeddingBagPackedSum
from extensions.ops.rank import Rank
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementPattern
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_node
from mo.ops.broadcast import Broadcast
from mo.ops.concat import Concat
from mo.ops.shape import Shape
from mo.ops.unsqueeze import Unsqueeze
from mo.utils.shape import node_to_get_shape_value_of_indices, get_canonical_axis_index_node, \
get_shape_values_by_indices_node


class AtenToEmbeddingBag(FrontReplacementPattern):
Expand All @@ -27,14 +36,52 @@ class AtenToEmbeddingBag(FrontReplacementPattern):

def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'):
assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \
'mode is supported for node {}.'.format(node.id)
node_name = node.soft_get('name', node.id)
rename_node(node, node_name + '/TBR')
embedding_bag = ATenEmbeddingBag(graph, {'name': node_name, 'mode': node.soft_get('mode', 1)}).create_node()
is_packed = False
if len(node.in_ports()) < 3 or node.in_port(2).disconnected():
is_packed = True
embedding_bag = EmbeddingBagPackedSum(graph, {'name': node_name}).create_node()
else:
embedding_bag = EmbeddingBagOffsetsSum(graph, {'name': node_name}).create_node()
node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2))
rename_node(embedding_bag, node_name)
node.in_port(0).get_connection().set_destination(embedding_bag.in_port(0))
node.in_port(1).get_connection().set_destination(embedding_bag.in_port(1))
if node.is_in_port_connected(2):
node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2))
if node.is_in_port_connected(3):
node.in_port(3).get_connection().set_destination(embedding_bag.in_port(3))
node.out_port(0).get_connection().set_source(embedding_bag.out_port(0))
if len(node.in_ports()) == 4 and not node.in_port(3).disconnected():
if is_packed:
node.in_port(3).get_connection().set_destination(embedding_bag.in_port(2))
else:
# connect per_sample_weights
node.in_port(3).get_connection().set_destination(embedding_bag.in_port(4))

weights_shape_node = Shape(graph, {'name': node_name + '/WeightsShape'}).create_node()

weights_rank_node = Rank(graph, {'name': node_name + '/WeightsRank'}).create_node()
last_dim_node = get_canonical_axis_index_node(weights_rank_node, -1)
weights_last_dim = get_shape_values_by_indices_node(weights_shape_node, last_dim_node)

weights_first_dim = node_to_get_shape_value_of_indices(weights_shape_node, [0])

zero_col_node = create_op_with_const_inputs(graph, Broadcast, {0: int64_array([0])},
{'name': node_name + '/Broadcast'})
zero_col_node.in_port(1).connect(weights_last_dim.out_port(0))

default_embeddings_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
{'name': node_name + '/Unsqueeze'})
default_embeddings_node.in_port(0).connect(zero_col_node.out_port(0))

# expand embedding table with zeros
weights_concat = Concat(graph, {'axis': 0, 'in_ports_count': 2,
'name': node_name + '/Concat'}).create_node()
embedding_bag.in_port(0).get_connection().set_destination(weights_concat.in_port(0))
weights_concat.in_port(0).get_connection().add_destination(weights_shape_node.in_port(0))
weights_concat.in_port(0).get_connection().add_destination(weights_rank_node.in_port(0))
weights_concat.in_port(1).connect(default_embeddings_node.out_port(0))
weights_concat.out_port(0).connect(embedding_bag.in_port(0))

# point default index to expanded part of embedding table
weights_first_dim.out_port(0).connect(embedding_bag.in_port(3))
163 changes: 132 additions & 31 deletions model-optimizer/extensions/front/ATenToEmbeddingBag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,147 @@

import unittest

import numpy as np

from extensions.front.ATenToEmbeddingBag import AtenToEmbeddingBag
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph

nodes_attributes = {
'weights_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'indices_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'offsets_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'aten': {'type': None, 'kind': 'op', 'op': 'ATen', 'mode': 0, 'operator': 'embedding_bag', 'name': 'my_aten'},
'result': {'type': 'Result', 'value': None, 'kind': 'op', 'op': 'Result'},

# new EmbeddingBag layer
'emb_bag': {'type': None, 'kind': 'op', 'op': 'ATenEmbeddingBag', 'mode': 0},
}
from mo.utils.unittest.graph import build_graph, result, \
regular_op, const


class AtenToEmbeddingBagTest(unittest.TestCase):
def test(self):
graph = build_graph(nodes_attributes,
[('weights_inp', 'aten', {'in': 0, 'out': 0}),
('indices_inp', 'aten', {'in': 1, 'out': 0}),
('offsets_inp', 'aten', {'in': 2, 'out': 0}),
('aten', 'result', {'in': 0, 'out': 0}),
],
{}, nodes_with_edges_only=True)

graph_ref = build_graph(nodes_attributes,
[('weights_inp', 'emb_bag', {'in': 0, 'out': 0}),
('indices_inp', 'emb_bag', {'in': 1, 'out': 0}),
('offsets_inp', 'emb_bag', {'in': 2, 'out': 0}),
('emb_bag', 'result', {'in': 0, 'out': 0}),
],
{}, nodes_with_edges_only=True)
nodes = {
**const('weights_inp', np.random.randn(100, 2)),
**regular_op('indices_inp', {'type': 'Parameter'}),
**regular_op('offsets_inp', {'type': 'Parameter'}),
**regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0,
'name': 'my_aten'}),

**regular_op('emb_bag', {'type': 'EmbeddingBagOffsetsSum', 'kind': 'op', 'op': 'EmbeddingBagOffsetsSum'}),
**result('result'),
}
edges = [('weights_inp', 'aten'),
('indices_inp', 'aten'),
('offsets_inp', 'aten'),
('aten', 'result'),
]
graph = build_graph(nodes, edges)

graph.graph['layout'] = 'NCHW'
graph.stage = 'front'

edges_ref = [('weights_inp', 'emb_bag'),
('indices_inp', 'emb_bag'),
('offsets_inp', 'emb_bag'),
('emb_bag', 'result'),
]

graph_ref = build_graph(nodes, edges_ref)

AtenToEmbeddingBag().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

def test_packed(self):
nodes = {
**const('weights_inp', np.random.randn(100, 4)),
**regular_op('indices_inp', {'type': 'Parameter'}),
**regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0,
'name': 'my_aten'}),

**regular_op('emb_bag', {'type': 'EmbeddingBagPackedSum', 'kind': 'op',
'op': 'EmbeddingBagPackedSum'}),
**result('result'),
}
edges = [('weights_inp', 'aten'),
('indices_inp', 'aten'),
('aten', 'result'),
]
graph = build_graph(nodes, edges)

graph.graph['layout'] = 'NCHW'
graph.stage = 'front'

replacer = AtenToEmbeddingBag()
replacer.find_and_replace_pattern(graph)
edges_ref = [('weights_inp', 'emb_bag'),
('indices_inp', 'emb_bag'),
('emb_bag', 'result'),
]

graph_ref = build_graph(nodes, edges_ref)

AtenToEmbeddingBag().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

def test_per_sample_weights(self):
nodes = {
**const('weights_inp', np.random.randn(100, 2)),
**regular_op('indices_inp', {'type': 'Parameter'}),
**regular_op('offsets_inp', {'type': 'Parameter'}),
**regular_op('per_sample_weights', {'type': 'Parameter'}),
**regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0,
'name': 'my_aten'}),

**regular_op('emb_bag', {'type': 'EmbeddingBagOffsetsSum', 'kind': 'op',
'op': 'EmbeddingBagOffsetsSum'}),
**regular_op('WeightsRank', {'type': None, 'kind': 'op', 'op': 'Rank'}),
**regular_op('WeightsRank/axis', {'type': 'Add', 'kind': 'op', 'op': 'Add'}),
**regular_op('gather1', {'type': 'Gather', 'kind': 'op', 'op': 'Gather'}),
**regular_op('gather2', {'type': 'Gather', 'kind': 'op', 'op': 'Gather'}),
**regular_op('WeightsShape', {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'}),
**regular_op('Broadcast', {'type': 'Broadcast', 'kind': 'op', 'op': 'Broadcast'}),
**regular_op('Unsqueeze', {'type': 'Unsqueeze', 'kind': 'op', 'op': 'Unsqueeze'}),
**const('WeightsShape/Axis', int64_array(0)),
**const('zero1', int64_array(0)),
**const('zero2', int64_array(0)),
**const('Unsqueeze/value', int64_array(0)),
**const('Broadcast/value', int64_array(0)),
**const('neg', int64_array(-1)),
**regular_op('Concat', {'type': 'Concat', 'kind': 'op', 'op': 'Concat'}),
**result('result'),
}
edges = [('weights_inp', 'aten'),
('indices_inp', 'aten'),
('offsets_inp', 'aten'),
('per_sample_weights', 'aten'),
('aten', 'result'),
]
graph = build_graph(nodes, edges, nodes_with_edges_only=True)

graph.graph['layout'] = 'NCHW'
graph.stage = 'front'

edges_ref = [('weights_inp', 'Concat', {'in': 0, 'out': 0}),
('weights_inp', 'WeightsShape', {'in': 0, 'out': 0}),
('weights_inp', 'WeightsRank', {'in': 0, 'out': 0}),
('WeightsRank', 'WeightsRank/axis'),
('neg', 'WeightsRank/axis'),
('WeightsShape', 'gather1', {'in': 0, 'out': 0}),
('WeightsRank/axis', 'gather1'),
('WeightsShape/Axis', 'gather1'),
('WeightsShape', 'gather2', {'in': 0, 'out': 0}),
('zero1', 'gather2'),
('zero2', 'gather2'),
('Broadcast/value', 'Broadcast'),
('gather1', 'Broadcast'),
('Broadcast', 'Unsqueeze'),
('Unsqueeze/value', 'Unsqueeze'),
('Unsqueeze', 'Concat'),
('Concat', 'emb_bag'),
('indices_inp', 'emb_bag'),
('offsets_inp', 'emb_bag'),
('gather2', 'emb_bag'),
('per_sample_weights', 'emb_bag'),
('emb_bag', 'result'),
]

graph_ref = build_graph(nodes, edges_ref, nodes_with_edges_only=True)

AtenToEmbeddingBag().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(graph.node[graph.get_nodes_with_attributes(op='ATenEmbeddingBag')[0]]['name'] == 'my_aten')
69 changes: 0 additions & 69 deletions model-optimizer/extensions/middle/EmbeddingBagResolver.py

This file was deleted.

Loading

0 comments on commit 4da1a30

Please sign in to comment.