48
48
from .utils import _get_per_token_block_size
49
49
import logging
50
50
from .autoquant import autoquant , AutoQuantizableLinearWeight
51
+ from torchao .utils import TORCH_VERSION_AFTER_2_5
51
52
52
53
53
54
__all__ = [
@@ -326,6 +327,35 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
326
327
_is_linear if filter_fn is None else filter_fn ,
327
328
)
328
329
330
+ def _int8_asymm_per_token_quant (x : torch .Tensor ) -> torch .Tensor :
331
+ # avoid circular dep
332
+ from torchao .dtypes import to_affine_quantized
333
+
334
+ mapping_type = MappingType .ASYMMETRIC
335
+ target_dtype = torch .int8
336
+ return to_affine_quantized (x , mapping_type , _get_per_token_block_size (x ), target_dtype )
337
+
338
+ def apply_int8_dynamic_activation_int4_weight_quant (weight , group_size = 32 ):
339
+ if weight .shape [- 1 ] % group_size != 0 :
340
+ return weight
341
+
342
+ # avoid circular dep
343
+ from torchao .dtypes import to_affine_quantized
344
+
345
+ # weight settings
346
+ mapping_type = MappingType .SYMMETRIC
347
+ block_size = (1 , group_size )
348
+ target_dtype = torch .int8
349
+ eps = torch .finfo (torch .float32 ).eps
350
+ quant_min = - 8
351
+ quant_max = 7
352
+
353
+ # input settings
354
+ input_quant_func = _int8_asymm_per_token_quant
355
+
356
+ weight = to_affine_quantized (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps )
357
+ weight = to_linear_activation_quantized (weight , input_quant_func )
358
+ return weight
329
359
330
360
def int8_dynamic_activation_int4_weight (group_size = 32 ):
331
361
"""Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear
@@ -336,31 +366,11 @@ def int8_dynamic_activation_int4_weight(group_size=32):
336
366
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
337
367
size is more fine grained
338
368
"""
339
- def apply_int8_dynamic_activation_int4_weight_quant (weight ):
340
- if weight .shape [- 1 ] % group_size != 0 :
341
- return weight
342
-
343
- # avoid circular dep
344
- from torchao .dtypes import to_affine_quantized
345
-
346
- # weight settings
347
- mapping_type = MappingType .SYMMETRIC
348
- block_size = (1 , group_size )
349
- target_dtype = torch .int8
350
- eps = torch .finfo (torch .float32 ).eps
351
- quant_min = - 8
352
- quant_max = 7
353
-
354
- # input settings
355
- input_mapping_type = MappingType .ASYMMETRIC
356
- input_target_dtype = torch .int8
357
- input_quant_func = lambda x : to_affine_quantized (x , input_mapping_type , _get_per_token_block_size (x ), input_target_dtype )
358
-
359
- weight = to_affine_quantized (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps )
360
- weight = to_linear_activation_quantized (weight , input_quant_func )
361
- return weight
369
+ def insert_subclass (lin ):
370
+ lin .weight = torch .nn .Parameter (apply_int8_dynamic_activation_int4_weight_quant (lin .weight , group_size ), requires_grad = False )
371
+ return lin
362
372
363
- return _get_linear_subclass_inserter ( apply_int8_dynamic_activation_int4_weight_quant )
373
+ return insert_subclass
364
374
365
375
366
376
def int4_weight_only (group_size = 128 , inner_k_tiles = 8 ):
@@ -421,6 +431,16 @@ def apply_int8wo_quant(weight):
421
431
422
432
return _get_linear_subclass_inserter (apply_int8wo_quant )
423
433
434
+ def _int8_symm_per_token_reduced_range_quant (x : torch .Tensor ) -> torch .Tensor :
435
+ # avoid circular dep
436
+ from torchao .dtypes import to_affine_quantized
437
+ mapping_type = MappingType .SYMMETRIC
438
+ target_dtype = torch .int8
439
+ eps = 1e-5
440
+ quant_min = - 127
441
+ quant_max = 127
442
+ return to_affine_quantized (x , mapping_type , _get_per_token_block_size (x ), target_dtype , eps = eps , quant_min = quant_min , quant_max = quant_max , scale_dtype = torch .float32 if x .dtype == torch .float16 else None )
443
+
424
444
425
445
def int8_dynamic_activation_int8_weight (layout_type = PlainLayoutType ()):
426
446
"""
@@ -444,12 +464,7 @@ def get_weight_block_size(x):
444
464
zero_point_dtype = torch .int64
445
465
446
466
# input settings
447
- input_mapping_type = MappingType .SYMMETRIC
448
- input_target_dtype = torch .int8
449
- input_eps = 1e-5
450
- input_quant_min = - 127
451
- input_quant_max = 127
452
- input_quant_func = lambda x : to_affine_quantized (x , input_mapping_type , _get_per_token_block_size (x ), input_target_dtype , eps = input_eps , quant_min = input_quant_min , quant_max = input_quant_max , scale_dtype = torch .float32 if x .dtype == torch .float16 else None )
467
+ input_quant_func = _int8_symm_per_token_reduced_range_quant
453
468
454
469
block_size = get_weight_block_size (weight )
455
470
weight = to_affine_quantized (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype , layout_type = layout_type )
@@ -466,3 +481,7 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
466
481
"""
467
482
from torchao .dtypes import SemiSparseLayoutType
468
483
return int8_dynamic_activation_int8_weight (layout_type = SemiSparseLayoutType ())
484
+
485
+
486
+ if TORCH_VERSION_AFTER_2_5 :
487
+ torch .serialization .add_safe_globals ([_int8_asymm_per_token_quant , _int8_symm_per_token_reduced_range_quant ])
0 commit comments