10
10
int8_dynamic_activation_int8_semi_sparse_weight ,
11
11
float8_weight_only ,
12
12
)
13
+ from torch .testing ._internal import common_utils
13
14
from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
14
15
15
16
import torch
16
17
import unittest
17
18
import tempfile
18
19
20
+ is_cuda_8_9 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
21
+
22
+
23
+ def get_quantization_functions (do_sparse : bool , do_int4 : bool ):
24
+ base_functions = [
25
+ int8_weight_only (),
26
+ int8_dynamic_activation_int4_weight (),
27
+ int8_dynamic_activation_int8_weight (),
28
+ ]
29
+ if do_int4 :
30
+ base_functions .append (int4_weight_only (group_size = 32 ))
31
+
32
+ if do_sparse :
33
+ base_functions .append (int8_dynamic_activation_int8_semi_sparse_weight ())
34
+
35
+ if is_cuda_8_9 :
36
+ base_functions .append (float8_weight_only ())
37
+
38
+ return base_functions
39
+
19
40
20
41
class TestAffineQuantized (TestCase ):
21
42
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
@@ -38,36 +59,36 @@ def test_tensor_core_layout_transpose(self):
38
59
self .assertEqual (aqt_shape , shape )
39
60
40
61
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
41
- def test_weights_only (self ):
42
- for apply_quant in [int4_weight_only (group_size = 32 ), int8_weight_only (), int8_dynamic_activation_int4_weight (),
43
- int8_dynamic_activation_int8_weight (), int8_dynamic_activation_int8_semi_sparse_weight (), float8_weight_only ()]:
44
- l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
45
- ql = apply_quant (l )
46
- with tempfile .NamedTemporaryFile () as f :
47
- torch .save (ql .state_dict (), f )
48
- f .seek (0 )
49
- # `weights_only=True` is enabled for torch 2.5+
50
- if TORCH_VERSION_AT_LEAST_2_5 :
51
- _ = torch .load (f , weights_only = True )
52
- else :
53
- _ = torch .load (f , weights_only = False )
62
+ @common_utils .parametrize ("apply_quant" , get_quantization_functions (True , True ))
63
+ def test_weights_only (self , apply_quant ):
64
+ l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
65
+ ql = apply_quant (l )
66
+ with tempfile .NamedTemporaryFile () as f :
67
+ torch .save (ql .state_dict (), f )
68
+ f .seek (0 )
69
+ # `weights_only=True` is enabled for torch 2.5+
70
+ if TORCH_VERSION_AT_LEAST_2_5 :
71
+ _ = torch .load (f , weights_only = True )
72
+ else :
73
+ _ = torch .load (f , weights_only = False )
54
74
55
75
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
56
- def test_to_device (self ):
57
- from torchao .quantization import quantize_
58
- for apply_quant in [int8_weight_only (), int8_dynamic_activation_int4_weight (), int8_dynamic_activation_int8_weight ()]:
59
- l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
60
- ql = apply_quant (l )
61
- ql .to ("cuda" )
76
+ @common_utils .parametrize ("apply_quant" , get_quantization_functions (False , False ))
77
+ def test_to_device (self , apply_quant ):
78
+ l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
79
+ ql = apply_quant (l )
80
+ ql .to ("cuda" )
81
+
82
+ l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
83
+ ql = apply_quant (l )
84
+ ql .to (device = "cuda" )
62
85
63
- l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
64
- ql = apply_quant (l )
65
- ql .to ( device = " cuda" )
86
+ l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
87
+ ql = apply_quant (l )
88
+ ql .cuda ( )
66
89
67
- l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
68
- ql = apply_quant (l )
69
- ql .cuda ()
70
90
91
+ common_utils .instantiate_parametrized_tests (TestAffineQuantized )
71
92
72
93
if __name__ == "__main__" :
73
94
run_tests ()
0 commit comments