Skip to content

Commit

Permalink
training ir torchao migration (#1006)
Browse files Browse the repository at this point in the history
Summary:

Migrate capture_pre_autograd_graph to export_for_training.

We still need to keep capture_pre_autograd_graph call because torch/ao's CI tests uses earlier version of pytorch that does not have export_for_training.

See https://github.com/pytorch/ao/blob/main/.github/workflows/regression_test.yml

Differential Revision: D63859678
  • Loading branch information
yushangdi authored and facebook-github-bot committed Oct 4, 2024
1 parent 0ffbf85 commit 567ee57
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
16 changes: 11 additions & 5 deletions test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer

from torch._export import capture_pre_autograd_graph
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
Expand All @@ -25,6 +24,7 @@
QuantizationAnnotation,
)
import copy
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5


def _apply_weight_only_uint4_quant(model):
Expand Down Expand Up @@ -203,10 +203,16 @@ def forward(self, x):

# program capture
m = copy.deepcopy(m_eager)
m = capture_pre_autograd_graph(
m,
example_inputs,
)
if TORCH_VERSION_AT_LEAST_2_5:
m = torch.export.texport_for_training(
m,
example_inputs,
).module()
else:
m = torch._export.capture_pre_autograd_graph(
m,
example_inputs,
).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
Expand Down
6 changes: 4 additions & 2 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,11 +1484,13 @@ def forward(self, x):

# make sure it compiles
example_inputs = (x,)
from torch._export import capture_pre_autograd_graph
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
# we can re-enable this after non-functional IR is enabled in export
# model = torch.export.export(model, example_inputs).module()
model = capture_pre_autograd_graph(model, example_inputs)
if TORCH_VERSION_AT_LEAST_2_5:
model = torch.export.export_for_training(model, example_inputs).module()
else:
model = torch._export.capture_pre_autograd_graph(model, example_inputs)
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))
if api is _int8da_int8w_api:
Expand Down
2 changes: 1 addition & 1 deletion torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _the_op_that_needs_to_be_preserved(...)
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
# torch.export.export / torch._export.capture_pre_autograd_graph
# torch.export.export / torch._export.export_for_training
"""
from torch._inductor.decomposition import register_decomposition
Expand Down

0 comments on commit 567ee57

Please sign in to comment.