@@ -395,7 +395,10 @@ def test_eval_wrapper(self):
395
395
# TODO: move to a separate test file
396
396
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
397
397
def test_quantized_tensor_subclass_8da4w (self ):
398
- from torchao .quantization .subclass import AffineQuantizedTensor
398
+ from torchao .quantization .subclass import (
399
+ AffineQuantizedTensor ,
400
+ LinearActQuantizedTensor ,
401
+ )
399
402
from torchao .quantization .quant_primitives import MappingType
400
403
import copy
401
404
@@ -409,6 +412,7 @@ def test_quantized_tensor_subclass_8da4w(self):
409
412
quant_max = 7
410
413
411
414
# TODO: make a general helper function?
415
+ # input settings
412
416
def get_per_token_block_size (x ):
413
417
block_size = []
414
418
for i in range (len (x .shape )- 1 ):
@@ -421,13 +425,18 @@ def get_per_token_block_size(x):
421
425
input_target_dtype = torch .int8
422
426
input_quant_func = lambda x : AffineQuantizedTensor .from_float (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype )
423
427
428
+ def dynamic_quant (linear ):
429
+ # note: order is important
430
+ linear .weight = torch .nn .Parameter (AffineQuantizedTensor .from_float (linear .weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps ), requires_grad = False )
431
+ linear .weight = torch .nn .Parameter (LinearActQuantizedTensor .from_float (linear .weight , input_quant_func ), requires_grad = False )
432
+
424
433
m = ToyLinearModel ().eval ()
425
434
m_copy = copy .deepcopy (m )
426
435
example_inputs = m .example_inputs ()
427
- m . linear1 . weight = torch . nn . Parameter ( AffineQuantizedTensor . from_float ( m .linear1 . weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , input_quant_func = input_quant_func ), requires_grad = False )
428
- m . linear2 . weight = torch . nn . Parameter ( AffineQuantizedTensor . from_float ( m .linear2 . weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , input_quant_func = input_quant_func ), requires_grad = False )
429
- assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
430
- assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
436
+ dynamic_quant ( m .linear1 )
437
+ dynamic_quant ( m .linear2 )
438
+ assert isinstance (m .linear1 .weight , LinearActQuantizedTensor )
439
+ assert isinstance (m .linear2 .weight , LinearActQuantizedTensor )
431
440
432
441
# reference
433
442
from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
@@ -461,9 +470,6 @@ def test_quantized_tensor_subclass_int4(self):
461
470
preserve_zero = False
462
471
zero_point_dtype = torch .bfloat16
463
472
464
- # weight only quantization
465
- input_quant_func = None
466
-
467
473
# use 1024 so that we don't need padding
468
474
m = ToyLinearModel (1024 , 1024 , 1024 ).eval ().to (torch .bfloat16 ).to ("cuda" )
469
475
m_copy = copy .deepcopy (m )
@@ -475,7 +481,6 @@ def to_quantized(weight):
475
481
zero_point_dtype = zero_point_dtype ,
476
482
preserve_zero = preserve_zero ,
477
483
zero_point_domain = ZeroPointDomain .FLOAT ,
478
- input_quant_func = input_quant_func ,
479
484
)
480
485
481
486
m .linear1 .weight = torch .nn .Parameter (to_quantized (m .linear1 .weight ), requires_grad = False )
@@ -506,16 +511,13 @@ def test_quantized_tensor_subclass_int8(self):
506
511
eps = torch .finfo (torch .float32 ).eps
507
512
zero_point_dtype = torch .int64
508
513
509
- # weight only quantization
510
- input_quant_func = None
511
-
512
514
m = ToyLinearModel ().eval ().to (torch .bfloat16 )
513
515
m_copy = copy .deepcopy (m )
514
516
example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ), m .example_inputs ()))
515
517
516
518
def to_quantized (weight ):
517
519
block_size = (1 , weight .shape [1 ])
518
- return AffineQuantizedTensor .from_float (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype , input_quant_func = input_quant_func )
520
+ return AffineQuantizedTensor .from_float (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype )
519
521
520
522
m .linear1 .weight = torch .nn .Parameter (to_quantized (m .linear1 .weight ), requires_grad = False )
521
523
m .linear2 .weight = torch .nn .Parameter (to_quantized (m .linear2 .weight ), requires_grad = False )
@@ -532,5 +534,63 @@ def to_quantized(weight):
532
534
torch .testing .assert_close (res , ref , rtol = 0.00001 , atol = 1e-2 )
533
535
534
536
537
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
538
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
539
+ def test_quantized_tensor_subclass_int8_dyn_quant (self ):
540
+ from torchao .quantization .subclass import AffineQuantizedTensor
541
+ from torchao .quantization .subclass import LinearActQuantizedTensor
542
+ from torchao .quantization .quant_primitives import MappingType
543
+ from torchao .quantization .quant_primitives import ZeroPointDomain
544
+ import copy
545
+
546
+ # weight settings
547
+ mapping_type = MappingType .SYMMETRIC
548
+ def get_weight_block_size (x ):
549
+ return (1 , x .shape [1 ])
550
+ target_dtype = torch .int8
551
+ eps = torch .finfo (torch .float32 ).eps
552
+ zero_point_dtype = torch .int64
553
+
554
+ # input settings
555
+ def get_per_token_block_size (x ):
556
+ block_size = list (x .shape )
557
+ for i in range (len (block_size )- 1 ):
558
+ block_size [i ] = 1
559
+ return block_size
560
+
561
+ input_mapping_type = MappingType .SYMMETRIC
562
+ input_target_dtype = torch .int8
563
+ input_eps = 1e-5
564
+ input_quant_min = - 127
565
+ input_quant_max = 127
566
+ input_quant_func = lambda x : AffineQuantizedTensor .from_float (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype , eps = input_eps , quant_min = input_quant_min , quant_max = input_quant_max , scale_dtype = torch .float )
567
+
568
+ # use 1024 so that we don't need padding
569
+ m = ToyLinearModel (1024 , 1024 , 1024 ).eval ().to (torch .bfloat16 ).to ("cuda" )
570
+ m_copy = copy .deepcopy (m )
571
+ example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ).to ("cuda" ), m .example_inputs ()))
572
+
573
+ def dynamic_quant (linear ):
574
+ # note: order is important
575
+ linear .weight = torch .nn .Parameter (AffineQuantizedTensor .from_float (linear .weight , mapping_type , get_weight_block_size (linear .weight ), target_dtype , eps = eps , zero_point_dtype = zero_point_dtype ), requires_grad = False )
576
+ linear .weight = torch .nn .Parameter (LinearActQuantizedTensor .from_float (linear .weight , input_quant_func ), requires_grad = False )
577
+
578
+ dynamic_quant (m .linear1 )
579
+ dynamic_quant (m .linear2 )
580
+ assert isinstance (m .linear1 .weight , LinearActQuantizedTensor )
581
+ assert isinstance (m .linear2 .weight , LinearActQuantizedTensor )
582
+ assert isinstance (m .linear1 .weight .original_weight_tensor , AffineQuantizedTensor )
583
+ assert isinstance (m .linear2 .weight .original_weight_tensor , AffineQuantizedTensor )
584
+
585
+ # reference
586
+ from torchao .quantization .quant_api import change_linear_weights_to_int8_dqtensors
587
+ change_linear_weights_to_int8_dqtensors (m_copy )
588
+
589
+ res = m (* example_inputs )
590
+ ref = m_copy (* example_inputs )
591
+
592
+ self .assertTrue (torch .equal (res , ref ))
593
+
594
+
535
595
if __name__ == "__main__" :
536
596
unittest .main ()
0 commit comments