3939from torch ._inductor .decomposition import remove_decompositions
4040from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
4141
42- from torch .export import export
4342from torch .export .exported_program import ExportedProgram
4443
4544from .passes import get_cadence_passes
5554# however useful for unit tests to separate the converted model from the fused
5655# model, to be able to get reference numerics.
5756# If this does not apply, please use quantize_and_fuse_pt2 instead.
58- def prepare_and_convert_pt2 (
57+ def trace (
5958 model : torch .nn .Module ,
6059 inputs : tuple [object , ...],
61- quantizer : CadenceQuantizer ,
62- calibration_data : Optional [list [tuple [object , ...]]] = None ,
6360 dump_graphs : bool = False ,
64- ) -> torch . fx . GraphModule :
61+ ) -> ExportedProgram :
6562 """
66- Prepare and convert a model using the given quantizer.
67- The quantizer must be supplied and be the same as the one used to
68- fuse the model later, if applicable. If you do not expect that behavior,
69- please use quantize_and_fuse_pt2 instead, which will instantiate a
70- default quantizer for you if needed.
71- If calibration data is provided, it will be used to calibrate the model. If
72- not, the inputs will be used for calibration instead, which is useful for
73- unit tests but should not be used for end-to-end use cases.
74- Returns a GraphModule with the converted model.
63+ Trace the model with export_for_training and return an ExportedProgram.
7564 """
7665
66+ # Make the model inference mode by calling model.eval()
67+ model .eval ()
68+
69+ # Prevent mkldnn decompositions
70+ torch ._C ._set_mkldnn_enabled (False )
71+
7772 # Get default decompositions
7873 decomp_table = torch .export .default_decompositions ()
74+
7975 # Select ops to keep
8076 ops_to_keep = [
8177 torch .ops .aten .conv1d .default ,
@@ -85,19 +81,46 @@ def prepare_and_convert_pt2(
8581 torch .ops .aten .matmul .default ,
8682 torch .ops .aten .rms_norm .default ,
8783 ]
84+
8885 # Remove decompositions for the ops we want to keep
8986 # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
9087 remove_decompositions (decomp_table , ops_to_keep )
88+
9189 # Export with dynamo
92- model_gm = (
93- torch .export .export_for_training (model , inputs , strict = True )
94- .run_decompositions (decomp_table )
95- .module ()
96- )
90+ program = torch .export .export_for_training (
91+ model , inputs , strict = True
92+ ).run_decompositions (decomp_table )
9793
9894 if dump_graphs :
9995 logging .info ("Graph before quantization:" )
100- logging .info (model_gm .graph .print_tabular ())
96+ logging .info (program .module ().graph .print_tabular ())
97+
98+ return program
99+
100+
101+ def prepare_and_convert_pt2 (
102+ program : ExportedProgram ,
103+ inputs : tuple [object , ...],
104+ quantizer : CadenceQuantizer ,
105+ calibration_data : Optional [list [tuple [object , ...]]] = None ,
106+ dump_graphs : bool = False ,
107+ ) -> torch .fx .GraphModule :
108+ """
109+ Prepare and convert a model using the given quantizer.
110+ The quantizer must be supplied and be the same as the one used to
111+ fuse the model later, if applicable. If you do not expect that behavior,
112+ please use quantize_and_fuse_pt2 instead, which will instantiate a
113+ default quantizer for you if needed.
114+ If calibration data is provided, it will be used to calibrate the model. If
115+ not, the inputs will be used for calibration instead, which is useful for
116+ unit tests but should not be used for end-to-end use cases.
117+ Returns a GraphModule with the converted model.
118+ """
119+
120+ # Get the graph module from the ExportedProgram
121+ model_gm = program .module ()
122+
123+ assert isinstance (model_gm , torch .fx .GraphModule )
101124
102125 # Prepare
103126 prepared_model = prepare_pt2e (model_gm , quantizer )
@@ -121,10 +144,10 @@ def prepare_and_convert_pt2(
121144
122145
123146# Note: this is not meant as a primary API since it can create inconsistencies
124- # if the quantizer here is different from the quantizer used to convert. It is
125- # however useful for unit tests to separate the converted model from the fused
126- # model, to be able to get reference numerics.
127- # If this does not apply, please use quantize_and_fuse_pt2 instead.
147+ # if the quantizer here is different from the quantizer used to prepare/ convert.
148+ # It is however useful for unit tests to separate the converted model from the
149+ # fused model, to be able to get reference numerics.
150+ # If this does not apply, please use quantize_pt2 instead.
128151def fuse_pt2 (
129152 converted_graph_module : torch .fx .GraphModule ,
130153 quantizer : CadenceQuantizer ,
@@ -166,9 +189,15 @@ def quantize_pt2(
166189 if not quantizer :
167190 quantizer = CadenceDefaultQuantizer ()
168191
192+ program = trace (model , inputs , dump_graphs = dump_graphs )
193+
194+ if dump_graphs :
195+ logging .info ("Graph after trace:" )
196+ logging .info (program .graph .print_tabular ())
197+
169198 # Get converted graph module
170199 converted_gm = prepare_and_convert_pt2 (
171- model , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
200+ program , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
172201 )
173202
174203 # Get fused model
@@ -181,22 +210,6 @@ def quantize_pt2(
181210 return fused_gm
182211
183212
184- # Export the model and lower it to an ExportedProgram (in aten IR)
185- def export_program (
186- model : torch .nn .Module ,
187- inputs : tuple [object , ...],
188- ) -> ExportedProgram :
189- assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
190-
191- # Prevent mkldnn decompositions
192- torch ._C ._set_mkldnn_enabled (False )
193-
194- # Export the model and return it.
195- expo_program = export (model , inputs , strict = True )
196-
197- return expo_program
198-
199-
200213def lower_ep_to_edge (
201214 expo_program : ExportedProgram ,
202215 dump_graphs : bool = False ,
@@ -245,7 +258,7 @@ def export_to_edge(
245258 assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
246259
247260 # Export the model into an ExportedProgram.
248- expo_program = export_program (model , inputs )
261+ expo_program = trace (model , inputs )
249262
250263 # Lower the model to edge IR.
251264 edge_prog_manager = lower_ep_to_edge (expo_program , dump_graphs , constant_methods )
0 commit comments