Skip to content

Commit

Permalink
training ir torchao migration (#1006)
Browse files Browse the repository at this point in the history
Summary:

Migrate capture_pre_autograd_graph to export_for_training.

We still need to keep capture_pre_autograd_graph call because torch/ao's CI tests uses earlier version of pytorch that does not have export_for_training.

Differential Revision: D63859678
  • Loading branch information
yushangdi authored and facebook-github-bot committed Oct 3, 2024
1 parent 9ce7ebb commit 67fa3b4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
21 changes: 16 additions & 5 deletions test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer

from torch._export import capture_pre_autograd_graph
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
Expand All @@ -26,6 +25,12 @@
)
import copy

has_export_for_training = False
try:
from torch.export import export_for_training
has_export_for_training = True
except ImportError:
from torch._export import capture_pre_autograd_graph

def _apply_weight_only_uint4_quant(model):
def fn(mod):
Expand Down Expand Up @@ -203,10 +208,16 @@ def forward(self, x):

# program capture
m = copy.deepcopy(m_eager)
m = capture_pre_autograd_graph(
m,
example_inputs,
)
if has_export_for_training:
m = export_for_training(
m,
example_inputs,
).module()
else:
m = capture_pre_autograd_graph(
m,
example_inputs,
).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
Expand Down
12 changes: 10 additions & 2 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,11 +1484,19 @@ def forward(self, x):

# make sure it compiles
example_inputs = (x,)
from torch._export import capture_pre_autograd_graph
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
# we can re-enable this after non-functional IR is enabled in export
# model = torch.export.export(model, example_inputs).module()
model = capture_pre_autograd_graph(model, example_inputs)
has_export_for_training = False
try:
from torch.export import export_for_training
has_export_for_training = True
except ImportError:
from torch._export import capture_pre_autograd_graph
if has_export_for_training:
model = export_for_training(model, example_inputs).module()
else:
model = capture_pre_autograd_graph(model, example_inputs)
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))
if api is _int8da_int8w_api:
Expand Down
2 changes: 1 addition & 1 deletion torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _the_op_that_needs_to_be_preserved(...)
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
# torch.export.export / torch._export.capture_pre_autograd_graph
# torch.export.export / torch._export.export_for_training
"""
from torch._inductor.decomposition import register_decomposition
Expand Down

0 comments on commit 67fa3b4

Please sign in to comment.