|  | 
| 7 | 7 | 
 | 
| 8 | 8 | import logging | 
| 9 | 9 | 
 | 
| 10 |  | -import os | 
| 11 | 10 | from collections import Counter | 
| 12 | 11 | from pprint import pformat | 
| 13 | 12 | from typing import ( | 
|  | 
| 47 | 46 | ) | 
| 48 | 47 | from executorch.backends.arm.test.runner_utils import ( | 
| 49 | 48 |     dbg_tosa_fb_to_json, | 
| 50 |  | -    get_elf_path, | 
| 51 | 49 |     get_output_quantization_params, | 
| 52 |  | -    get_target_board, | 
| 53 |  | -    run_target, | 
| 54 | 50 |     TosaReferenceModelDispatch, | 
| 55 | 51 | ) | 
| 56 | 52 | 
 | 
| 57 | 53 | from executorch.backends.arm.test.tester.analyze_output_utils import ( | 
| 58 | 54 |     dump_error_output, | 
| 59 | 55 |     print_error_diffs, | 
| 60 | 56 | ) | 
|  | 57 | +from executorch.backends.arm.test.tester.serialize import Serialize | 
| 61 | 58 | from executorch.backends.arm.tosa import TosaSpecification | 
| 62 | 59 | from executorch.backends.arm.tosa.mapping import extract_tensor_meta | 
| 63 | 60 | from executorch.backends.arm.tosa.partitioner import TOSAPartitioner | 
|  | 
| 96 | 93 | 
 | 
| 97 | 94 | from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec | 
| 98 | 95 | from torch.fx import Graph | 
| 99 |  | -from torch.utils._pytree import tree_flatten | 
| 100 | 96 | 
 | 
| 101 | 97 | 
 | 
| 102 | 98 | logger = logging.getLogger(__name__) | 
| @@ -184,44 +180,6 @@ def run( | 
| 184 | 180 |             generate_etrecord=generate_etrecord, | 
| 185 | 181 |         ) | 
| 186 | 182 | 
 | 
| 187 |  | - | 
| 188 |  | -class Serialize(tester.Serialize): | 
| 189 |  | -    def __init__(self, compile_spec: list[CompileSpec], timeout): | 
| 190 |  | -        super().__init__() | 
| 191 |  | -        self.timeout = timeout | 
| 192 |  | -        self.executorch_program_manager: ExecutorchProgramManager | None | 
| 193 |  | -        self.compile_spec = compile_spec | 
| 194 |  | - | 
| 195 |  | -    def run(self, artifact: ExecutorchProgramManager, inputs=None) -> None: | 
| 196 |  | -        super().run(artifact, inputs) | 
| 197 |  | -        # Keep the entire ExecutorchProgramManager for execution. | 
| 198 |  | -        self.executorch_program_manager = artifact | 
| 199 |  | - | 
| 200 |  | -    def run_artifact(self, inputs): | 
| 201 |  | -        if self.executorch_program_manager is None: | 
| 202 |  | -            raise RuntimeError( | 
| 203 |  | -                "Tried running artifact from Serialize stage without running the stage." | 
| 204 |  | -            ) | 
| 205 |  | -        inputs_flattened, _ = tree_flatten(inputs) | 
| 206 |  | -        intermediate_path = get_intermediate_path(self.compile_spec) | 
| 207 |  | -        target_board = get_target_board(self.compile_spec) | 
| 208 |  | -        elf_path = get_elf_path(target_board) | 
| 209 |  | - | 
| 210 |  | -        if not os.path.exists(elf_path): | 
| 211 |  | -            raise FileNotFoundError( | 
| 212 |  | -                f"Did not find build arm_executor_runner in path {elf_path}, run setup_testing.sh?" | 
| 213 |  | -            ) | 
| 214 |  | - | 
| 215 |  | -        return run_target( | 
| 216 |  | -            self.executorch_program_manager, | 
| 217 |  | -            inputs_flattened, | 
| 218 |  | -            intermediate_path, | 
| 219 |  | -            target_board, | 
| 220 |  | -            elf_path, | 
| 221 |  | -            self.timeout, | 
| 222 |  | -        ) | 
| 223 |  | - | 
| 224 |  | - | 
| 225 | 183 | class ToExecutorch(tester.ToExecutorch): | 
| 226 | 184 |     def run_artifact(self, inputs): | 
| 227 | 185 |         with TosaReferenceModelDispatch(): | 
| @@ -423,7 +381,11 @@ def serialize( | 
| 423 | 381 |         self, serialize_stage: Optional[Serialize] = None, timeout: int = 480 | 
| 424 | 382 |     ): | 
| 425 | 383 |         if serialize_stage is None: | 
| 426 |  | -            serialize_stage = Serialize(self.compile_spec, timeout) | 
|  | 384 | +            serialize_stage = Serialize( | 
|  | 385 | +                compile_spec=self.compile_spec, | 
|  | 386 | +                module=self.original_module, | 
|  | 387 | +                timeout=timeout | 
|  | 388 | +            ) | 
| 427 | 389 |         assert ( | 
| 428 | 390 |             get_intermediate_path(self.compile_spec) is not None | 
| 429 | 391 |         ), "Can't dump serialized file when compile specs do not contain an artifact path." | 
|  | 
0 commit comments