|
10 | 10 | from typing import Any, Dict, List, Optional, Tuple |
11 | 11 |
|
12 | 12 | import torch |
| 13 | +import torch._export as export |
13 | 14 | from executorch import exir |
14 | 15 | from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( |
15 | 16 | XnnpackFloatingPointPartitioner, |
@@ -145,23 +146,23 @@ def __init__( |
145 | 146 |
|
146 | 147 | self.quantizer.set_global(self.quantization_config) |
147 | 148 |
|
148 | | - self.converted_program = None |
| 149 | + self.converted_graph = None |
149 | 150 |
|
150 | 151 | def run( |
151 | | - self, artifact: ExirExportedProgram, inputs: Optional[Tuple[torch.Tensor]] |
| 152 | + self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] |
152 | 153 | ) -> None: |
153 | | - prepared = prepare_pt2e(artifact.exported_program.graph_module, self.quantizer) |
| 154 | + captured_graph = export.capture_pre_autograd_graph(artifact, inputs) |
| 155 | + prepared = prepare_pt2e(captured_graph, self.quantizer) |
154 | 156 | converted = convert_pt2e(prepared) |
155 | | - artifact.exported_program._graph_module = converted |
156 | | - self.converted_program = artifact |
| 157 | + self.converted_graph = converted |
157 | 158 |
|
158 | 159 | @property |
159 | | - def artifact(self) -> ExirExportedProgram: |
160 | | - return self.converted_program |
| 160 | + def artifact(self) -> torch.fx.GraphModule: |
| 161 | + return self.converted_graph |
161 | 162 |
|
162 | 163 | @property |
163 | 164 | def graph_module(self) -> str: |
164 | | - return self.converted_program.exported_program.graph_module |
| 165 | + return self.converted_graph |
165 | 166 |
|
166 | 167 |
|
167 | 168 | @register_stage |
@@ -274,12 +275,11 @@ def __init__( |
274 | 275 | self.inputs = inputs |
275 | 276 | self.stages: Dict[str, Stage] = OrderedDict.fromkeys(list(_stages_.keys())) |
276 | 277 | self.pipeline = { |
| 278 | + self._stage_name(Quantize2): [self._stage_name(Export)], |
277 | 279 | self._stage_name(Quantize): [self._stage_name(Export)], |
278 | 280 | self._stage_name(Export): [ |
279 | | - self._stage_name(Quantize2), |
280 | 281 | self._stage_name(ToEdge), |
281 | 282 | ], |
282 | | - self._stage_name(Quantize2): [self._stage_name(ToEdge)], |
283 | 283 | self._stage_name(ToEdge): [self._stage_name(Partition)], |
284 | 284 | # TODO Make this Stage optional |
285 | 285 | self._stage_name(Partition): [self._stage_name(ToExecutorch)], |
|
0 commit comments