4343
4444
4545def _annotate_external_ids (ep_graph : Graph ) -> Dict [str , int ]:
46- """
47- Returns dictionary: node name -> external ids
46+ """Assign deterministic output IDs to nodes reachable from graph outputs.
47+
48+ Args:
49+ ep_graph (Graph): FX graph produced by export preprocessing.
50+
51+ Returns:
52+ dict[str, int]: Mapping from node name to external output index.
4853
49- Assign id to an output node of the model so we can trace it.
5054 """
5155 node2external_id = {}
5256
5357 def bfs_mark (start_nodes : List [Node ], idx : int , seen : Set [Node ]):
58+ """Walk producer graph from ``start_nodes`` and record external IDs."""
5459 q = deque (start_nodes )
5560 while q :
5661 n = q .popleft ()
@@ -71,7 +76,19 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
7176
7277
7378def _sort_outputs (graph_module : GraphModule , node_to_id_map : dict [str , int ]):
79+ """Reorder graph outputs to match ascending external IDs.
80+
81+ Args:
82+ graph_module (GraphModule): Graph to reorder in place.
83+ node_to_id_map (dict[str, int]): Mapping from node name to output index.
84+
85+ Returns:
86+ GraphModule: Updated graph module with deterministic output ordering.
87+
88+ """
89+
7490 def _external_id (n : Node , node_2_id , fallback : int ) -> int :
91+ """Return the external ID for ``n`` or ``fallback`` when absent."""
7592 return node_2_id .get (n .name , fallback )
7693
7794 out_node = graph_module .graph .output_node ()
@@ -80,6 +97,7 @@ def _external_id(n: Node, node_2_id, fallback: int) -> int:
8097
8198 # sort nodes by the key that is id
8299 def _sort_key (t : Node ) -> int :
100+ """Key function that orders outputs by external ID or position."""
83101 return _external_id (t , node_to_id_map , next (_counter ))
84102
85103 orig_ord = tuple (sorted (out_list , key = _sort_key ))
@@ -95,14 +113,14 @@ def _sort_key(t: Node) -> int:
95113
96114
97115def arm_get_first_delegation_tag (graph_module ) -> str :
98- """Return the first delegation tag from the FX graph.
116+ """Return the first delegation tag discovered in the FX graph.
99117
100118 Args:
101- graph_module: FX GraphModule produced by the Arm passes .
119+ graph_module ( GraphModule): Module produced by Arm partitioning .
102120
103121 Returns:
104- str: The first non-empty delegation tag found on any node, or an empty
105- string if none is present .
122+ str: First non-empty delegation tag or an empty string when no tag is
123+ recorded .
106124
107125 """
108126 for node in graph_module .graph .nodes :
@@ -125,6 +143,17 @@ class TOSABackend(BackendDetails):
125143
126144 @staticmethod
127145 def preprocess (edge_program : ExportedProgram , compile_specs : List [CompileSpec ]):
146+ """Convert an exported program using the provided compile specs.
147+
148+ Args:
149+ edge_program (ExportedProgram): Program generated by Torch export.
150+ compile_specs (List[CompileSpec]): Raw compile specifications from
151+ ``executorch.apply_backend``.
152+
153+ Returns:
154+ PreprocessResult: Result containing serialized TOSA bytes.
155+
156+ """
128157 return TOSABackend ._preprocess (
129158 edge_program , TosaCompileSpec .from_list (compile_specs )
130159 )
@@ -142,7 +171,7 @@ def _preprocess( # noqa: C901
142171
143172 Args:
144173 edge_program (ExportedProgram): Program to lower to TOSA.
145- compile_spec (List[CompileSpec] ): Backend options. Recognized keys:
174+ compile_spec (TosaCompileSpec ): Backend options. Recognized keys:
146175 - output_format: Must be "tosa".
147176 - tosa_spec: Target TOSA version/capabilities.
148177 - debug_artifact_path: Directory for debug outputs.
@@ -233,7 +262,20 @@ def _preprocess_module( # noqa: C901
233262 debug_hook : DebugHook | None ,
234263 submodule_name : str | None = None ,
235264 ):
236- """Convert 'graph_module' to a tosa_graph"""
265+ """Convert an FX ``graph_module`` to TOSA serializer calls.
266+
267+ Args:
268+ graph_module (GraphModule): Module to lower recursively.
269+ edge_program (ExportedProgram): Original exported program.
270+ compile_spec (TosaCompileSpec): Backend options with TOSA settings.
271+ tosa_graph (ts.TosaSerializer): Serializer receiving operators.
272+ debug_hook (DebugHook | None): Optional debug instrumentation.
273+ submodule_name (str | None): Name used when visiting nested blocks.
274+
275+ Raises:
276+ RuntimeError: If an FX node with an unsupported op kind is found.
277+
278+ """
237279 tosa_spec = compile_spec .tosa_spec
238280 node_to_id_map = _annotate_external_ids (graph_module .graph )
239281 artifact_path = compile_spec .get_intermediate_path ()
@@ -305,24 +347,17 @@ def _preprocess_module( # noqa: C901
305347 def filter_tosa_compile_specs (
306348 compile_spec : ArmCompileSpec ,
307349 ) -> TosaCompileSpec :
308- """
309- Filter out the CompileSpec elements relevant for the TOSA backend.
310- This is needed to compose a backend targetting hardware IP with the
311- TOSABackend, since we first want to use the TOSABackend to generate
312- the TOSA flatbuffer representation as an intermediate step. The TOSA
313- flatbuffer can then be consumed by the backend targetting specific
314- hardware.
350+ """Extract the TOSA-specific settings from a composite compile spec.
315351
316352 Args:
317353 compile_spec (ArmCompileSpec): Compile specification that may
318354 include both TOSA and hardware-specific options.
319355
320356 Returns:
321357 TosaCompileSpec: TOSA-only specification ready for
322- ``TOSABackend.preprocess``.
358+ ``TOSABackend.preprocess``.
323359
324360 """
325-
326361 return (
327362 TosaCompileSpec (compile_spec .tosa_spec )
328363 .dump_intermediate_artifacts_to (compile_spec .get_intermediate_path ())
0 commit comments