2121
2222import  torch 
2323
24- from  neural_compressor .torch .utils  import  logger 
25- from  neural_compressor .torch .utils .utility  import  set_module 
24+ from  neural_compressor .torch .utils  import  logger , set_module 
2625
2726from  .utility  import  quant_tensor , search_clip 
2827
2928
3029@torch .no_grad () 
3130def  rtn_quantize (
3231    model ,
33-     num_bits = 4 ,
32+     dtype = "int" ,
33+     bits = 4 ,
34+     scheme = "sym" ,
3435    group_size = 32 ,
35-     scheme = "asym" ,
36+     group_dim = 1 ,
3637    quantile = 1.0 ,
3738    weight_config = {},
38-     return_int = False ,
39-     dtype = "int" ,
40-     enable_full_range = False ,
41-     enable_mse_search = False ,
42-     group_dim = 1 ,
39+     export_compressed_model = False ,
40+     use_full_range = False ,
41+     use_mse_search = False ,
4342    ** kwargs ,
4443):
45-     """Quant the model with round to nearest method. 
44+     """Quant the model with round to nearest method and inplace is True . 
4645
4746    Args: 
4847        model: torch module 
49-         num_bits : num bits. Defaults to 4. 
48+         bits : num bits. Defaults to 4. 
5049        group_size (int, optional): how many elements share one scale/zp. Defaults to 32. 
51-         scheme (str, optional): sym or asym. Defaults to "asym ". 
50+         scheme (str, optional): sym or asym. Defaults to "sym ". 
5251        quantile (float, optional): percentile of clip. Defaults to 1.0. 
5352        dtype (str, optional): select from int, nf4, fp4. Defaults to int. 
5453        weight_config (dict, optional): specific layer wise configurations. Defaults to {}. 
@@ -60,88 +59,98 @@ def rtn_quantize(
6059                            'bits': 4, 
6160                            'group_size': 32, 
6261                            'scheme': 'sym' 
63-                             'gptq_perm': [1, 1, ...] # for gptq perm 
6462                        } 
6563                } 
66-         return_int  (bool, optional): Choose return fp32 or int32 model. 
64+         export_compressed_model  (bool, optional): Choose return fp32 or int32 model. 
6765                                     Defaults to False. 
68-         enable_full_range  (bool, optional): Choose sym range whether use -2**(bits-1). 
66+         use_full_range  (bool, optional): Choose sym range whether use -2**(bits-1). 
6967                                     Defaults to False. 
70-         enable_mse_search  (bool, optional):  Whether search clip range. 
68+         use_mse_search  (bool, optional):  Whether search clip range. 
7169                                     Defaults to True. 
7270        group_dim (int, optional):   0 means splitting output channel, 
7371                                     1 means splitting input channel. Defaults to 1. 
7472
7573    Returns: 
7674        model: fake quantized torch module 
7775    """ 
76+     device  =  "cpu" 
7877    assert  isinstance (model , torch .nn .Module ), "only support torch module" 
7978    supported_layers  =  ["Linear" ]
80-     double_quant_dtype   =   kwargs . get ( "double_quant_dtype" ,  "fp32" ) 
79+     # initialize global configuration 
8180    double_quant_config  =  {
82-         "double_quant" : False   if   double_quant_dtype   ==   "fp32"   else   True ,
83-         "double_quant_dtype" : double_quant_dtype ,
84-         "double_quant_num_bits " : kwargs .get ("double_quant_num_bits " , 8 ),
81+         "double_quant" : kwargs . get ( "use_double_quant" ,  False ) ,
82+         "double_quant_dtype" : kwargs . get ( " double_quant_dtype" ,  "int" ) ,
83+         "double_quant_bits " : kwargs .get ("double_quant_bits " , 8 ),
8584        "double_quant_scheme" : kwargs .get ("double_quant_scheme" , "sym" ),
8685        "double_quant_group_size" : kwargs .get ("double_quant_group_size" , 256 ),
8786    }
88-     if  return_int :
89-         compression_dtype  =  kwargs .get ("compression_dtype" , torch .int32 )
90-         compression_dim  =  kwargs .get ("compression_dim" , 1 )
91-         scale_dtype  =  kwargs .get ("scale_dtype" , torch .float32 )
92-         device  =  kwargs .get ("device" , "cpu" )
87+     if  export_compressed_model :
88+         use_optimum_format  =  kwargs .get ("use_optimum_format" , True )
9389    for  name , m  in  model .named_modules ():
9490        if  m .__class__ .__name__  not  in supported_layers :
9591            continue 
9692        if  name  in  weight_config :  # pragma: no cover 
93+             # initialize op configuration 
9794            dtype  =  weight_config [name ].get ("dtype" , "int" )
98-             num_bits  =  weight_config [name ][ "bits" ] 
95+             bits  =  weight_config [name ]. get ( "bits" ,  4 ) 
9996            group_size  =  weight_config [name ]["group_size" ]
10097            scheme  =  weight_config [name ]["scheme" ]
10198            quantile  =  weight_config [name ].get ("quantile" , 1.0 )
99+             group_dim  =  weight_config [name ]["group_dim" ]
100+             use_full_range  =  weight_config [name ]["use_full_range" ]
101+             use_mse_search  =  weight_config [name ]["use_mse_search" ]
102+             use_layer_wise  =  weight_config [name ]["use_layer_wise" ]
103+             export_compressed_model  =  weight_config [name ]["export_compressed_model" ]
104+             if  export_compressed_model :
105+                 use_optimum_format  =  kwargs .get ("use_optimum_format" , True )
106+             # double quant config 
107+             double_quant_config  =  {
108+                 "double_quant" : weight_config [name ]["use_double_quant" ],
109+                 "double_quant_dtype" : weight_config [name ]["double_quant_dtype" ],
110+                 "double_quant_bits" : weight_config [name ]["double_quant_bits" ],
111+                 "double_quant_scheme" : weight_config [name ]["double_quant_scheme" ],
112+                 "double_quant_group_size" : weight_config [name ]["double_quant_group_size" ],
113+             }
102114        log_msg  =  (
103-             f"RTN quantization config: num_bits={ num_bits } { group_size }  
104-             +  f"scheme={ scheme } { quantile }  
115+             f"RTN quantization config: bits={ bits } { group_size }   +  f"scheme={ scheme } { quantile }  
105116        )
106117        if  dtype  !=  "int" :
107118            log_msg  +=  f", dtype={ dtype }  
108119        elif  scheme  ==  "sym" :  # nf4/fp4 is always [-7,7] 
109-             log_msg  +=  f", enable_full_range= { enable_full_range }  
120+             log_msg  +=  f", use_full_range= { use_full_range }  
110121        if  dtype  ==  "fp32" :
111122            continue 
112123        logger .debug (f"RTN quantized module:{ name , m }  )
113124        logger .debug (log_msg )
114-         weight  =  m .weight .T  if  group_dim  ==  0  else  m .weight 
115-         if  enable_mse_search :
116-             quantile  =  search_clip (m , num_bits , group_size , scheme , dtype , enable_full_range )
117-         if  return_int :
125+         weight  =  m .weight .t_ (). contiguous ()  if  group_dim  ==  0  else  m .weight 
126+         if  use_mse_search :
127+             quantile  =  search_clip (m , bits , group_size , scheme , dtype , use_full_range )
128+         if  export_compressed_model :
118129            int_weight , scale , zp  =  quant_tensor (
119130                weight ,
120-                 num_bits ,
121-                 group_size ,
122-                 scheme ,
123-                 quantile ,
124131                dtype = dtype ,
132+                 bits = bits ,
133+                 group_size = group_size ,
134+                 scheme = scheme ,
135+                 quantile = quantile ,
125136                return_int = True ,
126-                 full_range = enable_full_range ,
137+                 full_range = use_full_range ,
127138                ** double_quant_config ,
128139            )
129-             int_weight  =  int_weight .T  if  group_dim  ==  0  else  int_weight 
130-             scale  =  scale .T  if  group_dim  ==  0  else  scale 
131-             zp  =  zp .T  if  group_dim  ==  0  and  zp  is  not None  else  zp 
140+             int_weight  =  int_weight .t_ (). contiguous ()  if  group_dim  ==  0  else  int_weight 
141+             scale  =  scale .t_ (). contiguous ()  if  group_dim  ==  0  else  scale 
142+             zp  =  zp .t_ (). contiguous ()  if  group_dim  ==  0  and  zp  is  not None  else  zp 
132143            from  neural_compressor .torch .quantization .layers  import  WeightOnlyLinear 
133144
134145            new_module  =  WeightOnlyLinear (
135146                m .in_features ,
136147                m .out_features ,
137-                 num_bits ,
138-                 group_size ,
148+                 bits = bits ,
149+                 group_size = group_size ,
139150                dtype = dtype ,
140151                zp = zp  is  not None ,
141152                bias = m .bias  is  not None ,
142-                 compression_dtype = compression_dtype ,
143-                 compression_dim = compression_dim ,
144-                 scale_dtype = scale_dtype ,
153+                 use_optimum_format = use_optimum_format ,
145154                device = device ,
146155            )
147156            new_module .pack (int_weight , scale , zp , m .bias )
@@ -150,50 +159,16 @@ def rtn_quantize(
150159            else :
151160                set_module (model , name , new_module )
152161        else :
153-             q_weight  =  quant_tensor (
162+             weight  =  quant_tensor (
154163                weight ,
155-                 num_bits ,
156-                 group_size ,
157-                 scheme ,
158-                 quantile ,
159164                dtype = dtype ,
160-                 full_range = enable_full_range ,
165+                 bits = bits ,
166+                 group_size = group_size ,
167+                 scheme = scheme ,
168+                 quantile = quantile ,
169+                 full_range = use_full_range ,
161170                ** double_quant_config ,
162171            )
163-             q_weight  =  q_weight . T  if  group_dim  ==  0  else  q_weight 
164-             m .weight .data .copy_ (q_weight )
172+             weight  =  weight . t_ (). contiguous ()  if  group_dim  ==  0  else  weight 
173+             m .weight .data .copy_ (weight )
165174    return  model 
166- 
167- 
168- from  neural_compressor .torch .quantization .config  import  RTNConfig 
169- 
170- 
171- def  apply_rtn_on_single_module (module : torch .nn .Module , quant_config : RTNConfig ) ->  torch .nn .Module :
172-     # TODO (Yi) remove it 
173-     enable_full_range  =  quant_config .enable_full_range 
174-     enable_mse_search  =  quant_config .enable_mse_search 
175-     group_dim  =  quant_config .group_dim 
176-     dtype  =  quant_config .weight_dtype 
177-     num_bits  =  quant_config .weight_bits 
178-     scheme  =  "sym"  if  quant_config .weight_sym  else  "asym" 
179-     group_size  =  quant_config .weight_group_size 
180-     return_int  =  quant_config .return_int 
181-     double_quant_dtype  =  quant_config .double_quant_dtype 
182-     double_quant_num_bits  =  quant_config .double_quant_bits 
183-     double_quant_scheme  =  "sym"  if  quant_config .double_quant_sym  else  "asym" 
184-     double_quant_group_size  =  quant_config .double_quant_group_size 
185-     return  rtn_quantize (
186-         module ,
187-         num_bits ,
188-         group_size ,
189-         scheme ,
190-         return_int = return_int ,
191-         dtype = dtype ,
192-         enable_full_range = enable_full_range ,
193-         enable_mse_search = enable_mse_search ,
194-         group_dim = group_dim ,
195-         double_quant_dtype = double_quant_dtype ,
196-         double_quant_scheme = double_quant_scheme ,
197-         double_quant_num_bits = double_quant_num_bits ,
198-         double_quant_group_size = double_quant_group_size ,
199-     )
0 commit comments