From 628444a5f23e9a62d750e262cbebe3489b13e380 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Thu, 11 Sep 2025 08:47:39 +0200 Subject: [PATCH] Arm backend: Add docstrings for operator_support/embedding_support.py Change-Id: I533687c4bd309dfd2155de1e362644ac65cb6106 Signed-off-by: Sebastian Larsson --- .../arm/operator_support/embedding_support.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/backends/arm/operator_support/embedding_support.py b/backends/arm/operator_support/embedding_support.py index 24395d56cbf..3ad17012cbb 100644 --- a/backends/arm/operator_support/embedding_support.py +++ b/backends/arm/operator_support/embedding_support.py @@ -2,7 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for ``aten.embedding`` in TOSA. +Permit embeddings with int32 indices (TOSA lacks int64 support); other dtypes +are rejected by this check. + +""" import torch @@ -17,6 +22,8 @@ @register_tosa_support_check class EmbeddingSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for ``aten.embedding``.""" + targets = [exir_ops.edge.aten.embedding.default] tosa_specs = [ @@ -27,16 +34,20 @@ class EmbeddingSupported(SupportedTOSAOperatorCheck): def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] - # Note aten.embedding.default requires int64 indices and TOSA does not - # support it. Int32 indices here for aten.embedding.default is ok since - # it will be decomposed into ops that can handle it. + """Return True if the node is supported by TOSA. + PyTorch's ``aten.embedding`` typically takes int64 indices, but for + TOSA we only allow int32 indices. The export path decomposes the op so + that int32 indices are ok. + + """ if len(node.all_input_nodes) != 2: self.reporter.report_reject( node, (f"Expected exactly two input nodes, got {len(node.all_input_nodes)}"), ) return False + indices_val = node.all_input_nodes[1].meta["val"] indices_dtype = indices_val.dtype