33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6- import copy
7-
86from executorch .backends .nxp .edge_passes .move_auxiliary_operator_into_separate_qdq_cluster_pass import (
97 MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass ,
108 MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass ,
119)
1210from executorch .backends .nxp .edge_passes .neutron_edge_pass import NeutronEdgePass
13-
14- from executorch .backends .nxp .edge_passes .remove_io_quant_ops_pass import (
15- RemoveIOQuantOpsPass ,
16- )
17- from executorch .exir import EdgeProgramManager
18- from executorch .exir .program ._program import (
19- _get_updated_graph_signature ,
20- _get_updated_range_constraints ,
21- )
22-
23- from torch import nn
24- from torch .export import ExportedProgram
25- from torch .fx .passes .infra .pass_base import PassResult
2611from torch .fx .passes .infra .pass_manager import PassManager
2712
2813
2914class NeutronEdgePassManager (PassManager ):
3015
31- def __init__ (
32- self , passes : list [NeutronEdgePass ] = None , remove_io_quant_ops : bool = False
33- ):
16+ def __init__ (self , passes : list [NeutronEdgePass ] = None ):
3417 passes : list [NeutronEdgePass ] = passes or [
3518 MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass (),
3619 MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass (),
@@ -40,63 +23,3 @@ def __init__(
4023 passes ,
4124 steps = 10 , # Empirical value. At most 10 cycles of passes will be run.
4225 )
43-
44- self .remove_io_quant_ops = remove_io_quant_ops
45-
46- def _transform_graph_module (self , module : nn .Module ) -> PassResult :
47- """Apply the passes to a single graph module."""
48- pass_result : PassResult = super ().__call__ (module )
49-
50- graph_module = pass_result .graph_module
51- graph_module .graph .eliminate_dead_code ()
52- graph_module .recompile ()
53-
54- return pass_result
55-
56- def __call__ (self , epm : EdgeProgramManager ) -> EdgeProgramManager :
57- """Apply the passes to all graph modules in the edge program."""
58- new_programs : dict [str , ExportedProgram ] = {}
59-
60- for name , program in epm ._edge_programs .items ():
61- pass_result = self ._transform_graph_module (program .graph_module )
62-
63- if pass_result .modified :
64- # Create a new exported program.
65- new_program = ExportedProgram (
66- root = pass_result .graph_module ,
67- graph = pass_result .graph_module .graph ,
68- graph_signature = _get_updated_graph_signature (
69- program .graph_signature , pass_result .graph_module
70- ),
71- state_dict = program .state_dict ,
72- range_constraints = _get_updated_range_constraints (
73- pass_result .graph_module
74- ),
75- module_call_graph = copy .deepcopy (program ._module_call_graph ),
76- example_inputs = program .example_inputs ,
77- constants = program .constants ,
78- verifiers = [program .verifier ],
79- )
80- new_program .graph_module .meta .update (program .graph_module .meta )
81- new_program .graph_module .meta .update (pass_result .graph_module .meta )
82-
83- else :
84- # Keep the old exported program.
85- new_program = program
86-
87- new_programs [name ] = new_program
88-
89- result = epm
90-
91- if len (new_programs ) > 0 :
92- # Use a new EdgeProgramManager with the updated programs if any update was performed.
93- result = EdgeProgramManager (
94- new_programs , copy .deepcopy (epm ._config_methods ), epm .compile_config
95- )
96-
97- if self .remove_io_quant_ops :
98- result = result .transform (
99- [RemoveIOQuantOpsPass (edge_program_manager = result )]
100- )
101-
102- return result
0 commit comments