From 850d090daf4e13638209e93c27566cbce6d6e221 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 3 Oct 2024 15:00:15 -0700 Subject: [PATCH] training ir torchao migration (#1006) Summary: Migrate capture_pre_autograd_graph to export_for_training. Differential Revision: D63859678 --- test/dtypes/test_uint4.py | 6 +++--- test/integration/test_integration.py | 3 +-- torchao/utils.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index aa9415e51b..db23744a26 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -7,7 +7,7 @@ from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch._export import capture_pre_autograd_graph +from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, QuantizationTestCase, @@ -203,10 +203,10 @@ def forward(self, x): # program capture m = copy.deepcopy(m_eager) - m = capture_pre_autograd_graph( + m = export_for_training( m, example_inputs, - ) + ).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 5f81858ba0..59862102af 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1484,11 +1484,10 @@ def forward(self, x): # make sure it compiles example_inputs = (x,) - from torch._export import capture_pre_autograd_graph # TODO: export changes numerics right now, this is because of functionalization according to Zhengxu # we can re-enable this after non-functional IR is enabled in export # model = torch.export.export(model, example_inputs).module() - model = capture_pre_autograd_graph(model, example_inputs) + model = torch.export.export_for_training(model, example_inputs).module() after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) if api is _int8da_int8w_api: diff --git a/torchao/utils.py b/torchao/utils.py index 4b5409e657..a0302cabe6 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -180,7 +180,7 @@ def _the_op_that_needs_to_be_preserved(...) # after this, `_the_op_that_needs_to_be_preserved` will be preserved as # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after - # torch.export.export / torch._export.capture_pre_autograd_graph + # torch.export.export / torch._export.export_for_training """ from torch._inductor.decomposition import register_decomposition