Skip to content

Commit

Permalink
training ir torchao migration (pytorch#1006)
Browse files Browse the repository at this point in the history
Summary:

Migrate capture_pre_autograd_graph to export_for_training.

Differential Revision: D63859678
  • Loading branch information
yushangdi authored and facebook-github-bot committed Oct 3, 2024
1 parent 8945fb3 commit 850d090
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
6 changes: 3 additions & 3 deletions test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer

from torch._export import capture_pre_autograd_graph
from torch.export import export_for_training
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
Expand Down Expand Up @@ -203,10 +203,10 @@ def forward(self, x):

# program capture
m = copy.deepcopy(m_eager)
m = capture_pre_autograd_graph(
m = export_for_training(
m,
example_inputs,
)
).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
Expand Down
3 changes: 1 addition & 2 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,11 +1484,10 @@ def forward(self, x):

# make sure it compiles
example_inputs = (x,)
from torch._export import capture_pre_autograd_graph
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
# we can re-enable this after non-functional IR is enabled in export
# model = torch.export.export(model, example_inputs).module()
model = capture_pre_autograd_graph(model, example_inputs)
model = torch.export.export_for_training(model, example_inputs).module()
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))
if api is _int8da_int8w_api:
Expand Down
2 changes: 1 addition & 1 deletion torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _the_op_that_needs_to_be_preserved(...)
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
# torch.export.export / torch._export.capture_pre_autograd_graph
# torch.export.export / torch._export.export_for_training
"""
from torch._inductor.decomposition import register_decomposition
Expand Down

0 comments on commit 850d090

Please sign in to comment.