Skip to content

Commit b5f2869

Browse files
yushangdifacebook-github-bot
authored andcommitted
training ir torchao migration (#1006)
Summary: Migrate capture_pre_autograd_graph to export_for_training. We still need to keep capture_pre_autograd_graph call because torch/ao's CI tests uses earlier version of pytorch that does not have export_for_training. See https://github.com/pytorch/ao/blob/main/.github/workflows/regression_test.yml Differential Revision: D63859678
1 parent 9ce7ebb commit b5f2869

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

test/dtypes/test_uint4.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
88
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
99

10-
from torch._export import capture_pre_autograd_graph
1110
from torch.testing._internal.common_quantization import (
1211
NodeSpec as ns,
1312
QuantizationTestCase,
@@ -26,6 +25,14 @@
2625
)
2726
import copy
2827

28+
has_export_for_training = False
29+
try:
30+
from torch.export import export_for_training
31+
has_export_for_training = True
32+
except ImportError:
33+
# capture_pre_autograd_graph is deprecated, it's
34+
# left here to work with previous versions of pytorch
35+
from torch._export import capture_pre_autograd_graph
2936

3037
def _apply_weight_only_uint4_quant(model):
3138
def fn(mod):
@@ -203,10 +210,16 @@ def forward(self, x):
203210

204211
# program capture
205212
m = copy.deepcopy(m_eager)
206-
m = capture_pre_autograd_graph(
207-
m,
208-
example_inputs,
209-
)
213+
if has_export_for_training:
214+
m = export_for_training(
215+
m,
216+
example_inputs,
217+
).module()
218+
else:
219+
m = capture_pre_autograd_graph(
220+
m,
221+
example_inputs,
222+
).module()
210223

211224
m = prepare_pt2e(m, quantizer)
212225
# Calibrate

test/integration/test_integration.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1484,11 +1484,21 @@ def forward(self, x):
14841484

14851485
# make sure it compiles
14861486
example_inputs = (x,)
1487-
from torch._export import capture_pre_autograd_graph
14881487
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
14891488
# we can re-enable this after non-functional IR is enabled in export
14901489
# model = torch.export.export(model, example_inputs).module()
1491-
model = capture_pre_autograd_graph(model, example_inputs)
1490+
has_export_for_training = False
1491+
try:
1492+
from torch.export import export_for_training
1493+
has_export_for_training = True
1494+
except ImportError:
1495+
# capture_pre_autograd_graph is deprecated, it's
1496+
# left here to work with previous versions of pytorch
1497+
from torch._export import capture_pre_autograd_graph
1498+
if has_export_for_training:
1499+
model = export_for_training(model, example_inputs).module()
1500+
else:
1501+
model = capture_pre_autograd_graph(model, example_inputs)
14921502
after_export = model(x)
14931503
self.assertTrue(torch.equal(after_export, ref))
14941504
if api is _int8da_int8w_api:

torchao/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _the_op_that_needs_to_be_preserved(...)
180180
181181
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
182182
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
183-
# torch.export.export / torch._export.capture_pre_autograd_graph
183+
# torch.export.export / torch._export.export_for_training
184184
185185
"""
186186
from torch._inductor.decomposition import register_decomposition

0 commit comments

Comments
 (0)