diff --git a/vllm/config/model.py b/vllm/config/model.py index 48e956467bf9..af10d646d294 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -12,13 +12,15 @@ import torch from pydantic import ConfigDict, SkipValidation, field_validator, model_validator from pydantic.dataclasses import dataclass -from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE import vllm.envs as envs +from vllm.config.model_arch import ( + ModelArchitectureConfig, +) from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import RunnerType -from vllm.config.utils import assert_hashable, config, getattr_iter +from vllm.config.utils import assert_hashable, config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.transformers_utils.config import ( @@ -31,18 +33,16 @@ is_encoder_decoder, try_get_dense_modules, try_get_generation_config, - try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope, ) -from vllm.transformers_utils.model_arch_config import ( - SUPPORTED_ARCHITECTURES as MODEL_ARCH_CONFIG_SUPPORTED_ARCHITECTURES, +from vllm.transformers_utils.model_arch_config_parser import ( + MODEL_ARCH_CONFIG_CONVERTORS, + ModelArchConfigConvertorBase, ) -from vllm.transformers_utils.model_arch_config import get_model_arch_config from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils.import_utils import LazyLoader -from vllm.utils.torch_utils import common_broadcastable_dtype if TYPE_CHECKING: from transformers import PretrainedConfig @@ -524,19 +524,11 @@ def __post_init__( self.model, hf_token=self.hf_token, revision=self.revision ) self.model_arch_config = None - if ( - len(self.architectures) == 1 - and self.architectures[0] in MODEL_ARCH_CONFIG_SUPPORTED_ARCHITECTURES - ): - assert hf_overrides_fn is None, "Not supported yet" - self.model_arch_config = get_model_arch_config( - self.hf_config_path or self.model, - self.trust_remote_code, - self.revision, - self.code_revision, - self.config_format, - model_arch_overrides_kw=hf_overrides_kw, - ) + convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get( + hf_config.model_type, ModelArchConfigConvertorBase + ) + convertor = convertor_cls(hf_config) + self.model_arch_config = convertor.convert(self.model, self.revision) architectures = self.architectures registry = self.registry @@ -787,7 +779,7 @@ def registry(self): @property def architectures(self) -> list[str]: - return getattr(self.hf_config, "architectures", []) + return self.model_arch_config.architectures @property def architecture(self) -> str: @@ -962,50 +954,15 @@ def _get_default_pooling_task( return "embed" - def _parse_quant_hf_config(self, hf_config: PretrainedConfig): - quant_cfg = getattr(hf_config, "quantization_config", None) - if quant_cfg is None: - # compressed-tensors uses a "compression_config" key - quant_cfg = getattr(hf_config, "compression_config", None) - - else: - # Set quant_method for ModelOpt models. - producer_name = quant_cfg.get("producer", {}).get("name") - if producer_name == "modelopt": - quant_algo = quant_cfg.get("quantization", {}).get("quant_algo") - if quant_algo == "FP8": - quant_cfg["quant_method"] = "modelopt" - elif quant_algo == "NVFP4": - quant_cfg["quant_method"] = "modelopt_fp4" - elif quant_algo is not None: - raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}") - - return quant_cfg - def _verify_quantization(self) -> None: supported_quantization = me_quant.QUANTIZATION_METHODS if self.quantization is not None: self.quantization = cast(me_quant.QuantizationMethods, self.quantization) # Parse quantization method from the HF model config, if available. - quant_cfg = self._parse_quant_hf_config(self.hf_config) - if quant_cfg is None and ( - text_config := getattr(self.hf_config, "text_config", None) - ): - # Check the text config as well for multi-modal models. - quant_cfg = self._parse_quant_hf_config(text_config) + quant_cfg = ModelArchConfigConvertorBase.get_quantization_config(self.hf_config) if quant_cfg is not None: - # Use the community standard 'quant_method' - quant_method = quant_cfg.get("quant_method", "").lower() - - # Normalize library names - quant_method = quant_method.replace( - "compressed_tensors", "compressed-tensors" - ) - - quant_cfg["quant_method"] = quant_method - # Quantization methods which are overrides (i.e. they have a # `override_quantization_method` method) must be checked in order # of preference (this is particularly important for GPTQ). @@ -1085,7 +1042,7 @@ def _verify_cuda_graph(self) -> None: logger.warning( "CUDA graph is not supported for %s on ROCm yet, fallback " "to eager mode.", - self.hf_config.model_type, + self.model_arch_config.model_type, ) self.enforce_eager = True @@ -1096,11 +1053,9 @@ def _verify_bnb_config(self) -> None: # TODO Remove this when bitsandbytes supports. """ is_bitsandbytes = self.quantization == "bitsandbytes" - has_quantization_config = ( - getattr(self.hf_config, "quantization_config", None) is not None - ) + has_quantization_config = self.model_arch_config.quantization_config is not None is_8bit = ( - self.hf_config.quantization_config.get("load_in_8bit", False) + self.model_arch_config.quantization_config.get("load_in_8bit", False) if has_quantization_config else False ) @@ -1160,9 +1115,7 @@ def verify_with_parallel_config( "make sure sampling results are the same across workers." ) - total_num_attention_heads = getattr( - self.hf_text_config, "num_attention_heads", 0 - ) + total_num_attention_heads = self.model_arch_config.total_num_attention_heads tensor_parallel_size = parallel_config.tensor_parallel_size if total_num_attention_heads % tensor_parallel_size != 0: raise ValueError( @@ -1205,137 +1158,26 @@ def get_sliding_window(self) -> int | None: return getattr(self.hf_text_config, "sliding_window", None) def get_vocab_size(self) -> int: - return getattr(self.hf_text_config, "vocab_size", 0) + return self.model_arch_config.vocab_size def get_hidden_size(self) -> int: - return getattr(self.hf_text_config, "hidden_size", 0) + return self.model_arch_config.hidden_size @property def is_deepseek_mla(self) -> bool: - if self.model_arch_config: - return self.model_arch_config.text_config.use_deepseek_mla - if not hasattr(self.hf_text_config, "model_type"): - return False - elif self.hf_text_config.model_type in ( - "deepseek_v2", - "deepseek_v3", - "deepseek_v32", - "deepseek_mtp", - "kimi_k2", - "kimi_linear", - "longcat_flash", - "pangu_ultra_moe", - "pangu_ultra_moe_mtp", - ): - return self.hf_text_config.kv_lora_rank is not None - elif self.hf_text_config.model_type == "eagle": - # if the model is an EAGLE module, check for the - # underlying architecture - return ( - self.hf_text_config.model.model_type - in ("deepseek_v2", "deepseek_v3", "deepseek_v32") - and self.hf_text_config.kv_lora_rank is not None - ) - return False + return self.model_arch_config.is_deepseek_mla def get_head_size(self) -> int: - if self.model_arch_config: - return self.model_arch_config.text_config.head_dim - - # TODO remove hard code - if self.is_deepseek_mla: - qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0) - if self.use_mla: - return self.hf_text_config.kv_lora_rank + qk_rope_head_dim - else: - qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0) - if qk_rope_head_dim and qk_nope_head_dim: - return qk_rope_head_dim + qk_nope_head_dim - - if hasattr(self.hf_text_config, "model_type") and ( - self.hf_text_config.model_type == "zamba2" - ): - return self.hf_text_config.attention_head_dim - if self.is_attention_free: return 0 - - # NOTE: Some configs may set head_dim=None in the config - if getattr(self.hf_text_config, "head_dim", None) is not None: - return self.hf_text_config.head_dim - - # NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head` - if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None: - return self.hf_text_config.hidden_size_per_head - - # FIXME(woosuk): This may not be true for all models. - return ( - self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads - ) + return self.model_arch_config.head_size def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" - if self.model_arch_config: - return self.model_arch_config.text_config.num_key_value_heads - - # For GPTBigCode & Falcon: - # NOTE: for falcon, when new_decoder_architecture is True, the - # multi_query flag is ignored and we use n_head_kv for the number of - # KV heads. - falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] - new_decoder_arch_falcon = ( - self.hf_config.model_type in falcon_model_types - and getattr(self.hf_config, "new_decoder_architecture", False) - ) - if not new_decoder_arch_falcon and getattr( - self.hf_text_config, "multi_query", False - ): - # Multi-query attention, only one KV head. - # Currently, tensor parallelism is not supported in this case. - return 1 - - # For DBRX and MPT - if self.hf_config.model_type == "mpt": - if "kv_n_heads" in self.hf_config.attn_config: - return self.hf_config.attn_config["kv_n_heads"] - return self.hf_config.num_attention_heads - if self.hf_config.model_type == "dbrx": - return getattr( - self.hf_config.attn_config, - "kv_n_heads", - self.hf_config.num_attention_heads, - ) - - if self.hf_config.model_type == "nemotron-nas": - for block in self.hf_config.block_configs: - if not block.attention.no_op: - return ( - self.hf_config.num_attention_heads - // block.attention.n_heads_in_group - ) - - raise RuntimeError("Couldn't determine number of kv heads") - if self.is_attention_free: return 0 - attributes = [ - # For Falcon: - "n_head_kv", - "num_kv_heads", - # For LLaMA-2: - "num_key_value_heads", - # For ChatGLM: - "multi_query_group_num", - ] - for attr in attributes: - num_kv_heads = getattr(self.hf_text_config, attr, None) - if num_kv_heads is not None: - return num_kv_heads - - # For non-grouped-query attention models, the number of KV heads is - # equal to the number of attention heads. - return self.hf_text_config.num_attention_heads + return self.model_arch_config.total_num_kv_heads def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int: """Returns the number of KV heads per GPU.""" @@ -1351,54 +1193,14 @@ def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int: return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) def get_num_attention_heads(self, parallel_config: ParallelConfig) -> int: - num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) + num_heads = self.model_arch_config.total_num_attention_heads return num_heads // parallel_config.tensor_parallel_size def get_num_experts(self) -> int: - """Returns the number of experts in the model.""" - if self.model_arch_config: - return self.model_arch_config.text_config.num_experts - - num_expert_names = [ - "num_experts", # Jamba - "moe_num_experts", # Dbrx - "n_routed_experts", # DeepSeek - "num_local_experts", # Mixtral - ] - num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0) - if isinstance(num_experts, list): - # Ernie VL's remote code uses list[int]... - # The values are always the same so we just take the first one. - return num_experts[0] - return num_experts + return self.model_arch_config.num_experts def get_num_hidden_layers(self) -> int: - if self.model_arch_config: - total_num_hidden_layers = ( - self.model_arch_config.text_config.num_hidden_layers - ) - else: - if ( - self.hf_text_config.model_type == "deepseek_mtp" - or self.hf_config.model_type == "mimo_mtp" - or self.hf_config.model_type == "glm4_moe_mtp" - or self.hf_config.model_type == "ernie_mtp" - or self.hf_config.model_type == "qwen3_next_mtp" - or self.hf_config.model_type == "qwen3_next_mtp" - or self.hf_config.model_type == "pangu_ultra_moe_mtp" - ): - total_num_hidden_layers = getattr( - self.hf_text_config, "num_nextn_predict_layers", 0 - ) - elif self.hf_config.model_type == "longcat_flash_mtp": - total_num_hidden_layers = getattr( - self.hf_text_config, "num_nextn_predict_layers", 1 - ) - else: - total_num_hidden_layers = getattr( - self.hf_text_config, "num_hidden_layers", 0 - ) - return total_num_hidden_layers + return self.model_arch_config.num_hidden_layers def get_layers_start_end_indices( self, parallel_config: ParallelConfig @@ -1449,9 +1251,7 @@ def get_num_layers_by_block_type( self.hf_text_config, "layers_block_type", None ) if layers_block_type_value is not None: - if hasattr(self.hf_text_config, "model_type") and ( - self.hf_text_config.model_type == "zamba2" - ): + if self.model_arch_config.text_model_type == "zamba2": if attn_block_type: return sum( t == "hybrid" for t in layers_block_type_value[start:end] @@ -1731,10 +1531,7 @@ def head_dtype(self) -> torch.dtype: @property def hidden_size(self): - if hasattr(self.hf_config, "hidden_size"): - return self.hf_config.hidden_size - text_config = self.hf_config.get_text_config() - return text_config.hidden_size + return self.model_arch_config.hidden_size @property def embedding_size(self): @@ -1758,6 +1555,7 @@ def get_and_verify_max_len(self, max_model_len: int): ) max_model_len = _get_and_verify_max_len( hf_config=self.hf_text_config, + model_arch_config=self.model_arch_config, tokenizer_config=tokenizer_config, max_model_len=max_model_len, disable_sliding_window=self.disable_sliding_window, @@ -1865,46 +1663,6 @@ def _check_valid_dtype(model_type: str, dtype: torch.dtype): return True -def _find_dtype( - model_id: str, - config: PretrainedConfig, - *, - revision: str | None, -): - # NOTE: getattr(config, "dtype", torch.float32) is not correct - # because config.dtype can be None. - config_dtype = getattr(config, "dtype", None) - - # Fallbacks for multi-modal models if the root config - # does not define dtype - if config_dtype is None: - config_dtype = getattr(config.get_text_config(), "dtype", None) - if config_dtype is None and hasattr(config, "vision_config"): - config_dtype = getattr(config.vision_config, "dtype", None) - if config_dtype is None and hasattr(config, "encoder_config"): - config_dtype = getattr(config.encoder_config, "dtype", None) - - # Try to read the dtype of the weights if they are in safetensors format - if config_dtype is None: - repo_mt = try_get_safetensors_metadata(model_id, revision=revision) - - if repo_mt and (files_mt := repo_mt.files_metadata): - param_dtypes: set[torch.dtype] = { - _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] - for file_mt in files_mt.values() - for dtype_str in file_mt.parameter_count - if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE - } - - if param_dtypes: - return common_broadcastable_dtype(param_dtypes) - - if config_dtype is None: - config_dtype = torch.float32 - - return config_dtype - - def _resolve_auto_dtype( model_type: str, config_dtype: torch.dtype, @@ -1959,7 +1717,8 @@ def _get_and_verify_dtype( is_pooling_model: bool, revision: str | None = None, ) -> torch.dtype: - config_dtype = _find_dtype(model_id, config, revision=revision) + convertor = ModelArchConfigConvertorBase(config) + config_dtype = convertor.get_torch_dtype(model_id, revision=revision) model_type = config.model_type if isinstance(dtype, str): @@ -2022,6 +1781,7 @@ def _get_head_dtype( def _get_and_verify_max_len( hf_config: PretrainedConfig, + model_arch_config: ModelArchitectureConfig, tokenizer_config: dict | None, max_model_len: int | None, disable_sliding_window: bool, @@ -2030,36 +1790,9 @@ def _get_and_verify_max_len( encoder_config: Any | None = None, ) -> int: """Get and verify the model's maximum length.""" - derived_max_model_len = float("inf") - possible_keys = [ - # OPT - "max_position_embeddings", - # GPT-2 - "n_positions", - # MPT - "max_seq_len", - # ChatGLM2 - "seq_length", - # Command-R - "model_max_length", - # Whisper - "max_target_positions", - # Others - "max_sequence_length", - "max_seq_length", - "seq_len", - ] - # Choose the smallest "max_length" from the possible keys - max_len_key = None - for key in possible_keys: - max_len = getattr(hf_config, key, None) - if max_len is not None: - max_len_key = key if max_len < derived_max_model_len else max_len_key - derived_max_model_len = min(derived_max_model_len, max_len) - # For Command-R / Cohere, Cohere2 / Aya Vision models - if tmp_max_len := getattr(hf_config, "model_max_length", None): - max_len_key = "model_max_length" - derived_max_model_len = tmp_max_len + (derived_max_model_len, max_len_key) = ( + model_arch_config.derived_max_model_len_and_key + ) # If sliding window is manually disabled, max_length should be less # than the sliding window length in the model config. @@ -2092,10 +1825,9 @@ def _get_and_verify_max_len( default_max_len = 2048 logger.warning( - "The model's config.json does not contain any of the following " - "keys to determine the original maximum length of the model: " - "%s. Assuming the model's maximum length is %d.", - possible_keys, + "The model's config.json does not contain any of the keys " + "to determine the original maximum length of the model. " + "Assuming the model's maximum length is %d.", default_max_len, ) derived_max_model_len = default_max_len @@ -2103,7 +1835,7 @@ def _get_and_verify_max_len( rope_scaling = getattr(hf_config, "rope_scaling", None) # NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE # scaling, so we skip applying the scaling factor again. - if rope_scaling is not None and "gemma3" not in hf_config.model_type: + if rope_scaling is not None and "gemma3" not in model_arch_config.model_type: # No need to consider "type" key because of patch_rope_scaling when # loading HF config rope_type = rope_scaling["rope_type"] diff --git a/vllm/config/model_arch.py b/vllm/config/model_arch.py index f4e8a2ab065a..cd55dcbacd32 100644 --- a/vllm/config/model_arch.py +++ b/vllm/config/model_arch.py @@ -1,145 +1,62 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools -import json -from dataclasses import field from typing import Any -from pydantic import ConfigDict +import torch from pydantic.dataclasses import dataclass -from torch import nn -from vllm.config.utils import config from vllm.logger import init_logger logger = init_logger(__name__) -@dataclass(config=ConfigDict(arbitrary_types_allowed=True, extra="allow")) -class ModelArchitectureTextConfig: +@dataclass +class ModelArchitectureConfig: + """ + Configuration for model architecture that required by vLLM runtime + """ + + architectures: list[str] + """List of model architecture class names (e.g., ['LlamaForCausalLM']).""" + model_type: str + """Model type identifier (e.g., 'llama', 'gpt_oss').""" + + text_model_type: str | None + """Text model type identifier (e.g., 'llama4_text').""" + hidden_size: int + """Hidden size of the model.""" + num_hidden_layers: int - num_attention_heads: int - use_deepseek_mla: bool - head_dim: int - vocab_size: int - num_key_value_heads: int - num_experts: int + """Number of hidden layers in the model.""" - def __init__( - self, - model_type: str, - hidden_size: int, - num_hidden_layers: int, - num_attention_heads: int, - use_deepseek_mla: bool, - head_dim: int, - vocab_size: int, - num_key_value_heads: int, - num_experts: int, - **kwargs, - ): - self.model_type = model_type - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.head_dim = head_dim - self.vocab_size = vocab_size - self.num_key_value_heads = num_key_value_heads - self.num_experts = num_experts - self.use_deepseek_mla = use_deepseek_mla - - for key, value in kwargs.items(): - setattr(self, key, value) - - def __repr__(self): - config_dict_json = json.dumps(self.__dict__, indent=2, sort_keys=True) + "\n" - return f"{self.__class__.__name__} {config_dict_json}" - - -@dataclass(config=ConfigDict(arbitrary_types_allowed=True, extra="allow")) -class ModelArchitectureVisionConfig: - def __init__( - self, - **kwargs, - ): - for key, value in kwargs.items(): - setattr(self, key, value) - - def __repr__(self): - config_dict_json = json.dumps(self.__dict__, indent=2, sort_keys=True) + "\n" - return f"{self.__class__.__name__} {config_dict_json}" - - -@dataclass(config=ConfigDict(arbitrary_types_allowed=True, extra="allow")) -class ModelArchitectureAudioConfig: - def __init__( - self, - **kwargs, - ): - for key, value in kwargs.items(): - setattr(self, key, value) - - def __repr__(self): - config_dict_json = json.dumps(self.__dict__, indent=2, sort_keys=True) + "\n" - return f"{self.__class__.__name__} {config_dict_json}" - - -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) -class ModelArchitectureConfig: - """ - Configuration for model architecture - """ + total_num_attention_heads: int + """Number of attention heads in the model.""" - text_config: ModelArchitectureTextConfig = field(init=True) - """Text model configuration containing text-specific architecture details.""" + head_size: int + """Head dimension of the model.""" - architectures: list[str] = field(default_factory=list) - """List of model architecture class names (e.g., ['LlamaForCausalLM']).""" + vocab_size: int + """Vocabulary size of the model.""" + + total_num_kv_heads: int + """Number of key value heads in the model.""" - model_type: str = "" - """Model type identifier (e.g., 'llama', 'gpt2').""" + num_experts: int + """Number of experts in the model.""" - # TODO: Formalize quantization_config in parser - quantization_config: dict[str, Any] = field(default_factory=dict) + quantization_config: dict[str, Any] """Quantization configuration dictionary containing quantization parameters.""" - torch_dtype: str = "" + torch_dtype: torch.dtype """PyTorch data type for model weights (e.g., 'float16', 'bfloat16').""" - per_layer_attention_cls: list[type[nn.Module]] = field(default_factory=list) - """Per-layer attention class of the model.""" - - vision_config: ModelArchitectureVisionConfig | None = None - """Vision model configuration for multimodal models (optional).""" - - audio_config: ModelArchitectureAudioConfig | None = None - """Audio model configuration for multimodal models (optional).""" - - def __init__( - self, - architectures: list[str], - model_type: str, - quantization_config: dict[str, Any], - torch_dtype: str, - text_config: ModelArchitectureTextConfig, - per_layer_attention_cls: list[type[nn.Module]] | None = None, - vision: ModelArchitectureVisionConfig | None = None, - audio: ModelArchitectureAudioConfig | None = None, - ): - self.architectures = architectures - self.model_type = model_type - self.quantization_config = quantization_config - self.torch_dtype = torch_dtype - self.text_config = text_config - self.per_layer_attention_cls = ( - per_layer_attention_cls if per_layer_attention_cls is not None else [] - ) - self.vision = vision - self.audio = audio - - @functools.cached_property - def support_multimodal(self) -> bool: - raise NotImplementedError + support_multimodal: bool + """Whether the model supports multimodal input.""" + + is_deepseek_mla: bool + """Whether the model is a DeepSeek MLA model.""" + + derived_max_model_len_and_key: tuple[float, str | None] + """Derived maximum model length and key from the hf config.""" diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c80cd384d375..0a08bd376bad 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -36,7 +36,6 @@ from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.config.model_arch import ModelArchitectureTextConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -117,7 +116,7 @@ def forward(self, x): class LlamaAttention(nn.Module): def __init__( self, - config: ModelArchitectureTextConfig, + config: LlamaConfig, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -287,7 +286,7 @@ def __init__( ) -> None: super().__init__() - config = config or vllm_config.model_config.model_arch_config.text_config + config = config or vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = self.get_quant_config(vllm_config) @@ -385,7 +384,7 @@ def __init__( ): super().__init__() - config = vllm_config.model_config.model_arch_config.text_config + config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -577,7 +576,7 @@ def __init__( layer_type: type[nn.Module] = LlamaDecoderLayer, ): super().__init__() - config = vllm_config.model_config.model_arch_config.text_config + config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config @@ -729,12 +728,3 @@ def permute(w: torch.Tensor, n_heads: int, attn_out: int): name = name.replace(item, mapping[item]) return name, loaded_weight - - @classmethod - def get_per_layer_attention_cls(cls, text_config: ModelArchitectureTextConfig): - if getattr(text_config, "is_causal", True): - attn_cls = Attention - else: - attn_cls = EncoderOnlyAttention - - return [attn_cls] * text_config.num_hidden_layers diff --git a/vllm/transformers_utils/model_arch_config.py b/vllm/transformers_utils/model_arch_config.py deleted file mode 100644 index 3d0a816418e9..000000000000 --- a/vllm/transformers_utils/model_arch_config.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections.abc import Callable -from pathlib import Path -from typing import Any - -from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME - -from vllm.config.model_arch import ModelArchitectureConfig -from vllm.logger import init_logger -from vllm.transformers_utils.config import ( - MISTRAL_CONFIG_NAME, - ConfigFormat, - file_or_path_exists, -) -from vllm.transformers_utils.model_arch_config_parser import ( - HFModelArchConfigParser, - ModelArchConfigParserBase, -) -from vllm.transformers_utils.utils import ( - check_gguf_file, -) - -logger = init_logger(__name__) - - -_CONFIG_FORMAT_TO_MODEL_ARCH_CONFIG_PARSER: dict[ - str, type[ModelArchConfigParserBase] -] = { - "hf": HFModelArchConfigParser, -} -SUPPORTED_ARCHITECTURES: list[str] = ["LlamaForCausalLM"] - - -def get_model_arch_config_parser(config_format: str) -> ModelArchConfigParserBase: - """Get the model architecture config parser for a given config format.""" - if config_format not in _CONFIG_FORMAT_TO_MODEL_ARCH_CONFIG_PARSER: - raise ValueError(f"Unknown config format `{config_format}`.") - return _CONFIG_FORMAT_TO_MODEL_ARCH_CONFIG_PARSER[config_format]() - - -def get_model_arch_config( - model: str | Path, - trust_remote_code: bool, - revision: str | None = None, - code_revision: str | None = None, - config_format: str | ConfigFormat = "auto", - model_arch_overrides_kw: dict[str, Any] | None = None, - model_arch_overrides_fn: Callable[ - ["ModelArchitectureConfig"], "ModelArchitectureConfig" - ] - | None = None, - **kwargs, -) -> "ModelArchitectureConfig": - # Separate model folder from file path for GGUF models - is_gguf = check_gguf_file(model) - kwargs["is_gguf"] = is_gguf - - if config_format == "auto": - try: - if is_gguf or file_or_path_exists(model, HF_CONFIG_NAME, revision=revision): - config_format = "hf" - elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): - config_format = "mistral" - else: - raise ValueError( - "Could not detect config format for no config file found. " - "With config_format 'auto', ensure your model has either " - "config.json (HF format) or params.json (Mistral format). " - "Otherwise please specify your_custom_config_format " - "in engine args for customized config parser." - ) - - except Exception as e: - error_message = ( - "Invalid repository ID or local directory specified:" - " '{model}'.\nPlease verify the following requirements:\n" - "1. Provide a valid Hugging Face repository ID.\n" - "2. Specify a local directory that contains a recognized " - "configuration file.\n" - " - For Hugging Face models: ensure the presence of a " - "'config.json'.\n" - " - For Mistral models: ensure the presence of a " - "'params.json'.\n" - "3. For GGUF: pass the local path of the GGUF checkpoint.\n" - " Loading GGUF from a remote repo directly is not yet " - "supported.\n" - ).format(model=model) - - raise ValueError(error_message) from e - - config_parser = get_model_arch_config_parser(config_format) - config_dict, config = config_parser.parse( - model, - trust_remote_code=trust_remote_code, - revision=revision, - code_revision=code_revision, - **kwargs, - ) - - if model_arch_overrides_kw: - logger.debug("Overriding model arch config with %s", model_arch_overrides_kw) - for key, value in model_arch_overrides_kw.items(): - setattr(config, key, value) - if model_arch_overrides_fn: - logger.debug("Overriding model arch config with %s", model_arch_overrides_fn) - config = model_arch_overrides_fn(config) - - return config diff --git a/vllm/transformers_utils/model_arch_config_parser.py b/vllm/transformers_utils/model_arch_config_parser.py index e3e74b180ed2..6bcc463b6c67 100644 --- a/vllm/transformers_utils/model_arch_config_parser.py +++ b/vllm/transformers_utils/model_arch_config_parser.py @@ -1,441 +1,468 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os -from abc import ABC, abstractmethod -from copy import deepcopy -from pathlib import Path -from typing import TYPE_CHECKING, Any - -import huggingface_hub -from torch import nn -from transformers import AutoConfig, PretrainedConfig -from transformers.models.auto.modeling_auto import ( - MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, - MODEL_MAPPING_NAMES, -) + +import torch +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE +from transformers import PretrainedConfig from vllm import envs from vllm.config.model_arch import ( - ModelArchitectureAudioConfig, ModelArchitectureConfig, - ModelArchitectureTextConfig, - ModelArchitectureVisionConfig, ) +from vllm.config.utils import getattr_iter from vllm.logger import init_logger from vllm.transformers_utils.config import ( - _CONFIG_REGISTRY, - _get_hf_token, - _maybe_update_auto_config_kwargs, - file_or_path_exists, - get_hf_file_to_dict, + get_hf_text_config, + try_get_safetensors_metadata, ) -from vllm.utils.import_utils import LazyLoader +from vllm.utils.torch_utils import common_broadcastable_dtype logger = init_logger(__name__) +MULTIMODAL_MODEL_ARCHS = [ + "AriaForConditionalGeneration", + "AyaVisionForConditionalGeneration", + "BeeForConditionalGeneration", + "Blip2ForConditionalGeneration", + "ChameleonForConditionalGeneration", + "CLIPEmbeddingModel", + "Cohere2VisionForConditionalGeneration", + "DeepseekOCRForCausalLM", + "DeepseekVLV2ForCausalLM", + "DotsOCRForCausalLM", + "Ernie4_5_VLMoeForConditionalGeneration", + "FuyuForCausalLM", + "Gemma3ForConditionalGeneration", + "Gemma3nForConditionalGeneration", + "GLM4VForCausalLM", + "Glm4vForConditionalGeneration", + "Glm4vMoeForConditionalGeneration", + "GraniteSpeechForConditionalGeneration", + "H2OVLChatModel", + "HCXVisionForCausalLM", + "Idefics3ForConditionalGeneration", + "InternS1ForConditionalGeneration", + "InternVLChatModel", + "KeyeForConditionalGeneration", + "KeyeVL1_5ForConditionalGeneration", + "KimiVLForConditionalGeneration", + "LightOnOCRForConditionalGeneration", + "Llama4ForConditionalGeneration", + "LlamaNemotronVLChatModel", + "LlavaForConditionalGeneration", + "LlavaNextForConditionalGeneration", + "LlavaNextVideoForConditionalGeneration", + "LlavaOnevisionForConditionalGeneration", + "MantisForConditionalGeneration", + "MiDashengLMModel", + "MiniCPMO", + "MiniCPMV", + "MiniCPMVBaseModel", + "MiniMaxVL01ForConditionalGeneration", + "Mistral3ForConditionalGeneration", + "MolmoForCausalLM", + "MultiModalMixin", + "NemotronH_Nano_VL_V2", + "NVLM_D_Model", + "Ovis", + "Ovis2_5", + "PaddleOCRVLForConditionalGeneration", + "PaliGemmaForConditionalGeneration", + "Phi3VForCausalLM", + "Phi4MMForCausalLM", + "Phi4MultimodalForCausalLM", + "PixtralForConditionalGeneration", + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5OmniThinkerForConditionalGeneration", + "Qwen2AudioForConditionalGeneration", + "Qwen2VLForConditionalGeneration", + "Qwen3OmniMoeThinkerForConditionalGeneration", + "Qwen3VLForConditionalGeneration", + "Qwen3VLMoeForConditionalGeneration", + "QwenVLForConditionalGeneration", + "RForConditionalGeneration", + "SiglipEmbeddingModel", + "SkyworkR1VChatModel", + "SmolVLMForConditionalGeneration", + "Step3VLForConditionalGeneration", + "Tarsier2ForConditionalGeneration", + "TarsierForConditionalGeneration", + "Terratorch", + "TransformersMultiModalForCausalLM", + "TransformersMultiModalMoEForCausalLM", + "TransformersMultiModalEmbeddingModel", + "TransformersMultiModalForSequenceClassification", + "UltravoxModel", + "VoxtralForConditionalGeneration", + "WhisperForConditionalGeneration", +] + + +class ModelArchConfigConvertorBase: + def __init__(self, hf_config: PretrainedConfig): + self.hf_config = hf_config + self.hf_text_config = get_hf_text_config(hf_config) + + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_hidden_layers", 0) + + def get_total_num_attention_heads(self) -> int: + return getattr(self.hf_text_config, "num_attention_heads", 0) + + def get_vocab_size(self) -> int: + return getattr(self.hf_text_config, "vocab_size", 0) + + def get_hidden_size(self) -> int: + return getattr(self.hf_text_config, "hidden_size", 0) + + def get_head_size(self) -> int: + if self.is_deepseek_mla(): + qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0) + if envs.VLLM_MLA_DISABLE: + return self.hf_text_config.kv_lora_rank + qk_rope_head_dim + else: + qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0) + if qk_rope_head_dim and qk_nope_head_dim: + return qk_rope_head_dim + qk_nope_head_dim -class ModelArchConfigParserBase(ABC): - @abstractmethod - def parse( - self, - model: str | Path, - trust_remote_code: bool, - revision: str | None = None, - code_revision: str | None = None, - **kwargs, - ) -> tuple[dict[str, Any], "ModelArchitectureConfig"]: - raise NotImplementedError - - -def extract_num_hidden_layers(config_dict: dict[str, Any], model_type: str) -> int: - if model_type in [ - "deepseek_mtp", - "mimo_mtp", - "glm4_moe_mtp", - "ernie_mtp", - "qwen3_next_mtp", - ]: - total_num_hidden_layers = config_dict.pop("num_nextn_predict_layers", 0) - elif model_type == "longcat_flash_mtp": - total_num_hidden_layers = config_dict.pop("num_nextn_predict_layers", 1) - else: - total_num_hidden_layers = config_dict.pop("num_hidden_layers", 0) - - return total_num_hidden_layers - - -def extract_use_deepseek_mla( - config_dict: dict[str, Any], model_type: str | None -) -> bool: - if not model_type: - return False - elif model_type in ( - "deepseek_v2", - "deepseek_v3", - "deepseek_v32", - "deepseek_mtp", - "kimi_k2", - "kimi_linear", - "longcat_flash", - ): - return config_dict.get("kv_lora_rank") is not None - elif model_type == "eagle": - # if the model is an EAGLE module, check for the - # underlying architecture + # NOTE: Some configs may set head_dim=None in the config + if getattr(self.hf_text_config, "head_dim", None) is not None: + return self.hf_text_config.head_dim + + # NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head` + if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None: + return self.hf_text_config.hidden_size_per_head + + # FIXME(woosuk): This may not be true for all models. return ( - config_dict["model"]["model_type"] - in ("deepseek_v2", "deepseek_v3", "deepseek_v32") - and config_dict.get("kv_lora_rank") is not None + self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads ) - return False + def get_total_num_kv_heads(self) -> int: + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + return self.hf_text_config.num_attention_heads + + def get_num_experts(self) -> int: + """Returns the number of experts in the model.""" + num_expert_names = [ + "num_experts", # Jamba + "moe_num_experts", # Dbrx + "n_routed_experts", # DeepSeek + "num_local_experts", # Mixtral + ] + num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0) + if isinstance(num_experts, list): + # Ernie VL's remote code uses list[int]... + # The values are always the same so we just take the first one. + return num_experts[0] + # Coerce to 0 if explicitly set to None + return num_experts or 0 + + def get_torch_dtype(self, model_id: str, revision: str | None): + # NOTE: getattr(config, "dtype", torch.float32) is not correct + # because config.dtype can be None. + config_dtype = getattr(self.hf_config, "dtype", None) + + # Fallbacks for multi-modal models if the root config + # does not define dtype + if config_dtype is None: + config_dtype = getattr(self.hf_text_config, "dtype", None) + if config_dtype is None and hasattr(self.hf_config, "vision_config"): + config_dtype = getattr(self.hf_config.vision_config, "dtype", None) + if config_dtype is None and hasattr(self.hf_config, "encoder_config"): + config_dtype = getattr(self.hf_config.encoder_config, "dtype", None) + + # Try to read the dtype of the weights if they are in safetensors format + if config_dtype is None: + repo_mt = try_get_safetensors_metadata(model_id, revision=revision) + + if repo_mt and (files_mt := repo_mt.files_metadata): + param_dtypes: set[torch.dtype] = { + _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] + for file_mt in files_mt.values() + for dtype_str in file_mt.parameter_count + if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE + } + + if param_dtypes: + return common_broadcastable_dtype(param_dtypes) + + if config_dtype is None: + config_dtype = torch.float32 + + return config_dtype + + @classmethod + def _normalize_quantization_config(cls, config: PretrainedConfig): + quant_cfg = getattr(config, "quantization_config", None) + if quant_cfg is None: + # compressed-tensors uses a "compression_config" key + quant_cfg = getattr(config, "compression_config", None) -def extract_head_size( - config_dict: dict[str, Any], standard_fields: dict[str, Any] -) -> int: - # TODO remove hard code - if standard_fields["use_deepseek_mla"]: - qk_rope_head_dim = config_dict.get("qk_rope_head_dim", 0) - if not envs.VLLM_MLA_DISABLE: - return config_dict["kv_lora_rank"] + qk_rope_head_dim else: - qk_nope_head_dim = config_dict.get("qk_nope_head_dim", 0) - if qk_rope_head_dim and qk_nope_head_dim: - return qk_rope_head_dim + qk_nope_head_dim - - if standard_fields["model_type"] == "zamba2": - return config_dict.pop("attention_head_dim") - - # TODO(xingyuliu): Check attention_free - - # NOTE: Some configs may set head_dim=None in the config - if config_dict.get("head_dim") is not None: - return config_dict.pop("head_dim") - - # NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head` - if config_dict.get("hidden_size_per_head") is not None: - return config_dict.pop("hidden_size_per_head") - - # FIXME(woosuk): This may not be true for all models. - return standard_fields["hidden_size"] // standard_fields["num_attention_heads"] - - -def extract_total_num_kv_heads( - config_dict: dict[str, Any], standard_fields: dict[str, Any] -) -> int: - """Returns the total number of KV heads.""" - model_type = standard_fields["model_type"] - # For GPTBigCode & Falcon: - # NOTE: for falcon, when new_decoder_architecture is True, the - # multi_query flag is ignored and we use n_head_kv for the number of - # KV heads. - falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] - new_decoder_arch_falcon = model_type in falcon_model_types and config_dict.get( - "new_decoder_architecture", False - ) - if not new_decoder_arch_falcon and config_dict.get("multi_query", False): - # Multi-query attention, only one KV head. - # Currently, tensor parallelism is not supported in this case. - return 1 - - # For DBRX and MPT - if model_type == "mpt": - if "kv_n_heads" in config_dict["attn_config"]: - return config_dict["attn_config"]["kv_n_heads"] - return standard_fields["num_attention_heads"] - if model_type == "dbrx": - attn_config = config_dict["attn_config"] - return attn_config.get("kv_n_heads", standard_fields["num_attention_heads"]) - - if model_type == "nemotron-nas": - for block in config_dict["block_configs"]: - if not block.attention.no_op: - return ( - standard_fields["num_attention_heads"] - // block.attention.n_heads_in_group - ) + # Set quant_method for ModelOpt models. + producer_name = quant_cfg.get("producer", {}).get("name") + if producer_name == "modelopt": + quant_algo = quant_cfg.get("quantization", {}).get("quant_algo") + if quant_algo == "FP8": + quant_cfg["quant_method"] = "modelopt" + elif quant_algo == "NVFP4": + quant_cfg["quant_method"] = "modelopt_fp4" + elif quant_algo is not None: + raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}") + + if quant_cfg is not None: + # Use the community standard 'quant_method' + quant_method = quant_cfg.get("quant_method", "").lower() + + # Normalize library names + quant_method = quant_method.replace( + "compressed_tensors", "compressed-tensors" + ) + + quant_cfg["quant_method"] = quant_method + + return quant_cfg + + @classmethod + def get_quantization_config(cls, hf_config: PretrainedConfig): + quant_cfg = cls._normalize_quantization_config(hf_config) + if quant_cfg is None and ( + text_config := getattr(hf_config, "text_config", None) + ): + # Check the text config as well for multi-modal models. + quant_cfg = cls._normalize_quantization_config(text_config) + return quant_cfg + + def is_deepseek_mla(self) -> bool: + if not hasattr(self.hf_text_config, "model_type"): + return False + elif self.hf_text_config.model_type in ( + "deepseek_v2", + "deepseek_v3", + "deepseek_v32", + "deepseek_mtp", + "kimi_k2", + "kimi_linear", + "longcat_flash", + "pangu_ultra_moe", + "pangu_ultra_moe_mtp", + ): + return self.hf_text_config.kv_lora_rank is not None + elif self.hf_text_config.model_type == "eagle": + # if the model is an EAGLE module, check for the + # underlying architecture + return ( + self.hf_text_config.model.model_type + in ("deepseek_v2", "deepseek_v3", "deepseek_v32") + and self.hf_text_config.kv_lora_rank is not None + ) + return False - raise RuntimeError("Couldn't determine number of kv heads") + def derive_max_model_len_and_key(self) -> tuple[float, str | None]: + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Whisper + "max_target_positions", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + # Choose the smallest "max_length" from the possible keys + max_len_key = None + for key in possible_keys: + max_len = getattr(self.hf_text_config, key, None) + if max_len is not None: + if max_len < derived_max_model_len: + max_len_key = key + derived_max_model_len = min(derived_max_model_len, max_len) + + return derived_max_model_len, max_len_key + + def support_multimodal(self) -> bool: + return any( + multi_model_arch in self.hf_config.architectures + for multi_model_arch in MULTIMODAL_MODEL_ARCHS + ) - # TODO(xingyuliu): Check attention_free - - attributes = [ - # For Falcon: - "n_head_kv", - "num_kv_heads", - # For LLaMA-2: - "num_key_value_heads", - # For ChatGLM: - "multi_query_group_num", - ] - for attr in attributes: - num_kv_heads = config_dict.get(attr) - if num_kv_heads is not None: - config_dict.pop(attr) - return num_kv_heads - - # For non-grouped-query attention models, the number of KV heads is - # equal to the number of attention heads. - return standard_fields["num_attention_heads"] - - -def extract_num_experts(config_dict: dict[str, Any]) -> int: - """Returns the number of experts in the model.""" - num_expert_names = [ - "num_experts", # Jamba - "moe_num_experts", # Dbrx - "n_routed_experts", # DeepSeek - "num_local_experts", # Mixtral - ] - for attr in num_expert_names: - num_experts = config_dict.get(attr) - if num_experts is not None: - config_dict.pop(attr) - return num_experts - - return 0 - - -def extract_standard_text_config_field( - config_dict: dict[str, Any], -) -> tuple[dict[str, Any], dict[str, Any]]: - standard_fields = {} - if "text_config" in config_dict: - text_config_dict = config_dict["text_config"] - else: - text_config_dict = deepcopy(config_dict) - standard_fields["model_type"] = text_config_dict.pop("model_type") - - standard_fields["hidden_size"] = text_config_dict.pop("hidden_size") - - (standard_fields["num_hidden_layers"]) = extract_num_hidden_layers( - text_config_dict, standard_fields["model_type"] - ) - standard_fields["num_attention_heads"] = text_config_dict.pop("num_attention_heads") - - standard_fields["use_deepseek_mla"] = extract_use_deepseek_mla( - text_config_dict, standard_fields["model_type"] - ) - standard_fields["head_dim"] = extract_head_size(text_config_dict, standard_fields) - standard_fields["vocab_size"] = text_config_dict.pop("vocab_size") - standard_fields["num_key_value_heads"] = extract_total_num_kv_heads( - text_config_dict, standard_fields - ) - standard_fields["num_experts"] = extract_num_experts(text_config_dict) - - return standard_fields, text_config_dict - - -if TYPE_CHECKING: - import vllm.model_executor.models as me_models -else: - me_models = LazyLoader("model_executor", globals(), "vllm.model_executor.models") - - -def get_per_layer_attention_cls( - architectures: list[str], - model_impl: str, - text_config: ModelArchitectureTextConfig, -) -> list[type[nn.Module]]: - assert len(architectures) == 1, "Only support len(architectures) == 1 for now" - assert model_impl == "auto" or model_impl == "vllm" - assert architectures[0] in me_models.ModelRegistry.models - model_arch = architectures[0] - model_cls = me_models.registry._try_load_model_cls( - model_arch, me_models.ModelRegistry.models[model_arch] - ) - # TODO: need to split sliding window attention from vllm.attention.layer.Attention - per_layer_attention_cls = model_cls.get_per_layer_attention_cls(text_config) - - return per_layer_attention_cls - - -def get_quantization_config( - model: str | Path, revision: str | None, config_dict: dict[str, Any] -) -> dict[str, Any]: - # ModelOpt 0.31.0 and after saves the quantization config in the model - # config file. - quantization_config = config_dict.pop("quantization_config", None) - - # ModelOpt 0.29.0 and before saves the quantization config in a separate - # "hf_quant_config.json" in the same directory as the model config file. - if quantization_config is None and file_or_path_exists( - model, "hf_quant_config.json", revision - ): - quantization_config = get_hf_file_to_dict( - "hf_quant_config.json", model, revision + def convert(self, model_id: str, revision: str | None) -> ModelArchitectureConfig: + model_arch_config = ModelArchitectureConfig( + architectures=self.hf_config.architectures, + model_type=self.hf_config.model_type, + text_model_type=getattr(self.hf_text_config, "model_type", None), + hidden_size=self.get_hidden_size(), + num_hidden_layers=self.get_num_hidden_layers(), + total_num_attention_heads=self.get_total_num_attention_heads(), + head_size=self.get_head_size(), + vocab_size=self.get_vocab_size(), + total_num_kv_heads=self.get_total_num_kv_heads(), + num_experts=self.get_num_experts(), + quantization_config=self.get_quantization_config(self.hf_config), + torch_dtype=self.get_torch_dtype(model_id, revision), + support_multimodal=self.support_multimodal(), + is_deepseek_mla=self.is_deepseek_mla(), + derived_max_model_len_and_key=self.derive_max_model_len_and_key(), ) - if quantization_config is not None: - # config.quantization_config = quantization_config - # auto-enable DeepGEMM UE8M0 if model config requests it - scale_fmt = quantization_config.get("scale_fmt", None) - if scale_fmt in ("ue8m0",): - if not envs.is_set("VLLM_USE_DEEP_GEMM_E8M0"): - os.environ["VLLM_USE_DEEP_GEMM_E8M0"] = "1" - logger.info_once( - ( - "Detected quantization_config.scale_fmt=%s; " - "enabling UE8M0 for DeepGEMM." - ), - scale_fmt, - ) - elif not envs.VLLM_USE_DEEP_GEMM_E8M0: - logger.warning_once( - ( - "Model config requests UE8M0 " - "(quantization_config.scale_fmt=%s), but " - "VLLM_USE_DEEP_GEMM_E8M0=0 is set; " - "UE8M0 for DeepGEMM disabled." - ), - scale_fmt, - ) + return model_arch_config - return quantization_config or {} - - -def get_torch_dtype(config_dict: dict[str, Any]): - config_dtype = config_dict.pop("dtype", None) - - # Fallbacks for multi-modal models if the root config - # does not define dtype - if config_dtype is None: - config_dtype = config_dict["text_config"].get("dtype", None) - if config_dtype is None and "vision_config" in config_dict: - config_dtype = config_dict["vision_config"].get("dtype", None) - if config_dtype is None and hasattr(config_dict, "encoder_config"): - config_dtype = config_dict["encoder_config"].get("dtype", None) - - return config_dtype - - -class HFModelArchConfigParser(ModelArchConfigParserBase): - def parse( - self, - model: str | Path, - trust_remote_code: bool, - revision: str | None = None, - code_revision: str | None = None, - model_impl: str = "auto", - **kwargs, - ) -> tuple[dict[str, Any], "ModelArchitectureConfig"]: - """Parse the HF config and create ModelArchitectureConfig.""" - - is_gguf = kwargs.get("is_gguf", False) - if is_gguf: - kwargs["gguf_file"] = Path(model).name - model = Path(model).parent - - kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE - - config_dict, _ = PretrainedConfig.get_config_dict( - model, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ) - # Use custom model class if it's in our registry - model_type = config_dict.get("model_type", "") - if model_type in _CONFIG_REGISTRY: - # TODO: check if need to write new config class that - # inherient ModelArchitectureTextConfig for each of those models - raise NotImplementedError - else: - # We use AutoConfig.from_pretrained to leverage some existing - # standardization in PretrainedConfig - try: - kwargs = _maybe_update_auto_config_kwargs(kwargs, model_type=model_type) - # https://github.com/huggingface/transformers/blob/e8a6eb3304033fdd9346fe3b3293309fe50de238/src/transformers/models/auto/configuration_auto.py#L1238 - config_dict = AutoConfig.from_pretrained( - model, - trust_remote_code=trust_remote_code, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ).to_dict() - except ValueError as e: - if ( - not trust_remote_code - and "requires you to execute the configuration file" in str(e) - ): - err_msg = ( - "Failed to load the model config. If the model " - "is a custom model not yet available in the " - "HuggingFace transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI." - ) - raise RuntimeError(err_msg) from e - else: - raise e - - architectures = config_dict.pop("architectures", []) - quantization_config = get_quantization_config(model, revision, config_dict) - torch_dtype = get_torch_dtype(config_dict) - - standard_fields, text_config_dict = extract_standard_text_config_field( - config_dict +class Zamba2ModelArchConfigConvertor(ModelArchConfigConvertorBase): + """Convertor for Zamba2 which uses attention_head_dim instead of head_dim.""" + + def get_head_size(self) -> int: + return getattr(self.hf_text_config, "attention_head_dim", 0) + + +class MPTModelArchConfigConvertor(ModelArchConfigConvertorBase): + """Convertor for MPT which has attn_config with kv_n_heads.""" + + def get_total_num_kv_heads(self) -> int: + if "kv_n_heads" in self.hf_text_config.attn_config: + return self.hf_text_config.attn_config["kv_n_heads"] + return self.hf_text_config.num_attention_heads + + +class DbrxModelArchConfigConvertor(ModelArchConfigConvertorBase): + """Convertor for Dbrx which has attn_config with kv_n_heads.""" + + def get_total_num_kv_heads(self) -> int: + return getattr( + self.hf_text_config.attn_config, + "kv_n_heads", + self.hf_text_config.num_attention_heads, ) - # Ensure no overlap between standard fields and remaining text config - overlap = set(standard_fields.keys()) & set(text_config_dict.keys()) - assert len(overlap) == 0, ( - f"Standard fields and text config dict should not overlap, got {overlap}" + + +class FalconModelArchConfigConvertor(ModelArchConfigConvertorBase): + """Convertor for Falcon which uses multi_query and new_decoder_architecture.""" + + def get_total_num_kv_heads(self) -> int: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + new_decoder_arch_falcon = getattr( + self.hf_text_config, "new_decoder_architecture", False ) - # Extract text config fields - text_config = ModelArchitectureTextConfig(**standard_fields, **text_config_dict) - - # Special architecture mapping check for GGUF models - if is_gguf: - if model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - raise RuntimeError(f"Can't get gguf config for {model_type}.") - model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] - architectures = [model_type] - - # Architecture mapping for models without explicit architectures field - if not architectures: - if model_type not in MODEL_MAPPING_NAMES: - logger.warning( - "Model config does not have a top-level " - "'architectures' field: expecting " - "`model_arch_overrides={'architectures': ['...']}` " - "to be passed in engine args." + + if not new_decoder_arch_falcon and getattr( + self.hf_text_config, "multi_query", False + ): + # Multi-query attention, only one KV head. + return 1 + + # Use the base implementation which checks n_head_kv, num_kv_heads, etc. + return super().get_total_num_kv_heads() + + +class NemotronNasModelArchConfigConvertor(ModelArchConfigConvertorBase): + """Convertor for Nemotron-NAS which has block_configs.""" + + def get_total_num_kv_heads(self) -> int: + for block in self.hf_text_config.block_configs: + if not block.attention.no_op: + return ( + self.hf_text_config.num_attention_heads + // block.attention.n_heads_in_group ) - else: - model_type = MODEL_MAPPING_NAMES[model_type] - architectures = [model_type] + raise RuntimeError("Couldn't determine number of kv heads") - vision_config_dict = config_dict.get("vision_config", {}) - audio_config_dict = config_dict.get("audio_config", {}) - per_layer_attention_cls = get_per_layer_attention_cls( - architectures, model_impl, text_config - ) +class DeepSeekMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) - # Create ModelArchitectureConfig - vision_config = ( - ModelArchitectureVisionConfig(**vision_config_dict) - if vision_config_dict - else None - ) - audio_config = ( - ModelArchitectureAudioConfig(**audio_config_dict) - if audio_config_dict - else None - ) - arch_config = ModelArchitectureConfig( - architectures=architectures, - model_type=model_type, - quantization_config=quantization_config, - torch_dtype=torch_dtype, - per_layer_attention_cls=per_layer_attention_cls, - text_config=text_config, - vision=vision_config, - audio=audio_config, - ) +class Qwen3NextMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class MimoMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + """Convertor for MIMO MTP.""" + + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class GLM4MoeMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + """Convertor for GLM4 MoE MTP.""" + + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class ErnieMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + """Convertor for Ernie MTP.""" + + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class PanguUltraMoeMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + """Convertor for Pangu Ultra MoE MTP.""" + + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + """Convertor for LongCat Flash MTP which defaults to 1 layer.""" + + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 1) + + +class CohereModelArchConfigConvertor(ModelArchConfigConvertorBase): + def derive_max_model_len_and_key(self) -> tuple[float, str | None]: + derived_max_model_len, max_len_key = super().derive_max_model_len_and_key() + if tmp_max_len := getattr(self.hf_text_config, "model_max_length", None): + max_len_key = "model_max_length" + derived_max_model_len = tmp_max_len + return derived_max_model_len, max_len_key + - return config_dict, arch_config +# hf_config.model_type -> convertor class +MODEL_ARCH_CONFIG_CONVERTORS = { + "zamba2": Zamba2ModelArchConfigConvertor, + "mpt": MPTModelArchConfigConvertor, + "dbrx": DbrxModelArchConfigConvertor, + "falcon": FalconModelArchConfigConvertor, + "RefinedWeb": FalconModelArchConfigConvertor, + "RefinedWebModel": FalconModelArchConfigConvertor, + "nemotron-nas": NemotronNasModelArchConfigConvertor, + "deepseek_mtp": DeepSeekMTPModelArchConfigConvertor, + "qwen3_next_mtp": Qwen3NextMTPModelArchConfigConvertor, + "mimo_mtp": MimoMTPModelArchConfigConvertor, + "glm4_moe_mtp": GLM4MoeMTPModelArchConfigConvertor, + "ernie_mtp": ErnieMTPModelArchConfigConvertor, + "pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor, + "longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor, + "commandr": CohereModelArchConfigConvertor, + "aya_vision": CohereModelArchConfigConvertor, +}