diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index aa9415e51b..d368df189f 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -7,7 +7,6 @@ from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch._export import capture_pre_autograd_graph from torch.testing._internal.common_quantization import ( NodeSpec as ns, QuantizationTestCase, @@ -26,6 +25,18 @@ ) import copy +from packaging import version +torch_version = torch.__version__ + +has_export_for_training = False + +if version.parse(torch_version) > version.parse('2.5.0rc1'): + from torch.export import export_for_training + has_export_for_training = True +else: + # capture_pre_autograd_graph is deprecated, it's + # left here to work with previous versions of pytorch + from torch._export import capture_pre_autograd_graph def _apply_weight_only_uint4_quant(model): def fn(mod): @@ -203,10 +214,16 @@ def forward(self, x): # program capture m = copy.deepcopy(m_eager) - m = capture_pre_autograd_graph( - m, - example_inputs, - ) + if has_export_for_training: + m = export_for_training( + m, + example_inputs, + ).module() + else: + m = capture_pre_autograd_graph( + m, + example_inputs, + ).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 5f81858ba0..8275bb4ad3 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -89,6 +89,18 @@ benchmark_model ) +has_export_for_training = False + +from packaging import version +torch_version = torch.__version__ +if version.parse(torch_version) > version.parse('2.5.0rc1'): + from torch.export import export_for_training + has_export_for_training = True +else: + # capture_pre_autograd_graph is deprecated, it's + # left here to work with previous versions of pytorch + from torch._export import capture_pre_autograd_graph + logger = logging.getLogger("INFO") torch.manual_seed(0) @@ -1484,11 +1496,13 @@ def forward(self, x): # make sure it compiles example_inputs = (x,) - from torch._export import capture_pre_autograd_graph # TODO: export changes numerics right now, this is because of functionalization according to Zhengxu # we can re-enable this after non-functional IR is enabled in export # model = torch.export.export(model, example_inputs).module() - model = capture_pre_autograd_graph(model, example_inputs) + if has_export_for_training: + model = export_for_training(model, example_inputs).module() + else: + model = capture_pre_autograd_graph(model, example_inputs) after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) if api is _int8da_int8w_api: diff --git a/torchao/utils.py b/torchao/utils.py index 4b5409e657..a0302cabe6 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -180,7 +180,7 @@ def _the_op_that_needs_to_be_preserved(...) # after this, `_the_op_that_needs_to_be_preserved` will be preserved as # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after - # torch.export.export / torch._export.capture_pre_autograd_graph + # torch.export.export / torch._export.export_for_training """ from torch._inductor.decomposition import register_decomposition