@@ -913,7 +913,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
913
913
if dtype == torch .bfloat16 and torch .cuda .get_device_capability () < (8 , 0 ):
914
914
self .skipTest ("test requires SM capability of at least (8, 0)." )
915
915
from torch ._inductor import config
916
- mixed_mm_key , mixed_mm_val = ("mixed_mm_choice" , "triton" ) if TORCH_VERSION_AT_LEAST_2_4 else ("force_mixed_mm" , True )
916
+ mixed_mm_key , mixed_mm_val = ("mixed_mm_choice" , "triton" ) if TORCH_VERSION_AT_LEAST_2_5 else ("force_mixed_mm" , True )
917
917
918
918
with config .patch ({
919
919
"epilogue_fusion" : True ,
@@ -943,7 +943,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
943
943
self .skipTest ("test requires SM capability of at least (8, 0)." )
944
944
torch .manual_seed (0 )
945
945
from torch ._inductor import config
946
- mixed_mm_key , mixed_mm_val = ("mixed_mm_choice" , "triton" ) if TORCH_VERSION_AT_LEAST_2_4 else ("force_mixed_mm" , True )
946
+ mixed_mm_key , mixed_mm_val = ("mixed_mm_choice" , "triton" ) if TORCH_VERSION_AT_LEAST_2_5 else ("force_mixed_mm" , True )
947
947
948
948
with config .patch ({
949
949
"epilogue_fusion" : False ,
@@ -1222,7 +1222,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
1222
1222
(1 , 32 , 128 , 128 ),
1223
1223
(32 , 32 , 128 , 128 ),
1224
1224
]))
1225
- @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "autoquant requires 2.4 +." )
1225
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_5 , "autoquant requires 2.5 +." )
1226
1226
def test_autoquant_compile (self , device , dtype , m1 , m2 , k , n ):
1227
1227
undo_recommended_configs ()
1228
1228
if device != "cuda" or not torch .cuda .is_available ():
@@ -1254,7 +1254,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
1254
1254
self .assertTrue (sqnr >= 30 )
1255
1255
1256
1256
@parameterized .expand (COMMON_DEVICE_DTYPE )
1257
- @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "autoquant requires 2.4 +." )
1257
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_5 , "autoquant requires 2.5 +." )
1258
1258
def test_autoquant_manual (self , device , dtype ):
1259
1259
undo_recommended_configs ()
1260
1260
if device != "cuda" or not torch .cuda .is_available ():
@@ -1295,7 +1295,7 @@ def test_autoquant_manual(self, device, dtype):
1295
1295
(1 , 32 , 128 , 128 ),
1296
1296
(32 , 32 , 128 , 128 ),
1297
1297
]))
1298
- @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "autoquant requires 2.4 +." )
1298
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_5 , "autoquant requires 2.5 +." )
1299
1299
def test_autoquant_kwargs (self , device , dtype , m1 , m2 , k , n ):
1300
1300
undo_recommended_configs ()
1301
1301
if device != "cuda" or not torch .cuda .is_available ():
@@ -1478,7 +1478,7 @@ def forward(self, x):
1478
1478
1479
1479
class TestUtils (unittest .TestCase ):
1480
1480
@parameterized .expand (COMMON_DEVICE_DTYPE )
1481
- @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "autoquant requires 2.4 +." )
1481
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_5 , "autoquant requires 2.5 +." )
1482
1482
def test_get_model_size_autoquant (self , device , dtype ):
1483
1483
if device != "cuda" and dtype != torch .bfloat16 :
1484
1484
self .skipTest (f"autoquant currently does not support { device } " )
0 commit comments