@@ -365,6 +365,9 @@ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_til
365365 weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (weight_int32 , inner_k_tiles )
366366 return weight_int4pack , scales_and_zeros
367367
368+ def _calc_padded_size (k , groupsize = 1 , innner_k_tiles = 1 ):
369+ from model import find_multiple
370+ return find_multiple (k , 1024 )
368371
369372def linear_forward_int4 (x , weight_int4pack , scales_and_zeros , out_features , groupsize ):
370373 origin_x_size = x .size ()
@@ -378,29 +381,24 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
378381def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = 1 ):
379382 return k % groupsize == 0 and k % (inner_k_tiles * 16 ) == 0
380383
381- def replace_linear_int4 (module , groupsize , inner_k_tiles , padding , use_cuda ):
384+ def replace_linear_int4 (module , groupsize , inner_k_tiles , padding_allowed , use_cuda ):
382385 for name , child in module .named_children ():
383386 if isinstance (child , nn .Linear ):
384- if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ):
387+ if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ) or padding_allowed :
385388 setattr (module , name , WeightOnlyInt4Linear (
386389 child .in_features , child .out_features , bias = False ,
387- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = False , use_cuda = use_cuda
388- ))
389- elif padding :
390- setattr (module , name , WeightOnlyInt4Linear (
391- child .in_features , child .out_features , bias = False ,
392- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = True , use_cuda = use_cuda
390+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , use_cuda = use_cuda
393391 ))
394392 else :
395- replace_linear_int4 (child , groupsize , inner_k_tiles , padding , use_cuda )
393+ replace_linear_int4 (child , groupsize , inner_k_tiles , padding_allowed , use_cuda )
396394
397395
398396class WeightOnlyInt4QuantHandler :
399- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
397+ def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
400398 self .mod = mod
401399 self .groupsize = groupsize
402400 self .inner_k_tiles = inner_k_tiles
403- self .padding = padding
401+ self .padding_allowed = padding_allowed
404402 assert groupsize in [32 , 64 , 128 , 256 ]
405403 assert inner_k_tiles in [2 , 4 , 8 ]
406404
@@ -417,7 +415,7 @@ def create_quantized_state_dict(self):
417415
418416 weight = mod .weight .data
419417 if not _check_linear_int4_k (in_features , self .groupsize , self .inner_k_tiles ):
420- if self .padding :
418+ if self .padding_allowed :
421419 from model import find_multiple
422420 import torch .nn .functional as F
423421 print (f"warning: { fqn } is padded to satisfy in_features % 1024 == 0" )
@@ -436,7 +434,7 @@ def create_quantized_state_dict(self):
436434 return cur_state_dict
437435
438436 def convert_for_runtime (self , use_cuda ):
439- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding , use_cuda )
437+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding_allowed , use_cuda )
440438 return self .mod
441439
442440class WeightOnlyInt4GPTQQuantHandler (GPTQQuantHandler ):
@@ -485,11 +483,11 @@ class WeightOnlyInt4Linear(torch.nn.Module):
485483
486484 def __init__ (
487485 self , in_features : int , out_features : int ,
488- bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , padding : bool = True , use_cuda = True ,
486+ bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , use_cuda = True ,
489487 ) -> None :
490488 super ().__init__ ()
491- self .padding = padding
492- if padding :
489+ self .padding = _check_linear_int4_k ( in_features , groupsize , inner_k_tiles )
490+ if self . padding :
493491 from model import find_multiple
494492 self .origin_in_features = in_features
495493 in_features = find_multiple (in_features , 1024 )
@@ -597,7 +595,7 @@ def quantize(
597595
598596 dir_name = checkpoint_path .parent
599597 base_name = checkpoint_path .name
600- new_base_name = base_name .replace ('.pth' , f"{ label } int4-gptq.g{ groupsize } .pth" )
598+ new_base_name = base_name .replace ('.pth' , f"{ label } int4-gptq.g{ groupsize } .{ device } . pth" )
601599 else :
602600 raise ValueError (f"Invalid quantization mode { mode } needs to be one of [int8, int4, int4-gpptq]" )
603601
0 commit comments