Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
from collections.abc import Callable
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from dataclasses import asdict
from functools import cache, partial
from importlib.metadata import version
from pathlib import Path
from typing import Any, Literal, TypeAlias

import huggingface_hub
from huggingface_hub import get_safetensors_metadata
import torch
from huggingface_hub import constants, get_safetensors_metadata
from packaging.version import Version
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import get_image_processor_config
from transformers.models.auto.modeling_auto import (
Expand All @@ -28,6 +31,7 @@
parse_safetensors_file_metadata,
without_trust_remote_code,
)
from vllm.utils.torch_utils import common_broadcastable_dtype

from .config_parser_base import ConfigParserBase
from .gguf_utils import (
Expand Down Expand Up @@ -135,6 +139,19 @@ def is_rope_parameters_nested(rope_parameters: dict[str, Any]) -> bool:
return set(rope_parameters.keys()).issubset(ALLOWED_ATTENTION_LAYER_TYPES)


@contextmanager
def _mistral_patch_hf_hub_constants() -> Iterator[None]:
hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE
hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE
constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors"
constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json"
try:
yield
finally:
constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
Comment on lines +142 to +152
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The modification of global constants in huggingface_hub.constants is not thread-safe. In a scenario where multiple models are loaded concurrently in different threads (e.g., one Mistral model and one standard Hugging Face model), this monkey-patching can create a race condition. One thread might be expecting the default constant values while another has temporarily changed them, potentially leading to FileNotFoundError or other unpredictable behavior during model loading. To prevent this, the critical section where constants are modified should be protected by a lock.

Please also add import threading at the top of the file.

_mistral_patch_lock = threading.Lock()


@contextmanager
def _mistral_patch_hf_hub_constants() -> Iterator[None]:
    with _mistral_patch_lock:
        hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE
        hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE
        constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors"
        constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json"
        try:
            yield
        finally:
            constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file
            constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file



class HFConfigParser(ConfigParserBase):
def parse(
self,
Expand Down Expand Up @@ -245,6 +262,25 @@ def parse(
except OSError: # Not found
hf_config_dict = {}

if config_dict.get("dtype") is None:
with _mistral_patch_hf_hub_constants():
model_str = model if isinstance(model, str) else model.as_posix()
param_mt = get_safetensors_params_metadata(model_str, revision=revision)
if param_mt:
param_dtypes: set[torch.dtype] = {
_SAFETENSORS_TO_TORCH_DTYPE[dtype]
for info in param_mt.values()
if (dtype := info.get("dtype", None))
and dtype in _SAFETENSORS_TO_TORCH_DTYPE
}

if param_dtypes:
config_dict["dtype"] = common_broadcastable_dtype(param_dtypes)
logger.info_once(
"Inferred from consolidated*.safetensors files "
f"{config_dict['dtype']} dtype."
)

config = adapt_config_dict(config_dict, defaults=hf_config_dict)

return config_dict, config
Expand Down
17 changes: 10 additions & 7 deletions vllm/transformers_utils/configs/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,13 @@ def _remap_mistral_vision_args(config: dict) -> dict:

def _remap_mistral_yarn_args(config: dict) -> dict:
yarn_config_map = {
"factor": "factor",
"original_max_position_embeddings": "original_max_position_embeddings",
"beta": "beta_fast",
"alpha": "beta_slow",
"apply_scale": "apply_yarn_scaling",
"factor": ("factor", float),
"original_max_position_embeddings": ("original_max_position_embeddings", int),
"beta": ("beta_fast", float),
"alpha": ("beta_slow", float),
"apply_scale": ("apply_yarn_scaling", bool),
}

yarn_config = config.get("yarn") or {}
config["rope_parameters"] = {
"rope_type": "yarn",
Expand All @@ -128,9 +129,10 @@ def _remap_mistral_yarn_args(config: dict) -> dict:
if rope_theta := config.pop("rope_theta", None):
config["rope_parameters"]["rope_theta"] = rope_theta

for old_name, new_name in yarn_config_map.items():
for old_name, (new_name, cast) in yarn_config_map.items():
if old_name in yarn_config:
config["rope_parameters"][new_name] = yarn_config.pop(old_name)
# Cast to remove Transformers > v5 type warnings
config["rope_parameters"][new_name] = cast(yarn_config.pop(old_name))

assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"

Expand All @@ -154,6 +156,7 @@ def _remap_general_mistral_args(config: dict) -> dict:
"tie_word_embeddings": ("tied_embeddings", False),
"max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
"max_position_embeddings": ("max_position_embeddings", 128_000),
"dtype": ("dtype", config.get("dtype")),
}

for key, new_key in config_mapping.items():
Expand Down
22 changes: 1 addition & 21 deletions vllm/transformers_utils/model_arch_config_convertor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterator
from contextlib import contextmanager
from typing import final

import torch
from huggingface_hub import constants
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig

Expand All @@ -25,22 +22,6 @@
logger = init_logger(__name__)


@contextmanager
def _maybe_patch_hf_hub_constants(config_format: ConfigFormat) -> Iterator[None]:
if config_format == "mistral":
hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE
hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE
constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors"
constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json"
try:
yield
finally:
constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
else:
yield


class ModelArchConfigConvertorBase:
def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig):
self.hf_config = hf_config
Expand Down Expand Up @@ -164,8 +145,7 @@ def get_torch_dtype(

# Try to read the dtype of the weights if they are in safetensors format
if config_dtype is None:
with _maybe_patch_hf_hub_constants(config_format):
param_mt = get_safetensors_params_metadata(model_id, revision=revision)
param_mt = get_safetensors_params_metadata(model_id, revision=revision)

if param_mt:
param_dtypes: set[torch.dtype] = {
Expand Down
Loading