diff --git a/vllm/config/model.py b/vllm/config/model.py index 764bdf700056..978852bfab4d 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -494,6 +494,17 @@ def __post_init__( ) self.hf_config = hf_config + + # Ensure Gemma2 configs have hidden_act for backward compatibility. + # GGUF configs may only have hidden_activation; model code expects both. + if ( + hasattr(hf_config, "model_type") + and hf_config.model_type == "gemma2" + and not hasattr(hf_config, "hidden_act") + and hasattr(hf_config, "hidden_activation") + ): + hf_config.hidden_act = hf_config.hidden_activation + if dict_overrides: self._apply_dict_overrides(hf_config, dict_overrides) self.hf_text_config = get_hf_text_config(self.hf_config) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 607bb44cddd2..6c2640b8156f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -389,11 +389,39 @@ def _get_quantization_config( ) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: - raise ValueError( - f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}" - ) + # Handle dtype conflict between model restrictions and + # quantization restrictions (e.g., Gemma3 GGUF on Blackwell + # where Gemma3 blocks float16 and GGUF blocks bfloat16) + from vllm.config.model import _is_valid_dtype + + model_type = getattr(model_config.hf_config, "model_type", None) + compatible_dtypes = [ + d + for d in supported_dtypes + if model_type is None or _is_valid_dtype(model_type, d) + ] + if compatible_dtypes: + # Prefer float16 > bfloat16 > float32 for performance + dtype_preference = [torch.float16, torch.bfloat16, torch.float32] + for preferred in dtype_preference: + if preferred in compatible_dtypes: + logger.warning( + "dtype=%s is not supported for quantization " + "method %s with model type %s. " + "Automatically selecting %s as compatible dtype.", + model_config.dtype, + model_config.quantization, + model_type, + preferred, + ) + model_config.dtype = preferred + break + else: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}" + ) quant_config.maybe_update_config(model_config.model) return quant_config return None diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 114516ff07a1..31c6084c9b50 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools import pickle +import threading import time from contextlib import contextmanager from dataclasses import dataclass, field @@ -43,6 +44,33 @@ from_bytes_big = functools.partial(int.from_bytes, byteorder="big") +# Memory fence for cross-process shared memory visibility. +# Required for correct producer-consumer synchronization when using +# shared memory without locks. +_memory_fence_lock = threading.Lock() + + +def memory_fence(): + """ + Full memory barrier for shared memory synchronization. + + Ensures all prior memory writes are visible to other processes before + any subsequent reads. This is critical for lock-free producer-consumer + patterns using shared memory. + + Implementation acquires and immediately releases a lock. Python's + threading.Lock provides sequentially consistent memory barrier semantics + across all major platforms (POSIX, Windows). This is a lightweight + operation (~20ns) that guarantees: + - All stores before the barrier are visible to other threads/processes + - All loads after the barrier see the latest values + """ + # Lock acquire/release provides full memory barrier semantics. + # Using context manager ensures lock release even on exceptions. + with _memory_fence_lock: + pass + + def to_bytes_big(value: int, size: int) -> bytes: return value.to_bytes(size, byteorder="big") @@ -414,6 +442,10 @@ def acquire_write(self, timeout: float | None = None): n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + # Memory fence ensures we see the latest read flags from readers. + # Without this, we may read stale flags from our CPU cache and + # spin indefinitely even though readers have completed. + memory_fence() read_count = sum(metadata_buffer[1:]) written_flag = metadata_buffer[0] if written_flag and read_count != self.buffer.n_reader: @@ -458,6 +490,10 @@ def acquire_write(self, timeout: float | None = None): metadata_buffer[i] = 0 # mark the block as written metadata_buffer[0] = 1 + # Memory fence ensures the write is visible to readers on other cores + # before we proceed. Without this, readers may spin indefinitely + # waiting for a write that's stuck in our CPU's store buffer. + memory_fence() self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks break @@ -473,6 +509,10 @@ def acquire_read( n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + # Memory fence ensures we see the latest writes from the writer. + # Without this, we may read stale flags from our CPU cache + # and spin indefinitely even though writer has updated them. + memory_fence() read_flag = metadata_buffer[self.local_reader_rank + 1] written_flag = metadata_buffer[0] if not written_flag or read_flag: @@ -513,6 +553,10 @@ def acquire_read( # caller has read from the buffer # set the read flag metadata_buffer[self.local_reader_rank + 1] = 1 + # Memory fence ensures the read flag is visible to the writer. + # Without this, writer may not see our read completion and + # could wait indefinitely for all readers to finish. + memory_fence() self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks self._read_spin_timer.record_activity() diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 13aa2bcad21b..5c9da892f001 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -33,6 +33,7 @@ ) from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -52,6 +53,14 @@ def get_name(self) -> QuantizationMethods: return "gguf" def get_supported_act_dtypes(self) -> list[torch.dtype]: + # GGUF dequantization kernels use half precision (fp16) internally. + # bfloat16 has precision issues on SM 10.0+ devices (Blackwell). + if current_platform.has_device_capability(100): + logger.warning_once( + "GGUF has precision issues with bfloat16 on Blackwell (SM 10.0+). " + "bfloat16 is unavailable." + ) + return [torch.half, torch.float32] return [torch.half, torch.bfloat16, torch.float32] @classmethod diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 7f94bd234fd3..c00d9c850c89 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -147,6 +147,11 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): ) ) + # For models with tied word embeddings, lm_head.weight is initialized + # from embed_tokens and doesn't need to be mapped from GGUF file + if getattr(config, "tie_word_embeddings", False): + sideload_params.append(re.compile(r"lm_head\.weight")) + arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): if value == model_type: diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index cb36e0482458..256df3dfaa17 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -267,6 +267,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -366,6 +368,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue if is_pp_missing_parameter(name, self): continue + # Skip parameters not in the model (e.g., GGUF quantization + # metadata like qweight_type for embeddings) + if name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 2d9dfbd3e768..7968bb17939c 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -49,6 +49,7 @@ MambaStateShapeCalculator, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -435,6 +436,7 @@ def __init__( self, config: NemotronHConfig, layer_idx: int, + max_position_embeddings: int, model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, @@ -490,13 +492,25 @@ def __init__( prefix=f"{prefix}.attn", ) + # Rotary embeddings for positional encoding + self.max_position_embeddings = max_position_embeddings + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + is_neox_style=True, + dtype=model_config.dtype if model_config else torch.get_default_dtype(), + ) + def forward( self, + positions: torch.Tensor, hidden_states: torch.Tensor, **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -518,9 +532,10 @@ def __init__( self.mixer = NemotronHAttention( config, layer_idx, - model_config, - cache_config, - quant_config, + max_position_embeddings=config.max_position_embeddings, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, prefix=f"{prefix}.mixer", ) @@ -539,7 +554,7 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states=hidden_states) + hidden_states = self.mixer(positions=positions, hidden_states=hidden_states) return hidden_states, residual @@ -659,6 +674,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: + # Skip rotary embeddings - they are computed dynamically + if "rotary_emb.inv_freq" in name: + continue + if "scale" in name: # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/vllm/tokenizers/hf.py b/vllm/tokenizers/hf.py index 344507312038..9e85f89c2b67 100644 --- a/vllm/tokenizers/hf.py +++ b/vllm/tokenizers/hf.py @@ -7,11 +7,15 @@ from transformers import AutoTokenizer +from vllm.logger import init_logger from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config +from vllm.transformers_utils.gguf_utils import extract_eos_token_id_from_gguf from .protocol import TokenizerLike from .registry import TokenizerRegistry +logger = init_logger(__name__) + if TYPE_CHECKING: from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -121,4 +125,22 @@ def from_pretrained( } tokenizer.add_special_tokens(special_tokens_map) + # Patch EOS token ID from GGUF metadata if available + # GGUF files may have a different EOS token ID than HF tokenizer config + # (e.g., Gemma uses ID 106 as EOS, but HF reports ID 1) + gguf_file = kwargs.get("gguf_file") + if gguf_file: + gguf_path = Path(path_or_repo_id) / gguf_file + gguf_eos_id = extract_eos_token_id_from_gguf(str(gguf_path)) + if gguf_eos_id is not None: + hf_eos_id = tokenizer.eos_token_id + if hf_eos_id != gguf_eos_id: + logger.info( + "Patching tokenizer eos_token_id from %d to %d " + "(using GGUF metadata)", + hf_eos_id, + gguf_eos_id, + ) + tokenizer.eos_token_id = gguf_eos_id + return get_cached_tokenizer(tokenizer) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index d761802da940..0cf3d6ceb3a6 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -31,6 +31,7 @@ from .config_parser_base import ConfigParserBase from .gguf_utils import ( check_gguf_file, + extract_hf_config_from_gguf, is_gguf, is_remote_gguf, split_remote_gguf, @@ -223,15 +224,87 @@ def parse( return config_dict, config +class GGUFConfigParser(ConfigParserBase): + """Config parser that extracts configuration from GGUF metadata. + + This parser is used for GGUF models from repositories that don't include + config.json (e.g., bartowski repos). It reads the GGUF file metadata + directly to construct a HuggingFace-compatible configuration. + """ + + def parse( + self, + model: str | Path, + trust_remote_code: bool, + revision: str | None = None, + code_revision: str | None = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: + # Get the GGUF file path from kwargs + gguf_file = kwargs.get("gguf_file") + gguf_path = str(Path(model) / gguf_file) if gguf_file else str(model) + + # Extract config from GGUF metadata + config_dict = extract_hf_config_from_gguf(gguf_path) + if config_dict is None: + raise ValueError( + f"Failed to extract config from GGUF file: {gguf_path}. " + "The GGUF file may be corrupted or missing required metadata." + ) + + model_type = config_dict.get("model_type") + + # Use hf_overrides if provided + if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None: + config_dict.update(hf_overrides) + model_type = config_dict.get("model_type", model_type) + + # Create config using AutoConfig with the extracted dict + # We need to create a config class based on model_type + if model_type is not None and model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] + config = config_class(**config_dict) + else: + # Use AutoConfig to get the appropriate config class + try: + config_class = AutoConfig.for_model(model_type) + # Filter config_dict to only include valid keys for this config + valid_keys = ( + set(config_class.__dataclass_fields__.keys()) + if hasattr(config_class, "__dataclass_fields__") + else set(config_class().__dict__.keys()) + ) + filtered_dict = { + k: v + for k, v in config_dict.items() + if k in valid_keys or k == "model_type" + } + config = config_class(**filtered_dict) + except Exception as e: + logger.warning( + "Failed to create config with AutoConfig.for_model(%s): %s. " + "Falling back to PretrainedConfig.", + model_type, + e, + ) + # Fallback to basic PretrainedConfig + config = PretrainedConfig(**config_dict) + + config = _maybe_remap_hf_config_attrs(config) + return config_dict, config + + _CONFIG_FORMAT_TO_CONFIG_PARSER: dict[str, type[ConfigParserBase]] = { "hf": HFConfigParser, "mistral": MistralConfigParser, + "gguf": GGUFConfigParser, } ConfigFormat = Literal[ "auto", "hf", "mistral", + "gguf", ] @@ -556,13 +629,18 @@ def get_config( # Transformers implementation. if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): config_format = "mistral" - elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists( - model, HF_CONFIG_NAME, revision=revision - ): + elif file_or_path_exists(model, HF_CONFIG_NAME, revision=revision): config_format = "hf" + # Local GGUF files without config.json - extract config from GGUF metadata + elif _is_gguf and not _is_remote_gguf: + logger.info( + "No config.json found for local GGUF model. " + "Extracting config from GGUF metadata." + ) + config_format = "gguf" # Remote GGUF models must have config.json in repo, # otherwise the config can't be parsed correctly. - # FIXME(Isotr0py): Support remote GGUF repos without config.json + # TODO(Isotr0py): Support remote GGUF repos without config.json elif _is_remote_gguf and not file_or_path_exists( model, HF_CONFIG_NAME, revision=revision ): diff --git a/vllm/transformers_utils/gguf_utils.py b/vllm/transformers_utils/gguf_utils.py index f3fd43c6ace5..ddde1fe398f3 100644 --- a/vllm/transformers_utils/gguf_utils.py +++ b/vllm/transformers_utils/gguf_utils.py @@ -5,6 +5,7 @@ from functools import cache from os import PathLike from pathlib import Path +from typing import Any import gguf import regex as re @@ -199,6 +200,270 @@ def extract_vision_config_from_gguf(mmproj_path: str) -> "SiglipVisionConfig | N return config +# Mapping from GGUF architecture names to HuggingFace model_type +GGUF_ARCH_TO_HF_MODEL_TYPE: dict[str, str] = { + "llama": "llama", + "phi3": "phi3", + "gemma": "gemma", + "gemma2": "gemma2", + "qwen2": "qwen2", + "qwen3": "qwen2", # Qwen3 uses qwen2 architecture in HF + "starcoder2": "starcoder2", + "gpt2": "gpt2", + "mistral": "mistral", + "mixtral": "mixtral", + "falcon": "falcon", + "phi2": "phi", + "phi": "phi", + "baichuan": "baichuan", + "internlm2": "internlm2", + "mamba": "mamba", + "nemotron": "nemotron", +} + + +def extract_hf_config_from_gguf(model: str) -> dict[str, Any] | None: + """Extract HuggingFace-compatible config dict from GGUF metadata. + + This function reads GGUF metadata and constructs a config dictionary + that can be used to create a PretrainedConfig. Useful for GGUF repos + that don't include config.json (e.g., bartowski repos). + + Args: + model: Path to GGUF model file + + Returns: + Dictionary with HF-compatible config values, or None if extraction fails + + Raises: + Exception: Exceptions from GGUF reading propagate directly + """ + # Use check_gguf_file to validate - it reads the header magic bytes + # This handles both .gguf extension and HuggingFace cache blob paths + if not check_gguf_file(model): + return None + + try: + model_path = Path(model) + + reader = gguf.GGUFReader(str(model_path)) + + # Get architecture name + arch_field = reader.get_field(Keys.General.ARCHITECTURE) + if arch_field is None: + logger.warning("No architecture field found in GGUF metadata") + return None + + arch = bytes(arch_field.parts[-1]).decode("utf-8") + logger.info("Extracting config from GGUF metadata (architecture: %s)", arch) + + # Map GGUF architecture to HF model_type + model_type = GGUF_ARCH_TO_HF_MODEL_TYPE.get(arch, arch) + + config_dict: dict[str, Any] = { + "model_type": model_type, + } + + # Helper to extract field value + def get_field_value(key: str, default=None): + field = reader.get_field(key.format(arch=arch)) + if field is not None: + val = field.parts[-1] + # Handle arrays vs scalars + if hasattr(val, "__len__") and len(val) == 1: + return val[0] + return val + return default + + # Extract core architecture parameters + # Using arch-specific keys from gguf.constants.Keys + + # Context length -> max_position_embeddings + ctx_len = get_field_value(Keys.LLM.CONTEXT_LENGTH) + if ctx_len is not None: + config_dict["max_position_embeddings"] = int(ctx_len) + + # Embedding length -> hidden_size + embed_len = get_field_value(Keys.LLM.EMBEDDING_LENGTH) + if embed_len is not None: + config_dict["hidden_size"] = int(embed_len) + + # Feed forward length -> intermediate_size + ff_len = get_field_value(Keys.LLM.FEED_FORWARD_LENGTH) + if ff_len is not None: + config_dict["intermediate_size"] = int(ff_len) + + # Block count -> num_hidden_layers + block_count = get_field_value(Keys.LLM.BLOCK_COUNT) + if block_count is not None: + config_dict["num_hidden_layers"] = int(block_count) + + # Attention head count -> num_attention_heads + head_count = get_field_value(Keys.Attention.HEAD_COUNT) + if head_count is not None: + config_dict["num_attention_heads"] = int(head_count) + + # KV head count -> num_key_value_heads + kv_head_count = get_field_value(Keys.Attention.HEAD_COUNT_KV) + if kv_head_count is not None: + config_dict["num_key_value_heads"] = int(kv_head_count) + + # RoPE frequency base -> rope_theta + rope_freq = get_field_value(Keys.Rope.FREQ_BASE) + if rope_freq is not None: + config_dict["rope_theta"] = float(rope_freq) + + # Layer norm epsilon + rms_eps = get_field_value(Keys.Attention.LAYERNORM_RMS_EPS) + if rms_eps is not None: + config_dict["rms_norm_eps"] = float(rms_eps) + + # Sliding window attention + sliding_window = get_field_value(Keys.Attention.SLIDING_WINDOW) + if sliding_window is not None: + config_dict["sliding_window"] = int(sliding_window) + + # Vocab size - from tokenizer tokens list or arch-specific field + vocab_size = get_field_value(Keys.LLM.VOCAB_SIZE) + if vocab_size is None: + tokens_field = reader.get_field(Keys.Tokenizer.LIST) + if tokens_field is not None: + vocab_size = len(tokens_field.parts[-1]) + if vocab_size is not None: + config_dict["vocab_size"] = int(vocab_size) + + # Token IDs + bos_id = get_field_value(Keys.Tokenizer.BOS_ID) + if bos_id is not None: + config_dict["bos_token_id"] = int(bos_id) + + eos_id = get_field_value(Keys.Tokenizer.EOS_ID) + if eos_id is not None: + config_dict["eos_token_id"] = int(eos_id) + + # Attention softcapping (for Gemma2, etc.) + attn_softcap = get_field_value(Keys.LLM.ATTN_LOGIT_SOFTCAPPING) + if attn_softcap is not None: + config_dict["attn_logit_softcapping"] = float(attn_softcap) + + final_softcap = get_field_value(Keys.LLM.FINAL_LOGIT_SOFTCAPPING) + if final_softcap is not None: + config_dict["final_logit_softcapping"] = float(final_softcap) + + logger.info( + "Extracted %d config fields from GGUF metadata for %s", + len(config_dict), + model_type, + ) + + return config_dict + + except Exception as e: + logger.warning("Error extracting config from GGUF: %s", e) + return None + + +def extract_softcap_from_gguf(model: str) -> dict[str, float]: + """Extract attention and final logit softcap values from GGUF metadata. + + Reads softcap parameters from GGUF metadata using arch-specific keys. + These parameters are critical for models like Gemma2 where attention + logit softcapping prevents numerical instability. + + Args: + model: Path to GGUF model file + + Returns: + Dictionary with 'attn_logit_softcapping' and/or 'final_logit_softcapping' + keys if found in GGUF metadata, empty dict otherwise + """ + if not model.endswith(".gguf"): + return {} + + try: + model_path = Path(model) + if not model_path.is_file(): + return {} + + reader = gguf.GGUFReader(str(model_path)) + + # Get architecture name to build arch-specific keys + arch_field = reader.get_field(Keys.General.ARCHITECTURE) + if arch_field is None: + logger.debug("No architecture field found in GGUF metadata") + return {} + + arch = bytes(arch_field.parts[-1]).decode("utf-8") + + result = {} + + # Extract attention logit softcapping + attn_key = Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=arch) + attn_field = reader.get_field(attn_key) + if attn_field is not None: + result["attn_logit_softcapping"] = float(attn_field.parts[-1]) + logger.info( + "Extracted attn_logit_softcapping=%.2f from GGUF metadata", + result["attn_logit_softcapping"], + ) + + # Extract final logit softcapping + final_key = Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=arch) + final_field = reader.get_field(final_key) + if final_field is not None: + result["final_logit_softcapping"] = float(final_field.parts[-1]) + logger.info( + "Extracted final_logit_softcapping=%.2f from GGUF metadata", + result["final_logit_softcapping"], + ) + + return result + + except Exception as e: + logger.debug("Error extracting softcap from GGUF: %s", e) + return {} + + +def extract_eos_token_id_from_gguf(model: str) -> int | None: + """Extract EOS token ID from GGUF metadata. + + GGUF files store the EOS token ID in tokenizer.ggml.eos_token_id field. + This may differ from HuggingFace's tokenizer config (e.g., Gemma models + use token ID 106 as EOS in GGUF, but HF tokenizer reports + token ID 1). + + Args: + model: Path to GGUF model file + + Returns: + EOS token ID from GGUF metadata, or None if not found + """ + if not model.endswith(".gguf"): + return None + + try: + model_path = Path(model) + if not model_path.is_file(): + return None + + reader = gguf.GGUFReader(str(model_path)) + + eos_field = reader.get_field(Keys.Tokenizer.EOS_ID) + if eos_field is not None: + eos_token_id = int(eos_field.parts[-1][0]) + logger.debug( + "Extracted eos_token_id=%d from GGUF metadata", + eos_token_id, + ) + return eos_token_id + + return None + + except Exception as e: + logger.debug("Error extracting EOS token ID from GGUF: %s", e) + return None + + def maybe_patch_hf_config_from_gguf( model: str, hf_config: PretrainedConfig, @@ -207,7 +472,8 @@ def maybe_patch_hf_config_from_gguf( Applies GGUF-specific patches to HuggingFace config: 1. For multimodal models: patches architecture and vision config - 2. For all GGUF models: overrides vocab_size from embedding tensor + 2. For models with softcap (e.g., Gemma2): patches attention/logit softcapping + 3. For all GGUF models: overrides vocab_size from embedding tensor This ensures compatibility with GGUF models that have extended vocabularies (e.g., Unsloth) where the GGUF file contains more @@ -236,6 +502,15 @@ def maybe_patch_hf_config_from_gguf( ) hf_config = new_hf_config + # Patch softcap parameters from GGUF metadata + # Critical for models like Gemma2 where attention softcapping + # prevents numerical instability and ensures correct output + softcap_params = extract_softcap_from_gguf(model) + if "attn_logit_softcapping" in softcap_params: + hf_config.attn_logit_softcapping = softcap_params["attn_logit_softcapping"] + if "final_logit_softcapping" in softcap_params: + hf_config.final_logit_softcapping = softcap_params["final_logit_softcapping"] + return hf_config diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 4dd478804049..37e8a359717a 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import multiprocessing +import threading from concurrent.futures import Future, ThreadPoolExecutor from typing import TYPE_CHECKING @@ -63,39 +64,62 @@ def __init__(self, vllm_config: VllmConfig): max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8)) self.executor_for_fillmask = ThreadPoolExecutor(max_workers=max_workers) - if not self.vllm_config.model_config.skip_tokenizer_init: - # The default max_workers if not specified is the number of - # CPUs * 5, which is way too high since these tasks are CPU-bound, - # not I/O bound. We also know we would never dominate CPU usage - # with just grammar compilation, so we set it to half the number - # of CPUs. - max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) - self.executor = ThreadPoolExecutor(max_workers=max_workers) - self.tokenizer = init_tokenizer_from_config( - model_config=self.vllm_config.model_config - ) - reasoning_parser = ( - self.vllm_config.structured_outputs_config.reasoning_parser - ) - reasoning_parser_plugin = ( - self.vllm_config.structured_outputs_config.reasoning_parser_plugin - ) - if reasoning_parser_plugin and len(reasoning_parser_plugin) > 3: - ReasoningParserManager.import_reasoning_parser(reasoning_parser_plugin) - - reasoning_parser = ( - self.vllm_config.structured_outputs_config.reasoning_parser - ) - if reasoning_parser: - reasoner_cls = ReasoningParserManager.get_reasoning_parser( - reasoning_parser - ) - self.reasoner = reasoner_cls(tokenizer=self.tokenizer) + # Tokenizer is loaded lazily to avoid duplicate tokenizer initialization + # in multiprocess mode. For GGUF models, this prevents a semaphore leak + # that causes server hangs (tokenizer builds merges on the fly, which + # uses multiprocessing primitives that don't clean up in subprocesses). + self._tokenizer = None + self._tokenizer_initialized = False + self._tokenizer_init_lock = threading.Lock() + self.executor: ThreadPoolExecutor | None = None self.enable_in_reasoning = ( self.vllm_config.structured_outputs_config.enable_in_reasoning ) + @property + def tokenizer(self): + """Lazily initialize tokenizer when first accessed (thread-safe).""" + # Double-checked locking pattern for thread-safe lazy initialization + if not self._tokenizer_initialized: + with self._tokenizer_init_lock: + if not self._tokenizer_initialized: + self._init_tokenizer() + return self._tokenizer + + def _init_tokenizer(self): + """Initialize tokenizer and related components on first use.""" + if self._tokenizer_initialized: + return + + if self.vllm_config.model_config.skip_tokenizer_init: + self._tokenizer_initialized = True + return + + # The default max_workers if not specified is the number of + # CPUs * 5, which is way too high since these tasks are CPU-bound, + # not I/O bound. We also know we would never dominate CPU usage + # with just grammar compilation, so we set it to half the number + # of CPUs. + max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self._tokenizer = init_tokenizer_from_config( + model_config=self.vllm_config.model_config + ) + + reasoning_parser = self.vllm_config.structured_outputs_config.reasoning_parser + reasoning_parser_plugin = ( + self.vllm_config.structured_outputs_config.reasoning_parser_plugin + ) + if reasoning_parser_plugin and len(reasoning_parser_plugin) > 3: + ReasoningParserManager.import_reasoning_parser(reasoning_parser_plugin) + + if reasoning_parser: + reasoner_cls = ReasoningParserManager.get_reasoning_parser(reasoning_parser) + self.reasoner = reasoner_cls(tokenizer=self._tokenizer) + + self._tokenizer_initialized = True + def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: return @@ -149,6 +173,11 @@ def grammar_init(self, request: Request) -> None: raise ValueError(f"Unsupported structured output backend: {backend}") if self._use_async_grammar_compilation: + # Ensure tokenizer (and executor) is initialized + _ = self.tokenizer + assert self.executor is not None, ( + "Executor should be initialized with tokenizer" + ) grammar = self.executor.submit(self._create_grammar, request) else: grammar = self._create_grammar(request) # type: ignore[assignment]