From b33cc2c3cfd283e3c098638193e7a2a13d0dfceb Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 19 Feb 2025 17:27:35 +0000 Subject: [PATCH 1/3] Init Signed-off-by: Jee Jee Li --- vllm/model_executor/model_loader/utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 9686231fb4bd..2c7162ace83b 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -13,10 +13,13 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.bitsandbytes import ( + BitsAndBytesConfig) from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.adapters import (as_classification_model, as_embedding_model, as_reward_model) +from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) @@ -169,3 +172,15 @@ def configure_quant_config(quant_config: QuantizationConfig, "The model class %s has not defined `packed_modules_mapping`, " "this may lead to incorrect mapping of quantized or ignored " "modules", model_class.__name__) + if getattr(model_class, "hf_to_vllm_mapper", None) is None: + return + hf_to_vllm_mapper: WeightsMapper = model_class.hf_to_vllm_mapper + if isinstance(quant_config, + BitsAndBytesConfig) and (llm_int8_skip_modules := + quant_config.llm_int8_skip_modules): + new_modules_lst = [] + for skip_module in llm_int8_skip_modules: + module_name = hf_to_vllm_mapper._map_name(skip_module) + new_modules_lst.append(module_name) + quant_config.llm_int8_skip_modules = new_modules_lst + pass From dd2021fbe1039501266e701d80b83997b55aa69d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 21 Feb 2025 02:00:51 +0000 Subject: [PATCH 2/3] Move forward Signed-off-by: Jee Jee Li --- vllm/model_executor/model_loader/utils.py | 72 ++++++++++++++--------- 1 file changed, 45 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 2c7162ace83b..0aba7e9ada7e 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -11,6 +11,7 @@ from vllm.config import ModelConfig, ModelImpl from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.bitsandbytes import ( @@ -156,31 +157,48 @@ def get_sub_modules(self, def configure_quant_config(quant_config: QuantizationConfig, model_class: Type[nn.Module]): - """ - Pass packed_modules_mapping by reference to quant_config so that - quant_config can properly match fused modules - Note that model attributes are passed by reference to quant_config, - enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) - """ - packed_mapping = getattr(model_class, "packed_modules_mapping", None) - if packed_mapping is not None: - # pass packed_modules_mapping by reference to quant_config - quant_config.packed_modules_mapping = packed_mapping - else: - logger.warning( - "The model class %s has not defined `packed_modules_mapping`, " - "this may lead to incorrect mapping of quantized or ignored " - "modules", model_class.__name__) - if getattr(model_class, "hf_to_vllm_mapper", None) is None: - return - hf_to_vllm_mapper: WeightsMapper = model_class.hf_to_vllm_mapper - if isinstance(quant_config, - BitsAndBytesConfig) and (llm_int8_skip_modules := - quant_config.llm_int8_skip_modules): - new_modules_lst = [] - for skip_module in llm_int8_skip_modules: - module_name = hf_to_vllm_mapper._map_name(skip_module) - new_modules_lst.append(module_name) - quant_config.llm_int8_skip_modules = new_modules_lst - pass + def configure_packed_modules_mapping(): + """ + Pass packed_modules_mapping by reference to quant_config so that + quant_config can properly match fused modules + + Note that model attributes are passed by reference to quant_config, + enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) + """ + packed_mapping = getattr(model_class, "packed_modules_mapping", None) + if packed_mapping is not None: + # pass packed_modules_mapping by reference to quant_config + quant_config.packed_modules_mapping = packed_mapping + else: + logger.warning( + "The model class %s has not defined `packed_modules_mapping`, " + "this may lead to incorrect mapping of quantized or ignored " + "modules", model_class.__name__) + + def configure_quant_skip_modules(): + + if getattr(model_class, "hf_to_vllm_mapper", None) is None: + return + hf_to_vllm_mapper: WeightsMapper = model_class.hf_to_vllm_mapper + # AWQ + if isinstance(quant_config, + AWQConfig) and (modules_to_not_convert := + quant_config.modules_to_not_convert): + new_modules_lst = [] + for skip_module in modules_to_not_convert: + module_name = hf_to_vllm_mapper._map_name(skip_module) + new_modules_lst.append(module_name) + quant_config.modules_to_not_convert = new_modules_lst + + # BitsAndBytes + elif isinstance(quant_config, BitsAndBytesConfig) and ( + llm_int8_skip_modules := quant_config.llm_int8_skip_modules): + new_modules_lst = [] + for skip_module in llm_int8_skip_modules: + module_name = hf_to_vllm_mapper._map_name(skip_module) + new_modules_lst.append(module_name) + quant_config.llm_int8_skip_modules = new_modules_lst + + configure_packed_modules_mapping() + configure_quant_skip_modules() From 8d2badd9f1a5ad80c9158022d20bffc0b4e171c2 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 5 Mar 2025 12:48:29 +0000 Subject: [PATCH 3/3] Done Signed-off-by: Jee Jee Li --- vllm/model_executor/model_loader/utils.py | 51 +++++++++++++---------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 0aba7e9ada7e..5989d877ffdd 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -158,7 +158,7 @@ def get_sub_modules(self, def configure_quant_config(quant_config: QuantizationConfig, model_class: Type[nn.Module]): - def configure_packed_modules_mapping(): + def _configure_packed_modules_mapping(): """ Pass packed_modules_mapping by reference to quant_config so that quant_config can properly match fused modules @@ -176,29 +176,38 @@ def configure_packed_modules_mapping(): "this may lead to incorrect mapping of quantized or ignored " "modules", model_class.__name__) - def configure_quant_skip_modules(): + def _configure_quant_skip_modules(): + """ + Configures the quantization skip modules for the model based on the + provided quantization configuration. + This function checks if the model class has a `hf_to_vllm_mapper` + attribute. If it does, it uses this mapper to update the list of + modules to be skip for different quantization. + configurations. + - For `BitsAndBytesConfig`, it updates the `llm_int8_skip_modules`. + - For `AWQConfig`, it updates the `modules_to_not_convert`. + + """ if getattr(model_class, "hf_to_vllm_mapper", None) is None: return hf_to_vllm_mapper: WeightsMapper = model_class.hf_to_vllm_mapper - # AWQ - if isinstance(quant_config, - AWQConfig) and (modules_to_not_convert := - quant_config.modules_to_not_convert): - new_modules_lst = [] - for skip_module in modules_to_not_convert: - module_name = hf_to_vllm_mapper._map_name(skip_module) - new_modules_lst.append(module_name) - quant_config.modules_to_not_convert = new_modules_lst # BitsAndBytes - elif isinstance(quant_config, BitsAndBytesConfig) and ( - llm_int8_skip_modules := quant_config.llm_int8_skip_modules): - new_modules_lst = [] - for skip_module in llm_int8_skip_modules: - module_name = hf_to_vllm_mapper._map_name(skip_module) - new_modules_lst.append(module_name) - quant_config.llm_int8_skip_modules = new_modules_lst - - configure_packed_modules_mapping() - configure_quant_skip_modules() + if (isinstance(quant_config, BitsAndBytesConfig) + and quant_config.llm_int8_skip_modules): + quant_config.llm_int8_skip_modules = [ + hf_to_vllm_mapper._map_name(module) + for module in quant_config.llm_int8_skip_modules + ] + # AWQ + elif (isinstance(quant_config, AWQConfig) + and quant_config.modules_to_not_convert): + quant_config.modules_to_not_convert = [ + hf_to_vllm_mapper._map_name(module) + for module in quant_config.modules_to_not_convert + ] + # TODO: Supports more quantization types. + + _configure_packed_modules_mapping() + _configure_quant_skip_modules()