|  | 
| 7 | 7 | 
 | 
| 8 | 8 | import logging | 
| 9 | 9 | 
 | 
| 10 |  | -import os | 
| 11 | 10 | from collections import Counter | 
| 12 | 11 | from pprint import pformat | 
| 13 | 12 | from typing import ( | 
|  | 
| 42 | 41 | ) | 
| 43 | 42 | from executorch.backends.arm.test.runner_utils import ( | 
| 44 | 43 |     dbg_tosa_fb_to_json, | 
| 45 |  | -    get_elf_path, | 
| 46 | 44 |     get_output_quantization_params, | 
| 47 |  | -    get_target_board, | 
| 48 |  | -    run_target, | 
| 49 | 45 |     TosaReferenceModelDispatch, | 
| 50 | 46 | ) | 
| 51 | 47 | 
 | 
| 52 | 48 | from executorch.backends.arm.test.tester.analyze_output_utils import ( | 
| 53 | 49 |     dump_error_output, | 
| 54 | 50 |     print_error_diffs, | 
| 55 | 51 | ) | 
|  | 52 | +from executorch.backends.arm.test.tester.serialize import Serialize | 
| 56 | 53 | from executorch.backends.arm.tosa import TosaSpecification | 
| 57 | 54 | from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec | 
| 58 | 55 | from executorch.backends.arm.tosa.mapping import extract_tensor_meta | 
|  | 
| 90 | 87 | 
 | 
| 91 | 88 | from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec | 
| 92 | 89 | from torch.fx import Graph | 
| 93 |  | -from torch.utils._pytree import tree_flatten | 
| 94 | 90 | 
 | 
| 95 | 91 | 
 | 
| 96 | 92 | logger = logging.getLogger(__name__) | 
| @@ -179,43 +175,6 @@ def run( | 
| 179 | 175 |         ) | 
| 180 | 176 | 
 | 
| 181 | 177 | 
 | 
| 182 |  | -class Serialize(tester.Serialize): | 
| 183 |  | -    def __init__(self, compile_spec: ArmCompileSpec, timeout): | 
| 184 |  | -        super().__init__() | 
| 185 |  | -        self.timeout = timeout | 
| 186 |  | -        self.executorch_program_manager: ExecutorchProgramManager | None | 
| 187 |  | -        self.compile_spec = compile_spec | 
| 188 |  | - | 
| 189 |  | -    def run(self, artifact: ExecutorchProgramManager, inputs=None) -> None: | 
| 190 |  | -        super().run(artifact, inputs) | 
| 191 |  | -        # Keep the entire ExecutorchProgramManager for execution. | 
| 192 |  | -        self.executorch_program_manager = artifact | 
| 193 |  | - | 
| 194 |  | -    def run_artifact(self, inputs): | 
| 195 |  | -        if self.executorch_program_manager is None: | 
| 196 |  | -            raise RuntimeError( | 
| 197 |  | -                "Tried running artifact from Serialize stage without running the stage." | 
| 198 |  | -            ) | 
| 199 |  | -        inputs_flattened, _ = tree_flatten(inputs) | 
| 200 |  | -        intermediate_path = self.compile_spec.get_intermediate_path() | 
| 201 |  | -        target_board = get_target_board(self.compile_spec) | 
| 202 |  | -        elf_path = get_elf_path(target_board) | 
| 203 |  | - | 
| 204 |  | -        if not os.path.exists(elf_path): | 
| 205 |  | -            raise FileNotFoundError( | 
| 206 |  | -                f"Did not find build arm_executor_runner in path {elf_path}, run setup_testing.sh?" | 
| 207 |  | -            ) | 
| 208 |  | - | 
| 209 |  | -        return run_target( | 
| 210 |  | -            self.executorch_program_manager, | 
| 211 |  | -            inputs_flattened, | 
| 212 |  | -            intermediate_path, | 
| 213 |  | -            target_board, | 
| 214 |  | -            elf_path, | 
| 215 |  | -            self.timeout, | 
| 216 |  | -        ) | 
| 217 |  | - | 
| 218 |  | - | 
| 219 | 178 | class ToExecutorch(tester.ToExecutorch): | 
| 220 | 179 |     def run_artifact(self, inputs): | 
| 221 | 180 |         with TosaReferenceModelDispatch(): | 
| @@ -419,7 +378,11 @@ def serialize( | 
| 419 | 378 |         self, serialize_stage: Optional[Serialize] = None, timeout: int = 480 | 
| 420 | 379 |     ): | 
| 421 | 380 |         if serialize_stage is None: | 
| 422 |  | -            serialize_stage = Serialize(self.compile_spec, timeout) | 
|  | 381 | +            serialize_stage = Serialize( | 
|  | 382 | +                compile_spec=self.compile_spec, | 
|  | 383 | +                module=self.original_module, | 
|  | 384 | +                timeout=timeout, | 
|  | 385 | +            ) | 
| 423 | 386 |         assert ( | 
| 424 | 387 |             self.compile_spec.get_intermediate_path() is not None | 
| 425 | 388 |         ), "Can't dump serialized file when compile specs do not contain an artifact path." | 
|  | 
0 commit comments