File tree Expand file tree Collapse file tree 2 files changed +24
-2
lines changed
torchao/quantization/pt2e Expand file tree Collapse file tree 2 files changed +24
-2
lines changed Original file line number Diff line number Diff line change 6767 dev-requirements-overrides : " "
6868 - name : CUDA 2.7
6969 runs-on : linux.g5.12xlarge.nvidia.gpu
70- torch-spec : ' torch==2.7.0 '
70+ torch-spec : ' torch==2.7.1 '
7171 gpu-arch-type : " cuda"
7272 gpu-arch-version : " 12.6"
7373 dev-requirements-overrides : " "
7777 gpu-arch-type : " cuda"
7878 gpu-arch-version : " 12.6"
7979 dev-requirements-overrides : " "
80+ - name : CUDA 2.9
81+ runs-on : linux.g5.12xlarge.nvidia.gpu
82+ torch-spec : ' torch==2.9.1'
83+ gpu-arch-type : " cuda"
84+ gpu-arch-version : " 12.6"
85+ dev-requirements-overrides : " "
8086
8187 - name : CPU 2.6
8288 runs-on : linux.4xlarge
8692 dev-requirements-overrides : " "
8793 - name : CPU 2.7
8894 runs-on : linux.4xlarge
89- torch-spec : ' torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu'
95+ torch-spec : ' torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu'
9096 gpu-arch-type : " cpu"
9197 gpu-arch-version : " "
9298 dev-requirements-overrides : " "
@@ -96,6 +102,12 @@ jobs:
96102 gpu-arch-type : " cpu"
97103 gpu-arch-version : " "
98104 dev-requirements-overrides : " "
105+ - name : CPU 2.9
106+ runs-on : linux.4xlarge
107+ torch-spec : ' torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu'
108+ gpu-arch-type : " cpu"
109+ gpu-arch-version : " "
110+ dev-requirements-overrides : " "
99111
100112 uses : pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
101113 with :
Original file line number Diff line number Diff line change @@ -859,6 +859,16 @@ def _get_aten_graph_module_for_pattern(
859859 ):
860860 aten_pattern .graph .erase_node (node ) # type: ignore[operator, union-attr]
861861
862+ if torch .__version__ .startswith ("2.9" ):
863+ # PyTorch 2.9 adds _guards_fn nodes to exported graphs.
864+ # These have errors only on torch 2.9.0 and 2.9.1
865+ for node in list (aten_pattern .graph .nodes ): # type: ignore[union-attr]
866+ if node .op == "call_module" and node .name == "_guards_fn" :
867+ aten_pattern .graph .erase_node (node ) # type: ignore[operator, union-attr]
868+ # Also remove the _guards_fn module from the graph module if it exists
869+ if hasattr (aten_pattern , "_guards_fn" ):
870+ delattr (aten_pattern , "_guards_fn" )
871+
862872 aten_pattern .graph .eliminate_dead_code () # type: ignore[operator, union-attr]
863873 aten_pattern .recompile () # type: ignore[operator]
864874
You can’t perform that action at this time.
0 commit comments