|  | 
| 10 | 10 | from typing import Any, Dict, List, Optional, Tuple | 
| 11 | 11 | 
 | 
| 12 | 12 | import torch | 
|  | 13 | +import torch._export as export | 
| 13 | 14 | from executorch import exir | 
| 14 | 15 | from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( | 
| 15 | 16 |     XnnpackFloatingPointPartitioner, | 
| @@ -145,28 +146,28 @@ def __init__( | 
| 145 | 146 | 
 | 
| 146 | 147 |         self.quantizer.set_global(self.quantization_config) | 
| 147 | 148 | 
 | 
| 148 |  | -        self.converted_program = None | 
|  | 149 | +        self.converted_graph = None | 
| 149 | 150 | 
 | 
| 150 | 151 |     def run( | 
| 151 |  | -        self, artifact: ExirExportedProgram, inputs: Optional[Tuple[torch.Tensor]] | 
|  | 152 | +        self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] | 
| 152 | 153 |     ) -> None: | 
| 153 |  | -        prepared = prepare_pt2e(artifact.exported_program.graph_module, self.quantizer) | 
|  | 154 | +        captured_graph = export.capture_pre_autograd_graph(artifact, inputs) | 
|  | 155 | +        prepared = prepare_pt2e(captured_graph, self.quantizer) | 
| 154 | 156 |         converted = convert_pt2e(prepared) | 
| 155 |  | -        artifact.exported_program._graph_module = converted | 
| 156 |  | -        self.converted_program = artifact | 
|  | 157 | +        self.converted_graph = converted | 
| 157 | 158 | 
 | 
| 158 | 159 |     @property | 
| 159 |  | -    def artifact(self) -> ExirExportedProgram: | 
| 160 |  | -        return self.converted_program | 
|  | 160 | +    def artifact(self) -> torch.fx.GraphModule: | 
|  | 161 | +        return self.converted_graph | 
| 161 | 162 | 
 | 
| 162 | 163 |     @property | 
| 163 | 164 |     def graph_module(self) -> str: | 
| 164 |  | -        return self.converted_program.exported_program.graph_module | 
|  | 165 | +        return self.converted_graph | 
| 165 | 166 | 
 | 
| 166 | 167 | 
 | 
| 167 | 168 | @register_stage | 
| 168 | 169 | class Export(Stage): | 
| 169 |  | -    def __init__(self, capture_config: Optional[CaptureConfig] = None): | 
|  | 170 | +    def __init__(self, for_quant=False, capture_config: Optional[CaptureConfig] = None): | 
| 170 | 171 |         self.capture_conf = capture_config or get_xnnpack_capture_config() | 
| 171 | 172 |         self.exir_exported_program = None | 
| 172 | 173 | 
 | 
| @@ -274,12 +275,11 @@ def __init__( | 
| 274 | 275 |         self.inputs = inputs | 
| 275 | 276 |         self.stages: Dict[str, Stage] = OrderedDict.fromkeys(list(_stages_.keys())) | 
| 276 | 277 |         self.pipeline = { | 
|  | 278 | +            self._stage_name(Quantize2): [self._stage_name(Export)], | 
| 277 | 279 |             self._stage_name(Quantize): [self._stage_name(Export)], | 
| 278 | 280 |             self._stage_name(Export): [ | 
| 279 |  | -                self._stage_name(Quantize2), | 
| 280 | 281 |                 self._stage_name(ToEdge), | 
| 281 | 282 |             ], | 
| 282 |  | -            self._stage_name(Quantize2): [self._stage_name(ToEdge)], | 
| 283 | 283 |             self._stage_name(ToEdge): [self._stage_name(Partition)], | 
| 284 | 284 |             # TODO Make this Stage optional | 
| 285 | 285 |             self._stage_name(Partition): [self._stage_name(ToExecutorch)], | 
|  | 
0 commit comments