File tree 3 files changed +31
-8
lines changed
3 files changed +31
-8
lines changed Original file line number Diff line number Diff line change 7
7
from torch .ao .quantization .quantize_pt2e import prepare_pt2e , convert_pt2e
8
8
from torch .ao .quantization .quantizer import QuantizationSpec , Quantizer
9
9
10
- from torch ._export import capture_pre_autograd_graph
11
10
from torch .testing ._internal .common_quantization import (
12
11
NodeSpec as ns ,
13
12
QuantizationTestCase ,
26
25
)
27
26
import copy
28
27
28
+ has_export_for_training = False
29
+ try :
30
+ from torch .export import export_for_training
31
+ has_export_for_training = True
32
+ except ImportError :
33
+ # capture_pre_autograd_graph is deprecated, it's
34
+ # left here to work with previous versions of pytorch
35
+ from torch ._export import capture_pre_autograd_graph
29
36
30
37
def _apply_weight_only_uint4_quant (model ):
31
38
def fn (mod ):
@@ -203,10 +210,16 @@ def forward(self, x):
203
210
204
211
# program capture
205
212
m = copy .deepcopy (m_eager )
206
- m = capture_pre_autograd_graph (
207
- m ,
208
- example_inputs ,
209
- )
213
+ if has_export_for_training :
214
+ m = export_for_training (
215
+ m ,
216
+ example_inputs ,
217
+ ).module ()
218
+ else :
219
+ m = capture_pre_autograd_graph (
220
+ m ,
221
+ example_inputs ,
222
+ ).module ()
210
223
211
224
m = prepare_pt2e (m , quantizer )
212
225
# Calibrate
Original file line number Diff line number Diff line change @@ -1484,11 +1484,21 @@ def forward(self, x):
1484
1484
1485
1485
# make sure it compiles
1486
1486
example_inputs = (x ,)
1487
- from torch ._export import capture_pre_autograd_graph
1488
1487
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
1489
1488
# we can re-enable this after non-functional IR is enabled in export
1490
1489
# model = torch.export.export(model, example_inputs).module()
1491
- model = capture_pre_autograd_graph (model , example_inputs )
1490
+ has_export_for_training = False
1491
+ try :
1492
+ from torch .export import export_for_training
1493
+ has_export_for_training = True
1494
+ except ImportError :
1495
+ # capture_pre_autograd_graph is deprecated, it's
1496
+ # left here to work with previous versions of pytorch
1497
+ from torch ._export import capture_pre_autograd_graph
1498
+ if has_export_for_training :
1499
+ model = export_for_training (model , example_inputs ).module ()
1500
+ else :
1501
+ model = capture_pre_autograd_graph (model , example_inputs )
1492
1502
after_export = model (x )
1493
1503
self .assertTrue (torch .equal (after_export , ref ))
1494
1504
if api is _int8da_int8w_api :
Original file line number Diff line number Diff line change @@ -180,7 +180,7 @@ def _the_op_that_needs_to_be_preserved(...)
180
180
181
181
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
182
182
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
183
- # torch.export.export / torch._export.capture_pre_autograd_graph
183
+ # torch.export.export / torch._export.export_for_training
184
184
185
185
"""
186
186
from torch ._inductor .decomposition import register_decomposition
You can’t perform that action at this time.
0 commit comments