3939from torch ._inductor .decomposition import remove_decompositions
4040from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
4141
42- from torch .export import export
4342from torch .export .exported_program import ExportedProgram
4443
4544from .passes import get_cadence_passes
4645
4746from .utils import print_ops_info
4847
4948
50- def prepare_and_convert_pt2 (
49+ def trace (
5150 model : torch .nn .Module ,
5251 inputs : tuple [object , ...],
53- quantizer : CadenceQuantizer ,
54- calibration_data : Optional [list [tuple [object , ...]]] = None ,
5552 dump_graphs : bool = False ,
56- ) -> torch . fx . GraphModule :
53+ ) -> ExportedProgram :
5754 """
58- Prepare and convert a model using the given quantizer.
59- The quantizer must be supplied and be the same as the one used to
60- fuse the model later, if applicable. If you do not expect that behavior,
61- please use quantize_and_fuse_pt2 instead, which will instantiate a
62- default quantizer for you if needed.
63- If calibration data is provided, it will be used to calibrate the model. If
64- not, the inputs will be used for calibration instead, which is useful for
65- unit tests but should not be used for end-to-end use cases.
66- Returns a GraphModule with the converted model.
55+ Trace the model with export_for_training and return an ExportedProgram.
6756 """
6857
58+ # Make the model inference mode by calling model.eval()
59+ model .eval ()
60+
61+ # Prevent mkldnn decompositions
62+ torch ._C ._set_mkldnn_enabled (False )
63+
6964 # Get default decompositions
7065 decomp_table = torch .export .default_decompositions ()
66+
7167 # Select ops to keep
7268 ops_to_keep = [
7369 torch .ops .aten .conv1d .default ,
@@ -77,19 +73,47 @@ def prepare_and_convert_pt2(
7773 torch .ops .aten .matmul .default ,
7874 torch .ops .aten .rms_norm .default ,
7975 ]
76+
8077 # Remove decompositions for the ops we want to keep
8178 # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
8279 remove_decompositions (decomp_table , ops_to_keep )
80+
8381 # Export with dynamo
84- model_gm = (
82+ ep = (
8583 torch .export .export_for_training (model , inputs , strict = True )
8684 .run_decompositions (decomp_table )
87- .module ()
8885 )
8986
9087 if dump_graphs :
9188 logging .info ("Graph before quantization:" )
92- logging .info (model_gm .graph .print_tabular ())
89+ logging .info (ep .module ().graph .print_tabular ())
90+
91+ return ep
92+
93+
94+ def prepare_and_convert_pt2 (
95+ ep : ExportedProgram ,
96+ inputs : tuple [object , ...],
97+ quantizer : CadenceQuantizer ,
98+ calibration_data : Optional [list [tuple [object , ...]]] = None ,
99+ dump_graphs : bool = False ,
100+ ) -> torch .fx .GraphModule :
101+ """
102+ Prepare and convert a model using the given quantizer.
103+ The quantizer must be supplied and be the same as the one used to
104+ fuse the model later, if applicable. If you do not expect that behavior,
105+ please use quantize_and_fuse_pt2 instead, which will instantiate a
106+ default quantizer for you if needed.
107+ If calibration data is provided, it will be used to calibrate the model. If
108+ not, the inputs will be used for calibration instead, which is useful for
109+ unit tests but should not be used for end-to-end use cases.
110+ Returns a GraphModule with the converted model.
111+ """
112+
113+ # Get the graph module from the ExportedProgram
114+ model_gm = ep .module ()
115+
116+ assert isinstance (model_gm , torch .fx .GraphModule )
93117
94118 # Prepare
95119 prepared_model = prepare_pt2e (model_gm , quantizer )
@@ -113,10 +137,10 @@ def prepare_and_convert_pt2(
113137
114138
115139# Note: this is not meant as a primary API since it can create inconsistencies
116- # if the quantizer here is different from the quantizer used to convert. It is
117- # however useful for unit tests to separate the converted model from the fused
118- # model, to be able to get reference numerics.
119- # If this does not apply, please use quantize_and_fuse_pt2 instead.
140+ # if the quantizer here is different from the quantizer used to prepare/ convert.
141+ # It is however useful for unit tests to separate the converted model from the
142+ # fused model, to be able to get reference numerics.
143+ # If this does not apply, please use quantize_pt2 instead.
120144def fuse_pt2 (
121145 converted_graph_module : torch .fx .GraphModule ,
122146 quantizer : CadenceQuantizer ,
@@ -151,16 +175,20 @@ def quantize_pt2(
151175 unit tests but should not be used for end-to-end use cases.
152176 Returns a GraphModule with the quantized model.
153177 """
154- # Make the model inference mode by calling model.eval()
155- model .eval ()
156178
157179 # Instantiate the quantizer to CadenceQuantizer if not supplied
158180 if not quantizer :
159181 quantizer = CadenceDefaultQuantizer ()
160182
183+ ep = trace (model , inputs , dump_graphs = dump_graphs )
184+
185+ if dump_graphs :
186+ logging .info ("Graph after trace:" )
187+ logging .info (ep .graph .print_tabular ())
188+
161189 # Get converted graph module
162190 converted_gm = prepare_and_convert_pt2 (
163- model , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
191+ ep , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
164192 )
165193
166194 # Get fused model
@@ -173,22 +201,6 @@ def quantize_pt2(
173201 return fused_gm
174202
175203
176- # Export the model and lower it to an ExportedProgram (in aten IR)
177- def export_program (
178- model : torch .nn .Module ,
179- inputs : tuple [object , ...],
180- ) -> ExportedProgram :
181- assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
182-
183- # Prevent mkldnn decompositions
184- torch ._C ._set_mkldnn_enabled (False )
185-
186- # Export the model and return it.
187- expo_program = export (model , inputs , strict = True )
188-
189- return expo_program
190-
191-
192204def lower_ep_to_edge (
193205 expo_program : ExportedProgram ,
194206 dump_graphs : bool = False ,
@@ -237,7 +249,7 @@ def export_to_edge(
237249 assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
238250
239251 # Export the model into an ExportedProgram.
240- expo_program = export_program (model , inputs )
252+ expo_program = trace (model , inputs )
241253
242254 # Lower the model to edge IR.
243255 edge_prog_manager = lower_ep_to_edge (expo_program , dump_graphs , constant_methods )
0 commit comments