Skip to content

Commit b991c34

Browse files
authored
Merge branch 'main' into export-D87400561
2 parents c850844 + b4d72f1 commit b991c34

File tree

8 files changed

+155
-38
lines changed

8 files changed

+155
-38
lines changed

backends/arm/_passes/decompose_int16_activation_conv2d_pass.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.backends.arm._passes.arm_pass import ArmPass
1111
from executorch.backends.arm._passes.quant_args import QuantArgs
1212

13-
from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00
13+
from executorch.backends.arm.tosa.specification import get_context_spec
1414
from executorch.exir.dialects._ops import ops as exir_ops
1515
from executorch.exir.pass_base import ExportPass
1616

@@ -40,9 +40,7 @@ def call_operator(self, op, args, kwargs, meta):
4040
if args[0].data.dtype == torch.int8:
4141
return super().call_operator(op, args, kwargs, meta)
4242
elif args[0].data.dtype == torch.int16:
43-
if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension(
44-
"int16"
45-
):
43+
if not tosa_spec.support_extension("int16"):
4644
raise ValueError(
4745
"int16 activation for convolution requires TOSA int16 extension"
4846
)

backends/arm/common/arm_compile_spec.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,22 @@ class DebugMode(Enum):
3535
_OUTPUT_FORMAT_KEY = "output_format"
3636
_DEBUG_ARTIFACT_KEY = "debug_artifact_path"
3737
_DEBUG_MODE_KEY = "dump_debug_info"
38+
_OUTPUT_REORDER_KEY = "ouput_reorder_workaround"
3839

3940
def _set_compile_specs(
4041
self,
4142
tosa_spec: TosaSpecification,
4243
compiler_flags: list[str],
4344
path_for_intermediates: str | None = None,
4445
tosa_debug_mode: DebugMode | None = None,
46+
output_order_workaround: bool = True,
4547
):
4648
"""Set all values of dataclass directly."""
4749
self.tosa_spec = tosa_spec
4850
self.compiler_flags = compiler_flags
4951
self.path_for_intermediates = path_for_intermediates
5052
self.tosa_debug_mode = tosa_debug_mode
53+
self.output_order_workaround = output_order_workaround
5154

5255
@classmethod
5356
def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
@@ -56,10 +59,15 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
5659
compiler_flags: list[str] | None = None
5760
path_for_intermediates: str | None = None
5861
tosa_debug_mode: ArmCompileSpec.DebugMode | None = None
62+
output_order_workaround: bool = True
5963
unknown_specs: dict[str, str] = {}
6064
for spec in compile_specs:
6165
key = spec.key
62-
val = spec.value.decode()
66+
val = (
67+
spec.value.decode()
68+
if isinstance(spec.value, (bytes, bytearray))
69+
else spec.value
70+
)
6371
if key == ArmCompileSpec._TOSA_SPEC_KEY:
6472
if tosa_spec is not None:
6573
raise ValueError("More than one tosa_spec entry in compile spec.")
@@ -88,6 +96,8 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
8896
"More than one tosa_debug_mode entry in compile spec."
8997
)
9098
tosa_debug_mode = ArmCompileSpec.DebugMode[val]
99+
elif key == ArmCompileSpec._OUTPUT_REORDER_KEY:
100+
output_order_workaround = val # type: ignore[assignment]
91101
else:
92102
unknown_specs[key] = val
93103

@@ -109,6 +119,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
109119
compiler_flags=compiler_flags,
110120
path_for_intermediates=path_for_intermediates,
111121
tosa_debug_mode=tosa_debug_mode,
122+
output_order_workaround=output_order_workaround,
112123
)
113124
cls.from_list_hook(compile_spec, unknown_specs)
114125
compile_spec.validate()
@@ -170,6 +181,14 @@ def to_list(self):
170181
)
171182
)
172183

184+
if not self.output_order_workaround:
185+
compile_spec.append(
186+
CompileSpec(
187+
ArmCompileSpec._OUTPUT_REORDER_KEY,
188+
self.output_order_workaround,
189+
)
190+
)
191+
173192
return compile_spec
174193

175194
def get_intermediate_path(self) -> str | None:
@@ -201,6 +220,13 @@ def dump_debug_info(self, debug_mode: DebugMode | None):
201220
self.tosa_debug_mode = debug_mode
202221
return self
203222

223+
def set_output_order_workaround(self, output_order_workaround: bool):
224+
self.output_order_workaround = output_order_workaround
225+
return self
226+
227+
def get_output_order_workaround(self) -> bool:
228+
return self.output_order_workaround
229+
204230
@classmethod
205231
@abstractmethod
206232
def get_output_format(cls) -> str:

backends/arm/ethosu/backend.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# backends. Converts via TOSA as an intermediate form supported by AoT and
1010
# JIT compiler flows.
1111
#
12+
"""Ahead-of-time Arm Ethos-U backend built on the shared TOSA pipeline."""
1213

1314
import logging
1415
from typing import final, List
@@ -27,19 +28,28 @@
2728

2829
@final
2930
class EthosUBackend(BackendDetails):
30-
"""
31-
BackendDetails subclass for delegation to Ethos-U. Deduce the TOSA lowering from
32-
the compile spec list by filtering out the compile spec values that are of interest
33-
for the TOSABackend.
31+
"""BackendDetails subclass for delegation to Ethos-U.
32+
33+
Deduce the TOSA lowering from the compile spec list by filtering out the
34+
compile spec values that are of interest for the TOSABackend.
35+
3436
"""
3537

3638
@staticmethod
3739
def _compile_tosa_flatbuffer(
3840
tosa_flatbuffer: bytes, compile_spec: EthosUCompileSpec
3941
) -> bytes:
40-
"""
41-
Static helper method to do the compilation of the TOSA flatbuffer
42-
representation to a target specific binary stream.
42+
"""Compile a TOSA flatbuffer into a target-specific binary stream.
43+
44+
Args:
45+
tosa_flatbuffer (bytes): Serialized TOSA graph produced by
46+
``TOSABackend``.
47+
compile_spec (EthosUCompileSpec): Compile specification providing
48+
Vela flags and intermediate paths.
49+
50+
Returns:
51+
bytes: Target-specific binary stream produced by Vela.
52+
4353
"""
4454
compile_flags = compile_spec.compiler_flags
4555

@@ -73,6 +83,17 @@ def preprocess(
7383
edge_program: ExportedProgram,
7484
compile_specs: List[CompileSpec],
7585
) -> PreprocessResult:
86+
"""Lower the exported program and compile it for an Ethos-U target.
87+
88+
Args:
89+
edge_program (ExportedProgram): Program to lower to Ethos-U.
90+
compile_specs (List[CompileSpec]): Serialized Ethos-U compile specs
91+
supplied by the frontend.
92+
93+
Returns:
94+
PreprocessResult: Result containing the compiled Ethos-U binary.
95+
96+
"""
7697
logger.info(f"{EthosUBackend.__name__} preprocess")
7798

7899
compile_spec = EthosUCompileSpec.from_list(compile_specs)

backends/arm/test/misc/test_outputs_order.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,18 @@ def _read_tosa_outputs(tosa_path: Path):
7878
return shapes
7979

8080

81+
# TODO: MLETORCH-1266 Investigate output order issue
8182
@pytest.mark.parametrize("batch_size", [1, 4])
82-
def test_network_output_order_and_restore(batch_size):
83+
@pytest.mark.parametrize("output_order_workaround", [True, False])
84+
def test_network_output_order_and_restore(batch_size, output_order_workaround):
8385
model = Network(batch_norm=True).eval()
8486
# Prepare spec
8587
spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
86-
compile_spec = TosaCompileSpec(tosa_spec=spec)
88+
tosa_compile_spec = TosaCompileSpec(spec).set_output_order_workaround(
89+
output_order_workaround
90+
)
8791
# Setup quantizer
88-
quantizer = TOSAQuantizer(compile_spec)
92+
quantizer = TOSAQuantizer(tosa_compile_spec)
8993
quantizer.set_global(
9094
get_symmetric_quantization_config(is_qat=True, is_per_channel=False)
9195
)
@@ -100,7 +104,7 @@ def test_network_output_order_and_restore(batch_size):
100104
with tempfile.TemporaryDirectory(dir="") as tmpdir:
101105
art_dir = Path(tmpdir)
102106
part = TOSAPartitioner(
103-
TosaCompileSpec(spec).dump_intermediate_artifacts_to(str(art_dir))
107+
tosa_compile_spec.dump_intermediate_artifacts_to(str(art_dir))
104108
)
105109
_ = to_edge_transform_and_lower(aten_gm, partitioner=[part])
106110
# Expect exactly one .tosa file in the artefact dir

backends/arm/tosa/backend.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def _preprocess_module( # noqa: C901
283283
output_node.update_arg(0, [output_node.args[0]])
284284
node_to_id_map = _annotate_external_ids(graph_module.graph)
285285
artifact_path = compile_spec.get_intermediate_path()
286+
output_order_workaround = compile_spec.get_output_order_workaround()
286287

287288
# TODO: Fix the need to lazily import this.
288289
from executorch.backends.arm._passes import ArmPassManager
@@ -295,7 +296,12 @@ def _preprocess_module( # noqa: C901
295296
from executorch.backends.arm.operators.node_visitor import get_node_visitors
296297

297298
node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook)
298-
graph_module = _sort_outputs(graph_module, node_to_id_map)
299+
300+
if output_order_workaround:
301+
logger.debug("Re-sorting outputs during TOSA lowering.")
302+
graph_module = _sort_outputs(graph_module, node_to_id_map)
303+
else:
304+
logger.debug("No re-sorting outputs (workaround) during TOSA lowering.")
299305

300306
if submodule_name is not None:
301307
tosa_graph.startRegion(submodule_name)
@@ -375,4 +381,5 @@ def filter_tosa_compile_specs(
375381
TosaCompileSpec(compile_spec.tosa_spec)
376382
.dump_intermediate_artifacts_to(compile_spec.get_intermediate_path())
377383
.dump_debug_info(compile_spec.tosa_debug_mode)
384+
.set_output_order_workaround(compile_spec.output_order_workaround)
378385
)

backends/arm/tosa/specification.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,18 @@ def support_float(self) -> bool:
105105
"""Return True if floating-point operations are supported."""
106106
raise NotImplementedError
107107

108+
def support_extension(self, extension: str) -> bool:
109+
"""Return True if an extension is supported and enabled.
110+
111+
Args:
112+
extension (str): Extension name (for example ``int4``, ``bf16``).
113+
114+
Returns:
115+
bool: True if the extension is valid for the active profiles and selected.
116+
117+
"""
118+
raise NotImplementedError
119+
108120
def __init__(self, version: Version, extras: List[str]):
109121
"""Initialize the base specification.
110122

backends/arm/vgf/backend.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# this form is used where the final JIT compile is performed on target (in the
1111
# runtime delegate executorch::runtime::BackendInterface::init
1212
#
13+
"""Ahead-of-time Arm VGF backend built on the shared TOSA pipeline."""
1314

1415
import logging
1516
import os
@@ -43,9 +44,11 @@
4344

4445
@final
4546
class VgfBackend(BackendDetails):
46-
"""
47-
BackendDetails subclass for delegation to VGF compatible devices. This enables
48-
encapsulated TOSA on target device and JIT compilation on suitable platforms.
47+
"""BackendDetails subclass for delegation to VGF compatible devices.
48+
49+
This enables encapsulated TOSA on target device and JIT compilation on
50+
suitable platforms.
51+
4952
"""
5053

5154
@staticmethod
@@ -54,9 +57,18 @@ def _compile_tosa_flatbuffer(
5457
compile_spec: VgfCompileSpec,
5558
tag_name: str = "",
5659
) -> bytes:
57-
"""
58-
Static helper method to do the compilation of the TOSA flatbuffer
59-
representation to a target specific binary stream.
60+
"""Compile a TOSA flatbuffer into a target-specific binary stream.
61+
62+
Args:
63+
tosa_flatbuffer (bytes): Serialized TOSA graph produced by
64+
``TOSABackend``.
65+
compile_spec (VgfCompileSpec): Compile specification providing
66+
converter flags and artifact paths.
67+
tag_name (str): Optional suffix used when producing debug outputs.
68+
69+
Returns:
70+
bytes: Target-specific VGF binary stream.
71+
6072
"""
6173
compile_flags = compile_spec.compiler_flags
6274
artifact_path = compile_spec.get_intermediate_path()
@@ -69,6 +81,17 @@ def preprocess(
6981
edge_program: ExportedProgram,
7082
compile_specs: List[CompileSpec],
7183
) -> PreprocessResult:
84+
"""Lower the exported program and compile it for a VGF target.
85+
86+
Args:
87+
edge_program (ExportedProgram): Program to lower to VGF.
88+
compile_specs (List[CompileSpec]): Serialized VGF compile specs
89+
supplied by the frontend.
90+
91+
Returns:
92+
PreprocessResult: Result containing the compiled VGF binary.
93+
94+
"""
7295
logger.info(f"{VgfBackend.__name__} preprocess")
7396

7497
compile_spec = VgfCompileSpec.from_list(compile_specs)
@@ -98,6 +121,20 @@ def vgf_compile(
98121
artifact_path: str | None = None,
99122
tag_name: str = "",
100123
):
124+
"""Invoke the VGF compiler to convert a TOSA flatbuffer.
125+
126+
Args:
127+
tosa_flatbuffer (bytes): Serialized TOSA graph produced by
128+
``TOSABackend``.
129+
compile_flags (List[str]): Command-line flags forwarded to
130+
``model-converter``.
131+
artifact_path (str | None): Directory where debug artifacts are saved.
132+
tag_name (str): Optional suffix used when producing debug outputs.
133+
134+
Returns:
135+
bytes: Compiled VGF binary emitted by ``model-converter``.
136+
137+
"""
101138
with tempfile.TemporaryDirectory() as tmpdir:
102139

103140
# We currently write out a flatbuffer as input to the converter

backends/nxp/backend/neutron_converter_manager.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,23 +78,35 @@ def convert(self, tflite_model: bytes, target: str) -> bytes:
7878
cctx.compilationOpts.minNumOpsPerGraph = 1
7979
cctx.compilationOpts.excludeGraphPasses = "MergeTranspose"
8080

81-
logger = multiprocessing.log_to_stderr()
82-
logger.setLevel(logging.WARNING)
83-
queue = multiprocessing.Manager().Queue()
81+
# Try to use multiprocessing for isolation, but fall back to direct execution
82+
# if the environment doesn't support it (e.g., in sandcastle/build environments)
83+
try:
84+
logger = multiprocessing.log_to_stderr()
85+
logger.setLevel(logging.WARNING)
86+
queue = multiprocessing.Manager().Queue()
87+
88+
process = multiprocessing.Process(
89+
target=convert_unsafe,
90+
args=(self.neutron_converter, tflite_model, cctx, queue),
91+
)
92+
process.start()
93+
process.join() # waits until the subprocess is complete
8494

85-
process = multiprocessing.Process(
86-
target=convert_unsafe,
87-
args=(self.neutron_converter, tflite_model, cctx, queue),
88-
)
89-
process.start()
90-
process.join() # waits until the subprocess is complete
95+
if queue.empty(): # signals the unsafe task did not run till the end
96+
raise RuntimeError(
97+
f"Neutron converter module terminated unexpectedly with exit code {process.exitcode}"
98+
)
9199

92-
if queue.empty(): # signals the unsafe task did not run till the end
93-
raise RuntimeError(
94-
f"Neutron converter module terminated unexpectedly with exit code {process.exitcode}"
100+
model_converted = queue.get()
101+
process.close()
102+
except (EOFError, OSError) as e:
103+
# Multiprocessing failed (likely due to environment restrictions)
104+
# Fall back to direct execution
105+
logging.warning(
106+
f"Multiprocessing not available ({e}), running neutron converter directly"
107+
)
108+
model_converted = self.neutron_converter.convertModel(
109+
list(tflite_model), cctx
95110
)
96111

97-
model_converted = queue.get()
98-
99-
process.close()
100112
return bytes(model_converted)

0 commit comments

Comments
 (0)