2424import torch
2525
2626from neural_compressor .common .base_config import BaseConfig , config_registry , register_config
27- from neural_compressor .common .utility import (
28- DEFAULT_WHITE_LIST ,
29- FP8_QUANT ,
30- GPTQ ,
31- OP_NAME_OR_MODULE_TYPE ,
32- RTN_WEIGHT_ONLY_QUANT ,
33- )
27+ from neural_compressor .common .utility import DEFAULT_WHITE_LIST , FP8_QUANT , GPTQ , OP_NAME_OR_MODULE_TYPE , RTN
3428from neural_compressor .torch .utils .constants import PRIORITY_GPTQ , PRIORITY_RTN
3529from neural_compressor .torch .utils .utility import is_hpex_avaliable , logger
3630
@@ -60,8 +54,8 @@ class OperatorConfig(NamedTuple):
6054######################## RNT Config ###############################
6155
6256
63- @register_config (framework_name = FRAMEWORK_NAME , algo_name = RTN_WEIGHT_ONLY_QUANT , priority = PRIORITY_RTN )
64- class RTNWeightQuantConfig (BaseConfig ):
57+ @register_config (framework_name = FRAMEWORK_NAME , algo_name = RTN , priority = PRIORITY_RTN )
58+ class RTNConfig (BaseConfig ):
6559 """Config class for round-to-nearest weight-only quantization."""
6660
6761 supported_configs : List [OperatorConfig ] = []
@@ -80,7 +74,7 @@ class RTNWeightQuantConfig(BaseConfig):
8074 "double_quant_sym" ,
8175 "double_quant_group_size" ,
8276 ]
83- name = RTN_WEIGHT_ONLY_QUANT
77+ name = RTN
8478
8579 def __init__ (
8680 self ,
@@ -137,12 +131,12 @@ def to_dict(self):
137131
138132 @classmethod
139133 def from_dict (cls , config_dict ):
140- return super (RTNWeightQuantConfig , cls ).from_dict (config_dict = config_dict , str2operator = str2operator )
134+ return super (RTNConfig , cls ).from_dict (config_dict = config_dict , str2operator = str2operator )
141135
142136 @classmethod
143137 def register_supported_configs (cls ) -> List [OperatorConfig ]:
144138 supported_configs = []
145- linear_rtn_config = RTNWeightQuantConfig (
139+ linear_rtn_config = RTNConfig (
146140 weight_dtype = ["int" , "int8" , "int4" , "nf4" , "fp4" , "fp4_e2m1_bnb" , "fp4_e2m1" ],
147141 weight_bits = [4 , 1 , 2 , 3 , 5 , 6 , 7 , 8 ],
148142 weight_group_size = [32 , - 1 , 1 , 4 , 8 , 16 , 64 , 128 , 256 , 512 , 1024 ],
@@ -173,16 +167,16 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
173167
174168
175169# TODO(Yi) run `register_supported_configs` for all registered config.
176- RTNWeightQuantConfig .register_supported_configs ()
170+ RTNConfig .register_supported_configs ()
177171
178172
179- def get_default_rtn_config () -> RTNWeightQuantConfig :
173+ def get_default_rtn_config () -> RTNConfig :
180174 """Generate the default rtn config.
181175
182176 Returns:
183177 the default rtn config.
184178 """
185- return RTNWeightQuantConfig ()
179+ return RTNConfig ()
186180
187181
188182######################## GPTQ Config ###############################
0 commit comments