Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions backends/arm/operator_support/embedding_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Declare operator support for ``aten.embedding`` in TOSA.

Permit embeddings with int32 indices (TOSA lacks int64 support); other dtypes
are rejected by this check.

"""

import torch

Expand All @@ -17,6 +22,8 @@

@register_tosa_support_check
class EmbeddingSupported(SupportedTOSAOperatorCheck):
"""Provide TOSA support check for ``aten.embedding``."""

targets = [exir_ops.edge.aten.embedding.default]

tosa_specs = [
Expand All @@ -27,16 +34,20 @@ class EmbeddingSupported(SupportedTOSAOperatorCheck):
def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool: # type: ignore[override, misc]
# Note aten.embedding.default requires int64 indices and TOSA does not
# support it. Int32 indices here for aten.embedding.default is ok since
# it will be decomposed into ops that can handle it.
"""Return True if the node is supported by TOSA.

PyTorch's ``aten.embedding`` typically takes int64 indices, but for
TOSA we only allow int32 indices. The export path decomposes the op so
that int32 indices are ok.

"""
if len(node.all_input_nodes) != 2:
self.reporter.report_reject(
node,
(f"Expected exactly two input nodes, got {len(node.all_input_nodes)}"),
)
return False

indices_val = node.all_input_nodes[1].meta["val"]
indices_dtype = indices_val.dtype

Expand Down
Loading