|
7 | 7 |
|
8 | 8 | import logging |
9 | 9 |
|
10 | | -import os |
11 | 10 | from collections import Counter |
12 | 11 | from pprint import pformat |
13 | 12 | from typing import ( |
|
42 | 41 | ) |
43 | 42 | from executorch.backends.arm.test.runner_utils import ( |
44 | 43 | dbg_tosa_fb_to_json, |
45 | | - get_elf_path, |
46 | 44 | get_output_quantization_params, |
47 | | - get_target_board, |
48 | | - run_target, |
49 | 45 | TosaReferenceModelDispatch, |
50 | 46 | ) |
51 | 47 |
|
52 | 48 | from executorch.backends.arm.test.tester.analyze_output_utils import ( |
53 | 49 | dump_error_output, |
54 | 50 | print_error_diffs, |
55 | 51 | ) |
| 52 | +from executorch.backends.arm.test.tester.serialize import Serialize |
56 | 53 | from executorch.backends.arm.tosa import TosaSpecification |
57 | 54 | from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec |
58 | 55 | from executorch.backends.arm.tosa.mapping import extract_tensor_meta |
|
90 | 87 |
|
91 | 88 | from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec |
92 | 89 | from torch.fx import Graph |
93 | | -from torch.utils._pytree import tree_flatten |
94 | 90 |
|
95 | 91 |
|
96 | 92 | logger = logging.getLogger(__name__) |
@@ -179,43 +175,6 @@ def run( |
179 | 175 | ) |
180 | 176 |
|
181 | 177 |
|
182 | | -class Serialize(tester.Serialize): |
183 | | - def __init__(self, compile_spec: ArmCompileSpec, timeout): |
184 | | - super().__init__() |
185 | | - self.timeout = timeout |
186 | | - self.executorch_program_manager: ExecutorchProgramManager | None |
187 | | - self.compile_spec = compile_spec |
188 | | - |
189 | | - def run(self, artifact: ExecutorchProgramManager, inputs=None) -> None: |
190 | | - super().run(artifact, inputs) |
191 | | - # Keep the entire ExecutorchProgramManager for execution. |
192 | | - self.executorch_program_manager = artifact |
193 | | - |
194 | | - def run_artifact(self, inputs): |
195 | | - if self.executorch_program_manager is None: |
196 | | - raise RuntimeError( |
197 | | - "Tried running artifact from Serialize stage without running the stage." |
198 | | - ) |
199 | | - inputs_flattened, _ = tree_flatten(inputs) |
200 | | - intermediate_path = self.compile_spec.get_intermediate_path() |
201 | | - target_board = get_target_board(self.compile_spec) |
202 | | - elf_path = get_elf_path(target_board) |
203 | | - |
204 | | - if not os.path.exists(elf_path): |
205 | | - raise FileNotFoundError( |
206 | | - f"Did not find build arm_executor_runner in path {elf_path}, run setup_testing.sh?" |
207 | | - ) |
208 | | - |
209 | | - return run_target( |
210 | | - self.executorch_program_manager, |
211 | | - inputs_flattened, |
212 | | - intermediate_path, |
213 | | - target_board, |
214 | | - elf_path, |
215 | | - self.timeout, |
216 | | - ) |
217 | | - |
218 | | - |
219 | 178 | class ToExecutorch(tester.ToExecutorch): |
220 | 179 | def run_artifact(self, inputs): |
221 | 180 | with TosaReferenceModelDispatch(): |
@@ -419,7 +378,11 @@ def serialize( |
419 | 378 | self, serialize_stage: Optional[Serialize] = None, timeout: int = 480 |
420 | 379 | ): |
421 | 380 | if serialize_stage is None: |
422 | | - serialize_stage = Serialize(self.compile_spec, timeout) |
| 381 | + serialize_stage = Serialize( |
| 382 | + compile_spec=self.compile_spec, |
| 383 | + module=self.original_module, |
| 384 | + timeout=timeout, |
| 385 | + ) |
423 | 386 | assert ( |
424 | 387 | self.compile_spec.get_intermediate_path() is not None |
425 | 388 | ), "Can't dump serialized file when compile specs do not contain an artifact path." |
|
0 commit comments