20
20
DynamicallyPerAxisQuantizedLinear ,
21
21
)
22
22
from torchao .quantization .quant_api import (
23
- apply_dynamic_quant ,
24
- apply_weight_only_int8_quant ,
23
+ int4wo ,
24
+ int8wo ,
25
+ int8da_int8w ,
26
+ quantize ,
27
+ _replace_with_custom_fn_if_matches_filter ,
28
+ )
29
+ # APIs to be deprecated (used for torch 2.2.2 and 2.3)
30
+ from torchao .quantization .quant_api import (
25
31
change_linear_weights_to_int8_dqtensors ,
26
32
change_linear_weights_to_int8_woqtensors ,
27
33
change_linear_weights_to_int4_woqtensors ,
28
- _replace_with_custom_fn_if_matches_filter ,
29
34
)
30
35
from torchao .quantization .quant_primitives import (
31
36
safe_int_mm ,
73
78
from parameterized import parameterized
74
79
import itertools
75
80
import logging
76
- from torchao .utils import TORCH_VERSION_AFTER_2_3 , TORCH_VERSION_AFTER_2_4 , is_fbcode
81
+ from torchao .utils import (
82
+ TORCH_VERSION_AFTER_2_3 ,
83
+ TORCH_VERSION_AFTER_2_4 ,
84
+ unwrap_tensor_subclass ,
85
+ is_fbcode ,
86
+ )
77
87
78
88
logger = logging .getLogger ("INFO" )
79
89
80
90
torch .manual_seed (0 )
81
91
config .cache_size_limit = 100
82
92
83
- # TODO: use this to reduce the number of tests
84
- TENSOR_SUBCLASS_APIS = [
85
- change_linear_weights_to_int8_dqtensors ,
86
- change_linear_weights_to_int8_woqtensors ,
87
- change_linear_weights_to_int4_woqtensors ,
88
- ]
89
-
90
93
COMMON_DEVICES = ["cpu" , "cuda" ]
91
94
92
95
COMMON_DTYPES = [torch .float32 , torch .float16 , torch .bfloat16 ]
93
96
94
97
COMMON_DEVICE_DTYPE = list (itertools .product (COMMON_DEVICES , COMMON_DTYPES )).copy ()
95
98
99
+ def _int8wo_api (mod ):
100
+ if TORCH_VERSION_AFTER_2_4 :
101
+ quantize (mod , int8wo ())
102
+ unwrap_tensor_subclass (mod )
103
+ else :
104
+ change_linear_weights_to_int8_woqtensors (mod )
105
+
106
+ def _int8da_int8w_api (mod ):
107
+ if TORCH_VERSION_AFTER_2_4 :
108
+ quantize (mod , int8da_int8w ())
109
+ unwrap_tensor_subclass (mod )
110
+ else :
111
+ change_linear_weights_to_int8_dqtensors (mod )
112
+
113
+ def _int4wo_api (mod ):
114
+ if TORCH_VERSION_AFTER_2_4 :
115
+ quantize (mod , int4wo ())
116
+ unwrap_tensor_subclass (mod )
117
+ else :
118
+ change_linear_weights_to_int4_woqtensors (mod )
119
+
120
+ # TODO: use this to reduce the number of tests
121
+ TENSOR_SUBCLASS_APIS = [
122
+ _int8wo_api ,
123
+ _int8da_int8w_api ,
124
+ _int4wo_api ,
125
+ ]
126
+
127
+
96
128
def combine_parameters (a , b ):
97
129
new_tuples = []
98
130
for (tuple1 , tuple2 ) in itertools .product (a , b ):
@@ -756,14 +788,14 @@ def _test_lin_weight_subclass_api_impl(
756
788
@unittest .skipIf (TORCH_VERSION_AFTER_2_4 , "skip because there is some bug in inductor codegen" )
757
789
def test_int8_dynamic_quant_subclass_api (self , device , dtype ):
758
790
self ._test_lin_weight_subclass_api_impl (
759
- change_linear_weights_to_int8_dqtensors , device , 35 , test_dtype = dtype
791
+ _int8da_int8w_api , device , 35 , test_dtype = dtype
760
792
)
761
793
762
794
@parameterized .expand (COMMON_DEVICE_DTYPE )
763
795
@unittest .skipIf (is_fbcode (), "broken in fbcode" )
764
796
def test_int8_weight_only_quant_subclass_api (self , device , dtype ):
765
797
self ._test_lin_weight_subclass_api_impl (
766
- change_linear_weights_to_int8_woqtensors , device , 40 , test_dtype = dtype
798
+ _int8wo_api , device , 40 , test_dtype = dtype
767
799
)
768
800
769
801
@parameterized .expand (COMMON_DEVICE_DTYPE )
@@ -773,7 +805,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
773
805
self .skipTest (f"Fails for { dtype } " )
774
806
for test_shape in ([(16 , 1024 , 16 )] + ([(1 , 1024 , 256 )] if device == 'cuda' else [])):
775
807
self ._test_lin_weight_subclass_api_impl (
776
- change_linear_weights_to_int4_woqtensors ,
808
+ _int4wo_api ,
777
809
device ,
778
810
15 ,
779
811
test_shape = test_shape ,
@@ -789,8 +821,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
789
821
for groupsize in [64 , 32 ]:
790
822
for inner_k_tiles in [4 , 2 ]:
791
823
kwargs = {"groupsize" : groupsize , "inner_k_tiles" : inner_k_tiles }
824
+
825
+ def api (mod ):
826
+ if TORCH_VERSION_AFTER_2_4 :
827
+ quantize (mod , int4wo (** kwargs ))
828
+ unwrap_tensor_subclass (mod )
829
+ else :
830
+ change_linear_weights_to_int4_woqtensors (mod , ** kwargs )
831
+
792
832
self ._test_lin_weight_subclass_api_impl (
793
- lambda mod : change_linear_weights_to_int4_woqtensors ( mod , ** kwargs ) ,
833
+ api ,
794
834
device ,
795
835
15 ,
796
836
test_shape = test_shape ,
@@ -805,7 +845,7 @@ def test_dynamic_quant(self):
805
845
m = nn .Sequential (nn .Linear (K , N ))
806
846
807
847
y_ref = m (x )
808
- apply_dynamic_quant ( m )
848
+ quantize ( m , int8da_int8w () )
809
849
y_test = m (x )
810
850
811
851
sqnr = compute_error (y_ref , y_test )
@@ -819,7 +859,7 @@ def test_weight_only_quant(self):
819
859
x = torch .randn (* x_shape )
820
860
m = nn .Sequential (nn .Linear (4 , 5 ))
821
861
y_ref = m (x )
822
- apply_weight_only_int8_quant (m )
862
+ _int8wo_api (m )
823
863
y_wo = m (x )
824
864
sqnr = compute_error (y_ref , y_wo )
825
865
self .assertGreater (sqnr , 44.0 )
@@ -842,7 +882,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
842
882
x = torch .randn (* x_shape ).to (device ).to (dtype )
843
883
m = nn .Sequential (nn .Linear (4 , 5 )).to (device ).to (dtype )
844
884
y_ref = m (x )
845
- apply_weight_only_int8_quant (m )
885
+ _int8wo_api (m )
846
886
m (x )
847
887
m_c = torch .compile (m , mode = "max-autotune" )
848
888
y_wo , (code ,) = run_and_get_code (m_c , x )
@@ -869,7 +909,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
869
909
x = torch .randn (* x_shape ).to (device ).to (dtype )
870
910
m = nn .Sequential (nn .Linear (4 , 5 )).to (device ).to (dtype )
871
911
y_ref = m (x )
872
- apply_weight_only_int8_quant (m )
912
+ _int8wo_api (m )
873
913
m_c = torch .compile (m , mode = "max-autotune" )
874
914
y_wo , (code ,) = run_and_get_code (m_c , x )
875
915
sqnr = compute_error (y_ref , y_wo )
@@ -910,6 +950,7 @@ def forward(self, x):
910
950
911
951
# save quantized state_dict
912
952
api (model )
953
+
913
954
torch .save (model .state_dict (), "test.pth" )
914
955
# get quantized reference
915
956
model_qc = torch .compile (model , mode = "max-autotune" )
@@ -925,6 +966,7 @@ def forward(self, x):
925
966
# load quantized state_dict
926
967
state_dict = torch .load ("test.pth" , mmap = True )
927
968
os .remove ("test.pth" )
969
+
928
970
model .load_state_dict (state_dict , assign = True )
929
971
model = model .to (device = test_device , dtype = test_dtype ).eval ()
930
972
@@ -941,21 +983,21 @@ def forward(self, x):
941
983
def test_save_load_dqtensors (self , device , dtype ):
942
984
if device == "cpu" :
943
985
self .skipTest (f"indcutor failed for cpu right now" )
944
- self ._test_handle_save_load_meta_impl (change_linear_weights_to_int8_dqtensors , device , test_dtype = dtype )
986
+ self ._test_handle_save_load_meta_impl (_int8da_int8w_api , device , test_dtype = dtype )
945
987
946
988
@parameterized .expand (COMMON_DEVICE_DTYPE )
947
989
@torch .no_grad ()
948
990
@unittest .skipIf (is_fbcode (), "broken in fbcode" )
949
991
def test_save_load_int8woqtensors (self , device , dtype ):
950
- self ._test_handle_save_load_meta_impl (change_linear_weights_to_int8_woqtensors , device , test_dtype = dtype )
992
+ self ._test_handle_save_load_meta_impl (_int8wo_api , device , test_dtype = dtype )
951
993
952
994
@parameterized .expand (COMMON_DEVICE_DTYPE )
953
995
@unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "int4 requires torch nightly." )
954
996
@torch .no_grad ()
955
997
def test_save_load_int4woqtensors (self , device , dtype ):
956
998
if dtype != torch .bfloat16 :
957
999
self .skipTest (f"Fails for { dtype } " )
958
- self ._test_handle_save_load_meta_impl (change_linear_weights_to_int4_woqtensors , device , 20 , test_dtype = dtype )
1000
+ self ._test_handle_save_load_meta_impl (_int4wo_api , device , 20 , test_dtype = dtype )
959
1001
960
1002
961
1003
class TorchCompileUnitTest (unittest .TestCase ):
@@ -1275,8 +1317,7 @@ def forward(self, x):
1275
1317
model = test_model ().to (dtype = test_dtype , device = test_device ).eval ()
1276
1318
ref_f = model (x )
1277
1319
1278
- kwargs = {"dtype" : test_dtype }
1279
- api (model , ** kwargs )
1320
+ api (model )
1280
1321
1281
1322
# running model
1282
1323
model (x )
@@ -1321,8 +1362,7 @@ def forward(self, x):
1321
1362
model = test_model ().to (dtype = test_dtype , device = test_device ).eval ()
1322
1363
ref_f = model (x )
1323
1364
1324
- kwargs = {"dtype" : test_dtype }
1325
- api (model , ** kwargs )
1365
+ api (model )
1326
1366
1327
1367
# running model
1328
1368
ref = model (x )
0 commit comments