18
18
fake_quantize_per_channel_group ,
19
19
fake_quantize_per_token ,
20
20
)
21
- from torchao .quantization .utils import get_group_qparams_symmetric
21
+ from torchao .quantization .quant_primitives import (
22
+ fake_quantize_affine ,
23
+ ZeroPointDomain ,
24
+ )
25
+ from torchao .quantization .utils import (
26
+ get_group_qparams_symmetric ,
27
+ get_groupwise_affine_qparams ,
28
+ groupwise_affine_quantize_tensor ,
29
+ )
22
30
from torchao .utils import TORCH_VERSION_AFTER_2_4
23
31
24
32
25
33
# TODO: put this in a common test utils file
34
+ _CUDA_IS_AVAILABLE = torch .cuda .is_available ()
35
+
26
36
class Sub (torch .nn .Module ):
27
37
def __init__ (self ):
28
38
super ().__init__ ()
29
- self .linear = torch .nn .Linear (32 , 32 , bias = False ).to (torch .float )
39
+ self .linear = torch .nn .Linear (256 , 256 , bias = False ).to (torch .float )
30
40
31
41
def example_inputs (self ):
32
- return (torch .randn (1 , 32 ).to (torch .float ),)
42
+ return (torch .randn (1 , 256 ).to (torch .float ),)
33
43
34
44
def forward (self , x ):
35
45
return self .linear (x )
36
46
37
47
class M (torch .nn .Module ):
38
48
def __init__ (self ):
39
49
super ().__init__ ()
40
- self .linear1 = torch .nn .Linear (64 , 32 , bias = False ).to (torch .float )
50
+ self .linear1 = torch .nn .Linear (512 , 256 , bias = False ).to (torch .float )
41
51
self .sub = Sub ()
42
- self .linear2 = torch .nn .Linear (32 , 64 , bias = False ).to (torch .float )
52
+ self .linear2 = torch .nn .Linear (256 , 512 , bias = False ).to (torch .float )
43
53
44
54
def example_inputs (self ):
45
- return (torch .randn (1 , 64 ).to (torch .float ),)
55
+ return (torch .randn (1 , 512 ).to (torch .float ),)
46
56
47
57
def forward (self , x ):
48
58
x = self .linear1 (x )
@@ -111,23 +121,46 @@ def test_fake_quantize_per_token(self):
111
121
112
122
def _set_ptq_weight (
113
123
self ,
114
- ptq_linear : "Int8DynActInt4WeightLinear" ,
115
- fp32_weight : torch .Tensor ,
116
- group_size : int ,
124
+ ptq_linear : torch .nn .Module ,
125
+ qat_linear : torch .nn .Module ,
117
126
):
118
127
"""
119
128
Set the weight to the quantized version of the given fp32 weights,
120
129
for making linear outputs comparable with QAT.
121
130
"""
131
+ from torchao .quantization .GPTQ import (
132
+ Int8DynActInt4WeightLinear ,
133
+ WeightOnlyInt4Linear ,
134
+ )
135
+ from torchao .quantization .prototype .qat import (
136
+ Int8DynActInt4WeightQATLinear ,
137
+ Int4WeightOnlyQATLinear ,
138
+ )
122
139
n_bit = 4
123
140
(qmin , qmax ) = self ._get_qmin_qmax (n_bit )
124
- (s , zp ) = get_group_qparams_symmetric (fp32_weight , n_bit , group_size )
125
- q_weight = torch .ops .quantized_decomposed .quantize_per_channel_group (
126
- fp32_weight , s , zp , qmin , qmax , torch .int8 , group_size ,
127
- )
128
- ptq_linear .weight = q_weight
129
- ptq_linear .scales = s
130
- ptq_linear .zeros = zp
141
+ if isinstance (ptq_linear , Int8DynActInt4WeightLinear ):
142
+ assert isinstance (qat_linear , Int8DynActInt4WeightQATLinear )
143
+ fp32_weight = qat_linear .weight
144
+ group_size = qat_linear .groupsize
145
+ (s , zp ) = get_group_qparams_symmetric (fp32_weight , n_bit , group_size )
146
+ q_weight = torch .ops .quantized_decomposed .quantize_per_channel_group (
147
+ fp32_weight , s , zp , qmin , qmax , torch .int8 , group_size ,
148
+ )
149
+ ptq_linear .weight = q_weight
150
+ ptq_linear .scales = s
151
+ ptq_linear .zeros = zp
152
+ elif isinstance (ptq_linear , WeightOnlyInt4Linear ):
153
+ assert isinstance (qat_linear , Int4WeightOnlyQATLinear )
154
+ (q_weight , scales_and_zeros ) = groupwise_affine_quantize_tensor (
155
+ qat_linear .weight , n_bit , qat_linear .groupsize ,
156
+ )
157
+ q_weight = torch .ops .aten ._convert_weight_to_int4pack (
158
+ q_weight .to ("cuda" ), qat_linear .inner_k_tiles ,
159
+ )
160
+ ptq_linear .weight = q_weight
161
+ ptq_linear .scales_and_zeros = scales_and_zeros
162
+ else :
163
+ raise ValueError ("Unknown ptq_linear type: %s" % type (ptq_linear ))
131
164
132
165
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
133
166
def test_qat_8da4w_linear (self ):
@@ -144,7 +177,7 @@ def test_qat_8da4w_linear(self):
144
177
)
145
178
146
179
# Force the weights to be the same
147
- self ._set_ptq_weight (ptq_linear , qat_linear . weight , group_size )
180
+ self ._set_ptq_weight (ptq_linear , qat_linear )
148
181
149
182
# Compare linear values
150
183
torch .manual_seed (self .SEED )
@@ -280,7 +313,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
280
313
loss_fn1 = torch .nn .CrossEntropyLoss ()
281
314
loss_fn2 = torch .nn .CrossEntropyLoss ()
282
315
example_inputs = nn_model .example_inputs ()
283
- target = torch .randn (1 , 64 ).float ()
316
+ target = torch .randn (1 , 512 ).float ()
284
317
output1 = nn_model (* example_inputs )
285
318
output2 = qat_model (* example_inputs )
286
319
torch .testing .assert_close (output1 , output2 , atol = 0 , rtol = 0 )
@@ -322,6 +355,130 @@ def test_qat_generic_fake_quantize(self):
322
355
torch .testing .assert_close (py_out , ao_out , atol = 0 , rtol = 0 )
323
356
torch .testing .assert_close (py_input .grad , ao_input .grad , atol = 0 , rtol = 0 )
324
357
358
+ def _assert_close_4w (self , val , ref ):
359
+ # Note: for int4 weight-only quantization, we do not expect exact match
360
+ # because torch._weight_int4pack_mm and torch.mm do not match exactly.
361
+ # Here we use the same error bar as PyTorch core to determine closeness:
362
+ # https://github.com/pytorch/pytorch/blob/6079c5091091d872b8dafbaa4e31a5b6194647ad/test/test_linalg.py#L6079
363
+ mean_err = ((val - ref ) / ref ).mean ().abs ()
364
+ print (mean_err )
365
+ self .assertTrue (mean_err < 0.05 )
366
+
367
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
368
+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
369
+ def test_qat_4w_primitives (self ):
370
+ n_bit = 4
371
+ group_size = 32
372
+ inner_k_tiles = 8
373
+ scales_precision = torch .bfloat16
374
+ device = torch .device ("cuda" )
375
+ dtype = torch .bfloat16
376
+ torch .manual_seed (self .SEED )
377
+ x = torch .randn (100 , 256 , dtype = dtype , device = device )
378
+ weight = torch .randn (512 , 256 , dtype = dtype , device = device )
379
+
380
+ # PTQ
381
+ (q_weight , scales_and_zeros ) = groupwise_affine_quantize_tensor (
382
+ weight , n_bit , group_size , scales_precision ,
383
+ )
384
+ q_weight = torch .ops .aten ._convert_weight_to_int4pack (
385
+ q_weight .to (device ), inner_k_tiles ,
386
+ )
387
+ ptq_out = torch .ops .aten ._weight_int4pack_mm (
388
+ x , q_weight , group_size , scales_and_zeros
389
+ )
390
+
391
+ # QAT
392
+ block_size = (1 , group_size )
393
+ quant_min = 0
394
+ quant_max = 2 ** n_bit - 1
395
+ scales , zero_points = get_groupwise_affine_qparams (
396
+ weight , n_bit , group_size , scales_precision ,
397
+ )
398
+ w_fq = fake_quantize_affine (
399
+ weight ,
400
+ block_size ,
401
+ scales ,
402
+ zero_points ,
403
+ torch .int32 ,
404
+ quant_min ,
405
+ quant_max ,
406
+ zero_point_domain = ZeroPointDomain .FLOAT ,
407
+ )
408
+ qat_out = torch .nn .functional .linear (x , w_fq )
409
+
410
+ self ._assert_close_4w (qat_out , ptq_out )
411
+
412
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
413
+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
414
+ def test_qat_4w_linear (self ):
415
+ from torchao .quantization .prototype .qat import Int4WeightOnlyQATLinear
416
+ from torchao .quantization .GPTQ import WeightOnlyInt4Linear
417
+
418
+ group_size = 128
419
+ device = torch .device ("cuda" )
420
+ dtype = torch .bfloat16
421
+ torch .manual_seed (self .SEED )
422
+ qat_linear = Int4WeightOnlyQATLinear (
423
+ 256 , 688 , bias = False , groupsize = group_size , device = device ,
424
+ )
425
+ ptq_linear = WeightOnlyInt4Linear (
426
+ 256 , 688 , bias = False , groupsize = group_size , device = device ,
427
+ )
428
+
429
+ # Force the weights to be the same
430
+ self ._set_ptq_weight (ptq_linear , qat_linear )
431
+
432
+ # Compare linear values
433
+ torch .manual_seed (self .SEED )
434
+ x = torch .randn (100 , 256 , dtype = dtype , device = device )
435
+ x2 = copy .deepcopy (x )
436
+ qat_out = qat_linear (x )
437
+ ptq_out = ptq_linear (x2 )
438
+ self ._assert_close_4w (qat_out , ptq_out )
439
+
440
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
441
+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
442
+ def test_qat_4w_quantizer (self ):
443
+ from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
444
+ from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
445
+
446
+ group_size = 32
447
+ inner_k_tiles = 8
448
+ device = torch .device ("cuda" )
449
+ dtype = torch .bfloat16
450
+ torch .manual_seed (self .SEED )
451
+ m = M ().to (device ).to (dtype )
452
+ m2 = copy .deepcopy (m )
453
+ qat_quantizer = Int4WeightOnlyQATQuantizer (
454
+ groupsize = group_size , inner_k_tiles = inner_k_tiles ,
455
+ )
456
+ ptq_quantizer = Int4WeightOnlyQuantizer (
457
+ groupsize = group_size , inner_k_tiles = inner_k_tiles ,
458
+ )
459
+ qat_model = qat_quantizer .prepare (m )
460
+ ptq_model = ptq_quantizer .quantize (m2 )
461
+
462
+ # Compare model values
463
+ torch .manual_seed (self .SEED )
464
+ x = [i .to (device ).to (dtype ) for i in m .example_inputs ()]
465
+ x2 = copy .deepcopy (x )
466
+ qat_out = qat_model (* x )
467
+ ptq_out = ptq_model (* x2 )
468
+ self ._assert_close_4w (qat_out , ptq_out )
469
+
470
+ # Convert QAT model and compare model values
471
+ converted_model = qat_quantizer .convert (qat_model )
472
+ converted_out = converted_model (* x )
473
+ torch .testing .assert_close (converted_out , ptq_out , atol = 0 , rtol = 0 )
474
+
475
+ # Compare converted state dict
476
+ ptq_state_dict = ptq_model .state_dict ()
477
+ converted_state_dict = converted_model .state_dict ()
478
+ self .assertEqual (ptq_state_dict .keys (), converted_state_dict .keys ())
479
+ for k in ptq_state_dict .keys ():
480
+ torch .testing .assert_close (ptq_state_dict [k ], converted_state_dict [k ], atol = 0 , rtol = 0 )
481
+
325
482
326
483
if __name__ == "__main__" :
327
484
unittest .main ()
0 commit comments