Skip to content

Commit 5745777

Browse files
Arm backend: Align tosa backend, partitioner and mapping with doc style (#15827)
While the files mostly had docstrings, they did not adhere to the docstring style in the rest of the arm backend. Align them with this patch. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 3ea17c9 commit 5745777

File tree

3 files changed

+88
-36
lines changed

3 files changed

+88
-36
lines changed

backends/arm/tosa/backend.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,19 @@
4343

4444

4545
def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]:
46-
"""
47-
Returns dictionary: node name -> external ids
46+
"""Assign deterministic output IDs to nodes reachable from graph outputs.
47+
48+
Args:
49+
ep_graph (Graph): FX graph produced by export preprocessing.
50+
51+
Returns:
52+
dict[str, int]: Mapping from node name to external output index.
4853
49-
Assign id to an output node of the model so we can trace it.
5054
"""
5155
node2external_id = {}
5256

5357
def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
58+
"""Walk producer graph from ``start_nodes`` and record external IDs."""
5459
q = deque(start_nodes)
5560
while q:
5661
n = q.popleft()
@@ -71,7 +76,19 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
7176

7277

7378
def _sort_outputs(graph_module: GraphModule, node_to_id_map: dict[str, int]):
79+
"""Reorder graph outputs to match ascending external IDs.
80+
81+
Args:
82+
graph_module (GraphModule): Graph to reorder in place.
83+
node_to_id_map (dict[str, int]): Mapping from node name to output index.
84+
85+
Returns:
86+
GraphModule: Updated graph module with deterministic output ordering.
87+
88+
"""
89+
7490
def _external_id(n: Node, node_2_id, fallback: int) -> int:
91+
"""Return the external ID for ``n`` or ``fallback`` when absent."""
7592
return node_2_id.get(n.name, fallback)
7693

7794
out_node = graph_module.graph.output_node()
@@ -80,6 +97,7 @@ def _external_id(n: Node, node_2_id, fallback: int) -> int:
8097

8198
# sort nodes by the key that is id
8299
def _sort_key(t: Node) -> int:
100+
"""Key function that orders outputs by external ID or position."""
83101
return _external_id(t, node_to_id_map, next(_counter))
84102

85103
orig_ord = tuple(sorted(out_list, key=_sort_key))
@@ -95,14 +113,14 @@ def _sort_key(t: Node) -> int:
95113

96114

97115
def arm_get_first_delegation_tag(graph_module) -> str:
98-
"""Return the first delegation tag from the FX graph.
116+
"""Return the first delegation tag discovered in the FX graph.
99117
100118
Args:
101-
graph_module: FX GraphModule produced by the Arm passes.
119+
graph_module (GraphModule): Module produced by Arm partitioning.
102120
103121
Returns:
104-
str: The first non-empty delegation tag found on any node, or an empty
105-
string if none is present.
122+
str: First non-empty delegation tag or an empty string when no tag is
123+
recorded.
106124
107125
"""
108126
for node in graph_module.graph.nodes:
@@ -125,6 +143,17 @@ class TOSABackend(BackendDetails):
125143

126144
@staticmethod
127145
def preprocess(edge_program: ExportedProgram, compile_specs: List[CompileSpec]):
146+
"""Convert an exported program using the provided compile specs.
147+
148+
Args:
149+
edge_program (ExportedProgram): Program generated by Torch export.
150+
compile_specs (List[CompileSpec]): Raw compile specifications from
151+
``executorch.apply_backend``.
152+
153+
Returns:
154+
PreprocessResult: Result containing serialized TOSA bytes.
155+
156+
"""
128157
return TOSABackend._preprocess(
129158
edge_program, TosaCompileSpec.from_list(compile_specs)
130159
)
@@ -142,7 +171,7 @@ def _preprocess( # noqa: C901
142171
143172
Args:
144173
edge_program (ExportedProgram): Program to lower to TOSA.
145-
compile_spec (List[CompileSpec]): Backend options. Recognized keys:
174+
compile_spec (TosaCompileSpec): Backend options. Recognized keys:
146175
- output_format: Must be "tosa".
147176
- tosa_spec: Target TOSA version/capabilities.
148177
- debug_artifact_path: Directory for debug outputs.
@@ -233,7 +262,20 @@ def _preprocess_module( # noqa: C901
233262
debug_hook: DebugHook | None,
234263
submodule_name: str | None = None,
235264
):
236-
"""Convert 'graph_module' to a tosa_graph"""
265+
"""Convert an FX ``graph_module`` to TOSA serializer calls.
266+
267+
Args:
268+
graph_module (GraphModule): Module to lower recursively.
269+
edge_program (ExportedProgram): Original exported program.
270+
compile_spec (TosaCompileSpec): Backend options with TOSA settings.
271+
tosa_graph (ts.TosaSerializer): Serializer receiving operators.
272+
debug_hook (DebugHook | None): Optional debug instrumentation.
273+
submodule_name (str | None): Name used when visiting nested blocks.
274+
275+
Raises:
276+
RuntimeError: If an FX node with an unsupported op kind is found.
277+
278+
"""
237279
tosa_spec = compile_spec.tosa_spec
238280
node_to_id_map = _annotate_external_ids(graph_module.graph)
239281
artifact_path = compile_spec.get_intermediate_path()
@@ -305,24 +347,17 @@ def _preprocess_module( # noqa: C901
305347
def filter_tosa_compile_specs(
306348
compile_spec: ArmCompileSpec,
307349
) -> TosaCompileSpec:
308-
"""
309-
Filter out the CompileSpec elements relevant for the TOSA backend.
310-
This is needed to compose a backend targetting hardware IP with the
311-
TOSABackend, since we first want to use the TOSABackend to generate
312-
the TOSA flatbuffer representation as an intermediate step. The TOSA
313-
flatbuffer can then be consumed by the backend targetting specific
314-
hardware.
350+
"""Extract the TOSA-specific settings from a composite compile spec.
315351
316352
Args:
317353
compile_spec (ArmCompileSpec): Compile specification that may
318354
include both TOSA and hardware-specific options.
319355
320356
Returns:
321357
TosaCompileSpec: TOSA-only specification ready for
322-
``TOSABackend.preprocess``.
358+
``TOSABackend.preprocess``.
323359
324360
"""
325-
326361
return (
327362
TosaCompileSpec(compile_spec.tosa_spec)
328363
.dump_intermediate_artifacts_to(compile_spec.get_intermediate_path())

backends/arm/tosa/mapping.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
65
"""Provide PyTorch-to-TOSA mapping helpers.
76
8-
Use these utilities to translate PyTorch dtypes and FX node metadata into
9-
the TOSA serializer types and shapes used during initial compilation.
7+
Use these utilities to translate PyTorch dtypes and FX node metadata into the
8+
TOSA serializer types and shapes used during initial compilation.
109
1110
"""
1211

@@ -34,18 +33,27 @@
3433

3534

3635
class TosaSpecialDtype(Enum):
37-
"""
38-
Special TOSA data types that are not natively supported in PyTorch, to be
39-
used in specific scenarios as a value in the key from meta_key().
40-
"""
36+
"""Special TOSA dtypes not natively expressed in PyTorch."""
4137

4238
INT48 = ts.DType.INT48
4339

4440
def get_tosa_dtype(self) -> ts.DType:
41+
"""Return the underlying ``ts.DType`` enumerant.
42+
43+
Returns:
44+
ts.DType: Serializer dtype associated with the enum entry.
45+
46+
"""
4547
return self.value
4648

4749
@staticmethod
4850
def meta_key() -> str:
51+
"""Return the FX ``meta`` key that stores special dtypes.
52+
53+
Returns:
54+
str: Metadata key used to encode :class:`TosaSpecialDtype`.
55+
56+
"""
4957
return "tosa_special_dtype"
5058

5159

@@ -57,7 +65,7 @@ def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any:
5765
tosa_spec (TosaSpecification): Active spec (reserved for future checks).
5866
5967
Returns:
60-
Any: Matching ``ts.DType`` enum value.
68+
ts.DType: Matching serializer dtype.
6169
6270
Raises:
6371
ValueError: If the dtype is unsupported or unknown.
@@ -95,8 +103,8 @@ def extract_tensor_meta(meta, tosa_spec: TosaSpecification):
95103
tosa_spec (TosaSpecification): Active TOSA spec for dtype mapping.
96104
97105
Returns:
98-
tuple: ``(dtype, shape, dim_order)`` where ``dtype`` is ``ts.DType``,
99-
``shape`` is ``Tuple[int, ...]``, and ``dim_order`` is ``Tuple[int, ...]``.
106+
tuple[ts.DType, tuple[int, ...], tuple[int, ...]]: Tuple containing
107+
tensor dtype, shape, and dimension order.
100108
101109
Raises:
102110
ValueError: If ``meta['val']`` is not a ``FakeTensor``.
@@ -130,12 +138,14 @@ class TosaArg:
130138
consistent structure suitable for TOSA serialization.
131139
132140
Attributes:
133-
name (str): Node name when argument is a ``torch.fx.Node``; empty otherwise.
141+
name (str): Node name when argument is a ``torch.fx.Node``; empty
142+
otherwise.
134143
dtype (ts.DType | None): Inferred dtype when available.
135144
shape (tuple[int, ...] | None): Inferred shape when available.
136-
dim_order (tuple[int, ...] | None): Dimension order, defaulting to ``range(len(shape))``.
145+
dim_order (tuple[int, ...] | None): Dimension order, defaulting to
146+
``range(len(shape))``.
137147
special (list | None): Captured list when the argument is a sequence.
138-
number (float | int | None): Captured numeric value when given.
148+
number (float | int | None): Captured numeric value when provided.
139149
tosa_spec (TosaSpecification): Active specification used for mapping.
140150
multiple_output_name (list[str]): Output node names when node has multiple outputs; empty otherwise.
141151
"""
@@ -174,7 +184,7 @@ def __process_list(self, argument):
174184
"""Capture a sequence argument as ``special``.
175185
176186
Args:
177-
argument (Sequence): Sequence to store.
187+
argument (Sequence[Any]): Sequence to store.
178188
179189
"""
180190
self.special: list = list(argument)
@@ -194,10 +204,13 @@ def __init__(
194204
"""Initialize the argument wrapper and populate fields.
195205
196206
Args:
197-
argument (Any): One of ``torch.fx.Node``, ``Sequence``, ``int``, ``float``, ``torch.dtype``, or ``None``.
198-
tosa_spec (Optional[TosaSpecification]): Active specification; required.
207+
argument (Any): One of ``torch.fx.Node``, ``Sequence``, ``int``,
208+
``float``, ``torch.dtype``, or ``None``.
209+
tosa_spec (Optional[TosaSpecification]): Active specification;
210+
required for metadata extraction.
199211
200212
Raises:
213+
ValueError: If ``tosa_spec`` is missing or has the wrong type.
201214
RuntimeError: If ``argument`` is of an unsupported type.
202215
203216
"""

backends/arm/tosa/partitioner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
65
"""Provide a partitioner for delegating subgraphs to the TOSA backend.
76
87
Implement logic to identify and tag regions of an ``ExportedProgram`` that can
@@ -11,6 +10,7 @@
1110
- Partition graphs based on operator support and additional checks.
1211
- Prune trivial no-op partitions that would lower to empty TOSA graphs.
1312
- Tag constant data and report reasons for rejected nodes.
13+
1414
"""
1515

1616
import logging
@@ -142,6 +142,7 @@ def reject_partition(
142142
partition (object): Proposed partition object from the
143143
capability partitioner.
144144
reporter (WhyNoPartitionReporter): used to report why nodes were rejected.
145+
145146
"""
146147
for node in partition.nodes:
147148
if "delegation_tag" in node.meta:
@@ -158,6 +159,7 @@ class TOSAPartitioner(Partitioner):
158159
Construct this partitioner for compile specs targeting TOSA. The partition
159160
algorithm uses capability checks and optional additional operator-support
160161
rules to tag nodes with a delegation tag per subgraph.
162+
161163
"""
162164

163165
def __init__(
@@ -191,14 +193,16 @@ def _tag_module( # noqa
191193
reporter: WhyNoPartitionReporter,
192194
tag_iterator: count | None = None,
193195
) -> set[str]:
194-
"""Tag nodes in a module, possibly a submodule, from the containing program.
196+
"""Tag nodes in a module or submodule from the containing program.
195197
196198
Args:
197199
module: A GraphModule from `containing_program` to tag nodes in.
198200
containing_program: The ExportedProgram that contains the module.
199201
reporter: A reporter to report why nodes were rejected.
202+
200203
Returns:
201204
A set of strings with the partition tags.
205+
202206
"""
203207
tags: set[str] = set()
204208
if tag_iterator is None:

0 commit comments

Comments
 (0)