diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5aa984515b85..6313d34a6b38 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from collections.abc import Callable +from collections.abc import Callable, Iterator +from contextlib import contextmanager from dataclasses import asdict from functools import cache, partial from importlib.metadata import version @@ -10,8 +11,10 @@ from typing import Any, Literal, TypeAlias import huggingface_hub -from huggingface_hub import get_safetensors_metadata +import torch +from huggingface_hub import constants, get_safetensors_metadata from packaging.version import Version +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import get_image_processor_config from transformers.models.auto.modeling_auto import ( @@ -28,6 +31,7 @@ parse_safetensors_file_metadata, without_trust_remote_code, ) +from vllm.utils.torch_utils import common_broadcastable_dtype from .config_parser_base import ConfigParserBase from .gguf_utils import ( @@ -135,6 +139,19 @@ def is_rope_parameters_nested(rope_parameters: dict[str, Any]) -> bool: return set(rope_parameters.keys()).issubset(ALLOWED_ATTENTION_LAYER_TYPES) +@contextmanager +def _mistral_patch_hf_hub_constants() -> Iterator[None]: + hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE + hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE + constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors" + constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json" + try: + yield + finally: + constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file + constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file + + class HFConfigParser(ConfigParserBase): def parse( self, @@ -245,6 +262,25 @@ def parse( except OSError: # Not found hf_config_dict = {} + if config_dict.get("dtype") is None: + with _mistral_patch_hf_hub_constants(): + model_str = model if isinstance(model, str) else model.as_posix() + param_mt = get_safetensors_params_metadata(model_str, revision=revision) + if param_mt: + param_dtypes: set[torch.dtype] = { + _SAFETENSORS_TO_TORCH_DTYPE[dtype] + for info in param_mt.values() + if (dtype := info.get("dtype", None)) + and dtype in _SAFETENSORS_TO_TORCH_DTYPE + } + + if param_dtypes: + config_dict["dtype"] = common_broadcastable_dtype(param_dtypes) + logger.info_once( + "Inferred from consolidated*.safetensors files " + f"{config_dict['dtype']} dtype." + ) + config = adapt_config_dict(config_dict, defaults=hf_config_dict) return config_dict, config diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py index 1e1e49f7c11f..90728bbffb60 100644 --- a/vllm/transformers_utils/configs/mistral.py +++ b/vllm/transformers_utils/configs/mistral.py @@ -113,12 +113,13 @@ def _remap_mistral_vision_args(config: dict) -> dict: def _remap_mistral_yarn_args(config: dict) -> dict: yarn_config_map = { - "factor": "factor", - "original_max_position_embeddings": "original_max_position_embeddings", - "beta": "beta_fast", - "alpha": "beta_slow", - "apply_scale": "apply_yarn_scaling", + "factor": ("factor", float), + "original_max_position_embeddings": ("original_max_position_embeddings", int), + "beta": ("beta_fast", float), + "alpha": ("beta_slow", float), + "apply_scale": ("apply_yarn_scaling", bool), } + yarn_config = config.get("yarn") or {} config["rope_parameters"] = { "rope_type": "yarn", @@ -128,9 +129,10 @@ def _remap_mistral_yarn_args(config: dict) -> dict: if rope_theta := config.pop("rope_theta", None): config["rope_parameters"]["rope_theta"] = rope_theta - for old_name, new_name in yarn_config_map.items(): + for old_name, (new_name, cast) in yarn_config_map.items(): if old_name in yarn_config: - config["rope_parameters"][new_name] = yarn_config.pop(old_name) + # Cast to remove Transformers > v5 type warnings + config["rope_parameters"][new_name] = cast(yarn_config.pop(old_name)) assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}" @@ -154,6 +156,7 @@ def _remap_general_mistral_args(config: dict) -> dict: "tie_word_embeddings": ("tied_embeddings", False), "max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)), "max_position_embeddings": ("max_position_embeddings", 128_000), + "dtype": ("dtype", config.get("dtype")), } for key, new_key in config_mapping.items(): diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 3aeb375028ab..b01592aa3291 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -1,12 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterator -from contextlib import contextmanager from typing import final import torch -from huggingface_hub import constants from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from transformers import PretrainedConfig @@ -25,22 +22,6 @@ logger = init_logger(__name__) -@contextmanager -def _maybe_patch_hf_hub_constants(config_format: ConfigFormat) -> Iterator[None]: - if config_format == "mistral": - hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE - hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE - constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors" - constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json" - try: - yield - finally: - constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file - constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file - else: - yield - - class ModelArchConfigConvertorBase: def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig): self.hf_config = hf_config @@ -164,8 +145,7 @@ def get_torch_dtype( # Try to read the dtype of the weights if they are in safetensors format if config_dtype is None: - with _maybe_patch_hf_hub_constants(config_format): - param_mt = get_safetensors_params_metadata(model_id, revision=revision) + param_mt = get_safetensors_params_metadata(model_id, revision=revision) if param_mt: param_dtypes: set[torch.dtype] = {