Skip to content

Commit cb0d915

Browse files
authored
Add torch 2.9 in regression tests (#3311)
1 parent ba6f428 commit cb0d915

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

.github/workflows/regression_test.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ jobs:
6767
dev-requirements-overrides: ""
6868
- name: CUDA 2.7
6969
runs-on: linux.g5.12xlarge.nvidia.gpu
70-
torch-spec: 'torch==2.7.0'
70+
torch-spec: 'torch==2.7.1'
7171
gpu-arch-type: "cuda"
7272
gpu-arch-version: "12.6"
7373
dev-requirements-overrides: ""
@@ -77,6 +77,12 @@ jobs:
7777
gpu-arch-type: "cuda"
7878
gpu-arch-version: "12.6"
7979
dev-requirements-overrides: ""
80+
- name: CUDA 2.9
81+
runs-on: linux.g5.12xlarge.nvidia.gpu
82+
torch-spec: 'torch==2.9.1'
83+
gpu-arch-type: "cuda"
84+
gpu-arch-version: "12.6"
85+
dev-requirements-overrides: ""
8086

8187
- name: CPU 2.6
8288
runs-on: linux.4xlarge
@@ -86,7 +92,7 @@ jobs:
8692
dev-requirements-overrides: ""
8793
- name: CPU 2.7
8894
runs-on: linux.4xlarge
89-
torch-spec: 'torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu'
95+
torch-spec: 'torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu'
9096
gpu-arch-type: "cpu"
9197
gpu-arch-version: ""
9298
dev-requirements-overrides: ""
@@ -96,6 +102,12 @@ jobs:
96102
gpu-arch-type: "cpu"
97103
gpu-arch-version: ""
98104
dev-requirements-overrides: ""
105+
- name: CPU 2.9
106+
runs-on: linux.4xlarge
107+
torch-spec: 'torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu'
108+
gpu-arch-type: "cpu"
109+
gpu-arch-version: ""
110+
dev-requirements-overrides: ""
99111

100112
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
101113
with:

torchao/quantization/pt2e/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,16 @@ def _get_aten_graph_module_for_pattern(
859859
):
860860
aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr]
861861

862+
if torch.__version__.startswith("2.9"):
863+
# PyTorch 2.9 adds _guards_fn nodes to exported graphs.
864+
# These have errors only on torch 2.9.0 and 2.9.1
865+
for node in list(aten_pattern.graph.nodes): # type: ignore[union-attr]
866+
if node.op == "call_module" and node.name == "_guards_fn":
867+
aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr]
868+
# Also remove the _guards_fn module from the graph module if it exists
869+
if hasattr(aten_pattern, "_guards_fn"):
870+
delattr(aten_pattern, "_guards_fn")
871+
862872
aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
863873
aten_pattern.recompile() # type: ignore[operator]
864874

0 commit comments

Comments
 (0)