Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training ir torchao migration #1006

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer

from torch._export import capture_pre_autograd_graph
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
Expand All @@ -25,6 +24,7 @@
QuantizationAnnotation,
)
import copy
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5


def _apply_weight_only_uint4_quant(model):
Expand Down Expand Up @@ -203,10 +203,16 @@ def forward(self, x):

# program capture
m = copy.deepcopy(m_eager)
m = capture_pre_autograd_graph(
m,
example_inputs,
)
if TORCH_VERSION_AT_LEAST_2_5:
m = torch.export.texport_for_training(
m,
example_inputs,
).module()
else:
m = torch._export.capture_pre_autograd_graph(
m,
example_inputs,
).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
Expand Down
6 changes: 4 additions & 2 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,11 +1484,13 @@ def forward(self, x):

# make sure it compiles
example_inputs = (x,)
from torch._export import capture_pre_autograd_graph
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
# we can re-enable this after non-functional IR is enabled in export
# model = torch.export.export(model, example_inputs).module()
model = capture_pre_autograd_graph(model, example_inputs)
if TORCH_VERSION_AT_LEAST_2_5:
model = torch.export.export_for_training(model, example_inputs).module()
else:
model = torch._export.capture_pre_autograd_graph(model, example_inputs)
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))
if api is _int8da_int8w_api:
Expand Down
2 changes: 1 addition & 1 deletion torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _the_op_that_needs_to_be_preserved(...)
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
# torch.export.export / torch._export.capture_pre_autograd_graph
# torch.export.export / torch._export.export_for_training
"""
from torch._inductor.decomposition import register_decomposition
Expand Down
Loading