From 67fa3b45a8cf5a64e8447e7d053513045d91856b Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 3 Oct 2024 16:54:12 -0700 Subject: [PATCH] training ir torchao migration (#1006) Summary: Migrate capture_pre_autograd_graph to export_for_training. We still need to keep capture_pre_autograd_graph call because torch/ao's CI tests uses earlier version of pytorch that does not have export_for_training. Differential Revision: D63859678 --- test/dtypes/test_uint4.py | 21 ++++++++++++++++----- test/integration/test_integration.py | 12 ++++++++++-- torchao/utils.py | 2 +- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index aa9415e51b..f16e843c4b 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -7,7 +7,6 @@ from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch._export import capture_pre_autograd_graph from torch.testing._internal.common_quantization import ( NodeSpec as ns, QuantizationTestCase, @@ -26,6 +25,12 @@ ) import copy +has_export_for_training = False +try: + from torch.export import export_for_training + has_export_for_training = True +except ImportError: + from torch._export import capture_pre_autograd_graph def _apply_weight_only_uint4_quant(model): def fn(mod): @@ -203,10 +208,16 @@ def forward(self, x): # program capture m = copy.deepcopy(m_eager) - m = capture_pre_autograd_graph( - m, - example_inputs, - ) + if has_export_for_training: + m = export_for_training( + m, + example_inputs, + ).module() + else: + m = capture_pre_autograd_graph( + m, + example_inputs, + ).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 5f81858ba0..ea310d1a1b 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1484,11 +1484,19 @@ def forward(self, x): # make sure it compiles example_inputs = (x,) - from torch._export import capture_pre_autograd_graph # TODO: export changes numerics right now, this is because of functionalization according to Zhengxu # we can re-enable this after non-functional IR is enabled in export # model = torch.export.export(model, example_inputs).module() - model = capture_pre_autograd_graph(model, example_inputs) + has_export_for_training = False + try: + from torch.export import export_for_training + has_export_for_training = True + except ImportError: + from torch._export import capture_pre_autograd_graph + if has_export_for_training: + model = export_for_training(model, example_inputs).module() + else: + model = capture_pre_autograd_graph(model, example_inputs) after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) if api is _int8da_int8w_api: diff --git a/torchao/utils.py b/torchao/utils.py index 4b5409e657..a0302cabe6 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -180,7 +180,7 @@ def _the_op_that_needs_to_be_preserved(...) # after this, `_the_op_that_needs_to_be_preserved` will be preserved as # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after - # torch.export.export / torch._export.capture_pre_autograd_graph + # torch.export.export / torch._export.export_for_training """ from torch._inductor.decomposition import register_decomposition