Skip to content

Commit ffa88a4

Browse files
authored
Add PyTorch 2.4 tests in CI (#654)
1 parent 0b0192e commit ffa88a4

File tree

12 files changed

+36
-23
lines changed

12 files changed

+36
-23
lines changed

.github/workflows/regression_test.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,17 @@ jobs:
3131
torch-spec: 'torch==2.3.0'
3232
gpu-arch-type: "cuda"
3333
gpu-arch-version: "12.1"
34+
- name: CUDA 2.4
35+
runs-on: linux.g5.12xlarge.nvidia.gpu
36+
torch-spec: 'torch==2.4.0'
37+
gpu-arch-type: "cuda"
38+
gpu-arch-version: "12.1"
3439
- name: CUDA Nightly
3540
runs-on: linux.g5.12xlarge.nvidia.gpu
3641
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
3742
gpu-arch-type: "cuda"
3843
gpu-arch-version: "12.1"
44+
3945
- name: CPU 2.2.2
4046
runs-on: linux.4xlarge
4147
torch-spec: 'torch==2.2.2 --index-url https://download.pytorch.org/whl/cpu "numpy<2" '
@@ -46,6 +52,11 @@ jobs:
4652
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'
4753
gpu-arch-type: "cpu"
4854
gpu-arch-version: ""
55+
- name: CPU 2.4
56+
runs-on: linux.4xlarge
57+
torch-spec: 'torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu'
58+
gpu-arch-type: "cpu"
59+
gpu-arch-version: ""
4960
- name: CPU Nightly
5061
runs-on: linux.4xlarge
5162
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'

test/float8/test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
import torch
1717
import torch.nn as nn
1818

19-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
19+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2020

21-
if not TORCH_VERSION_AT_LEAST_2_4:
21+
if not TORCH_VERSION_AT_LEAST_2_5:
2222
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2323

2424

test/float8/test_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111

1212
import pytest
1313

14-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
14+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1515

16-
if not TORCH_VERSION_AT_LEAST_2_4:
16+
if not TORCH_VERSION_AT_LEAST_2_5:
1717
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1818

1919
import torch

test/float8/test_dtensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919

2020
import pytest
2121

22-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
22+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2323

24-
if not TORCH_VERSION_AT_LEAST_2_4:
24+
if not TORCH_VERSION_AT_LEAST_2_5:
2525
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2626

2727
from torchao.float8 import Float8LinearConfig

test/float8/test_fsdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
import fire
2020

21-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
21+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2222

23-
if not TORCH_VERSION_AT_LEAST_2_4:
23+
if not TORCH_VERSION_AT_LEAST_2_5:
2424
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2525

2626
import torch

test/float8/test_fsdp2/test_fsdp2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import unittest
66
from typing import Any, List
77

8-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
8+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
99

10-
if not TORCH_VERSION_AT_LEAST_2_4:
10+
if not TORCH_VERSION_AT_LEAST_2_5:
1111
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1212

1313

test/float8/test_fsdp_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
import pytest
1717

18-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
18+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1919

20-
if not TORCH_VERSION_AT_LEAST_2_4:
20+
if not TORCH_VERSION_AT_LEAST_2_5:
2121
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2222

2323
import torch

test/float8/test_inference_flows.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
import pytest
1313
from unittest.mock import patch
1414
from torchao.utils import (
15-
TORCH_VERSION_AT_LEAST_2_4,
15+
TORCH_VERSION_AT_LEAST_2_5,
1616
unwrap_tensor_subclass,
1717
)
1818

19-
if not TORCH_VERSION_AT_LEAST_2_4:
19+
if not TORCH_VERSION_AT_LEAST_2_5:
2020
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2121

2222
import torch

test/float8/test_numerics_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111

1212
import pytest
1313

14-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
14+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1515

16-
if not TORCH_VERSION_AT_LEAST_2_4:
16+
if not TORCH_VERSION_AT_LEAST_2_5:
1717
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1818

1919
import torch

test/integration/test_integration.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
913913
if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
914914
self.skipTest("test requires SM capability of at least (8, 0).")
915915
from torch._inductor import config
916-
mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_4 else ("force_mixed_mm", True)
916+
mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_5 else ("force_mixed_mm", True)
917917

918918
with config.patch({
919919
"epilogue_fusion": True,
@@ -943,7 +943,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
943943
self.skipTest("test requires SM capability of at least (8, 0).")
944944
torch.manual_seed(0)
945945
from torch._inductor import config
946-
mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_4 else ("force_mixed_mm", True)
946+
mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_5 else ("force_mixed_mm", True)
947947

948948
with config.patch({
949949
"epilogue_fusion": False,
@@ -1222,7 +1222,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
12221222
(1, 32, 128, 128),
12231223
(32, 32, 128, 128),
12241224
]))
1225-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "autoquant requires 2.4+.")
1225+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")
12261226
def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
12271227
undo_recommended_configs()
12281228
if device != "cuda" or not torch.cuda.is_available():
@@ -1254,7 +1254,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
12541254
self.assertTrue(sqnr >= 30)
12551255

12561256
@parameterized.expand(COMMON_DEVICE_DTYPE)
1257-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "autoquant requires 2.4+.")
1257+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")
12581258
def test_autoquant_manual(self, device, dtype):
12591259
undo_recommended_configs()
12601260
if device != "cuda" or not torch.cuda.is_available():
@@ -1295,7 +1295,7 @@ def test_autoquant_manual(self, device, dtype):
12951295
(1, 32, 128, 128),
12961296
(32, 32, 128, 128),
12971297
]))
1298-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "autoquant requires 2.4+.")
1298+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")
12991299
def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n):
13001300
undo_recommended_configs()
13011301
if device != "cuda" or not torch.cuda.is_available():
@@ -1478,7 +1478,7 @@ def forward(self, x):
14781478

14791479
class TestUtils(unittest.TestCase):
14801480
@parameterized.expand(COMMON_DEVICE_DTYPE)
1481-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "autoquant requires 2.4+.")
1481+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")
14821482
def test_get_model_size_autoquant(self, device, dtype):
14831483
if device != "cuda" and dtype != torch.bfloat16:
14841484
self.skipTest(f"autoquant currently does not support {device}")

0 commit comments

Comments
 (0)