From 00c20ae3947711582d7d54f0cae58091995f34bf Mon Sep 17 00:00:00 2001 From: Christina Date: Tue, 9 Dec 2025 18:42:01 -0600 Subject: [PATCH 01/17] fix: Address review feedback - platform-agnostic docs and context manager Signed-off-by: Christina (cherry picked from commit 0f27680b0cb1ff8138fb2593185a5d9088f7803c) --- .../device_communicators/shm_broadcast.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 114516ff07a1..31c6084c9b50 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools import pickle +import threading import time from contextlib import contextmanager from dataclasses import dataclass, field @@ -43,6 +44,33 @@ from_bytes_big = functools.partial(int.from_bytes, byteorder="big") +# Memory fence for cross-process shared memory visibility. +# Required for correct producer-consumer synchronization when using +# shared memory without locks. +_memory_fence_lock = threading.Lock() + + +def memory_fence(): + """ + Full memory barrier for shared memory synchronization. + + Ensures all prior memory writes are visible to other processes before + any subsequent reads. This is critical for lock-free producer-consumer + patterns using shared memory. + + Implementation acquires and immediately releases a lock. Python's + threading.Lock provides sequentially consistent memory barrier semantics + across all major platforms (POSIX, Windows). This is a lightweight + operation (~20ns) that guarantees: + - All stores before the barrier are visible to other threads/processes + - All loads after the barrier see the latest values + """ + # Lock acquire/release provides full memory barrier semantics. + # Using context manager ensures lock release even on exceptions. + with _memory_fence_lock: + pass + + def to_bytes_big(value: int, size: int) -> bytes: return value.to_bytes(size, byteorder="big") @@ -414,6 +442,10 @@ def acquire_write(self, timeout: float | None = None): n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + # Memory fence ensures we see the latest read flags from readers. + # Without this, we may read stale flags from our CPU cache and + # spin indefinitely even though readers have completed. + memory_fence() read_count = sum(metadata_buffer[1:]) written_flag = metadata_buffer[0] if written_flag and read_count != self.buffer.n_reader: @@ -458,6 +490,10 @@ def acquire_write(self, timeout: float | None = None): metadata_buffer[i] = 0 # mark the block as written metadata_buffer[0] = 1 + # Memory fence ensures the write is visible to readers on other cores + # before we proceed. Without this, readers may spin indefinitely + # waiting for a write that's stuck in our CPU's store buffer. + memory_fence() self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks break @@ -473,6 +509,10 @@ def acquire_read( n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + # Memory fence ensures we see the latest writes from the writer. + # Without this, we may read stale flags from our CPU cache + # and spin indefinitely even though writer has updated them. + memory_fence() read_flag = metadata_buffer[self.local_reader_rank + 1] written_flag = metadata_buffer[0] if not written_flag or read_flag: @@ -513,6 +553,10 @@ def acquire_read( # caller has read from the buffer # set the read flag metadata_buffer[self.local_reader_rank + 1] = 1 + # Memory fence ensures the read flag is visible to the writer. + # Without this, writer may not see our read completion and + # could wait indefinitely for all readers to finish. + memory_fence() self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks self._read_spin_timer.record_activity() From c2c25e5fef61241dc994e8e18d9caf94a38b3ad0 Mon Sep 17 00:00:00 2001 From: Christina Date: Mon, 8 Dec 2025 18:00:24 -0600 Subject: [PATCH 02/17] fix: Lazy tokenizer init in StructuredOutputManager to prevent semaphore leak GGUF models without precomputed merges trigger `build_merges_on_the_fly` in the transformers library, which uses multiprocessing primitives. When this happens in both the APIServer process (for request validation) and the EngineCore subprocess (via StructuredOutputManager), the subprocess leaks a semaphore, causing the server to hang indefinitely. This change makes tokenizer initialization lazy in StructuredOutputManager: - Tokenizer is only loaded when grammar_init() is first called - Most inference requests don't use structured output, so the tokenizer in EngineCore is never loaded - For requests that do use structured output, tokenizer is loaded on-demand The fix resolves the following symptoms: - Server hangs after "resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown" - Tokenizer merges being built twice (once in APIServer, once in EngineCore) - GGUF models failing to start even though weights load successfully Tested with bartowski/Phi-3.5-mini-instruct-GGUF (Q5_K_M). Signed-off-by: Christina (cherry picked from commit a72d1f9a3469f754adc6cd272ae14e066a481f2b) --- vllm/v1/structured_output/__init__.py | 85 ++++++++++++++++++--------- 1 file changed, 57 insertions(+), 28 deletions(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 4dd478804049..37e8a359717a 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import multiprocessing +import threading from concurrent.futures import Future, ThreadPoolExecutor from typing import TYPE_CHECKING @@ -63,39 +64,62 @@ def __init__(self, vllm_config: VllmConfig): max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8)) self.executor_for_fillmask = ThreadPoolExecutor(max_workers=max_workers) - if not self.vllm_config.model_config.skip_tokenizer_init: - # The default max_workers if not specified is the number of - # CPUs * 5, which is way too high since these tasks are CPU-bound, - # not I/O bound. We also know we would never dominate CPU usage - # with just grammar compilation, so we set it to half the number - # of CPUs. - max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) - self.executor = ThreadPoolExecutor(max_workers=max_workers) - self.tokenizer = init_tokenizer_from_config( - model_config=self.vllm_config.model_config - ) - reasoning_parser = ( - self.vllm_config.structured_outputs_config.reasoning_parser - ) - reasoning_parser_plugin = ( - self.vllm_config.structured_outputs_config.reasoning_parser_plugin - ) - if reasoning_parser_plugin and len(reasoning_parser_plugin) > 3: - ReasoningParserManager.import_reasoning_parser(reasoning_parser_plugin) - - reasoning_parser = ( - self.vllm_config.structured_outputs_config.reasoning_parser - ) - if reasoning_parser: - reasoner_cls = ReasoningParserManager.get_reasoning_parser( - reasoning_parser - ) - self.reasoner = reasoner_cls(tokenizer=self.tokenizer) + # Tokenizer is loaded lazily to avoid duplicate tokenizer initialization + # in multiprocess mode. For GGUF models, this prevents a semaphore leak + # that causes server hangs (tokenizer builds merges on the fly, which + # uses multiprocessing primitives that don't clean up in subprocesses). + self._tokenizer = None + self._tokenizer_initialized = False + self._tokenizer_init_lock = threading.Lock() + self.executor: ThreadPoolExecutor | None = None self.enable_in_reasoning = ( self.vllm_config.structured_outputs_config.enable_in_reasoning ) + @property + def tokenizer(self): + """Lazily initialize tokenizer when first accessed (thread-safe).""" + # Double-checked locking pattern for thread-safe lazy initialization + if not self._tokenizer_initialized: + with self._tokenizer_init_lock: + if not self._tokenizer_initialized: + self._init_tokenizer() + return self._tokenizer + + def _init_tokenizer(self): + """Initialize tokenizer and related components on first use.""" + if self._tokenizer_initialized: + return + + if self.vllm_config.model_config.skip_tokenizer_init: + self._tokenizer_initialized = True + return + + # The default max_workers if not specified is the number of + # CPUs * 5, which is way too high since these tasks are CPU-bound, + # not I/O bound. We also know we would never dominate CPU usage + # with just grammar compilation, so we set it to half the number + # of CPUs. + max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self._tokenizer = init_tokenizer_from_config( + model_config=self.vllm_config.model_config + ) + + reasoning_parser = self.vllm_config.structured_outputs_config.reasoning_parser + reasoning_parser_plugin = ( + self.vllm_config.structured_outputs_config.reasoning_parser_plugin + ) + if reasoning_parser_plugin and len(reasoning_parser_plugin) > 3: + ReasoningParserManager.import_reasoning_parser(reasoning_parser_plugin) + + if reasoning_parser: + reasoner_cls = ReasoningParserManager.get_reasoning_parser(reasoning_parser) + self.reasoner = reasoner_cls(tokenizer=self._tokenizer) + + self._tokenizer_initialized = True + def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: return @@ -149,6 +173,11 @@ def grammar_init(self, request: Request) -> None: raise ValueError(f"Unsupported structured output backend: {backend}") if self._use_async_grammar_compilation: + # Ensure tokenizer (and executor) is initialized + _ = self.tokenizer + assert self.executor is not None, ( + "Executor should be initialized with tokenizer" + ) grammar = self.executor.submit(self._create_grammar, request) else: grammar = self._create_grammar(request) # type: ignore[assignment] From 2c041523d0978bfd39b8af1758674b3718dbf2d7 Mon Sep 17 00:00:00 2001 From: Christina Date: Tue, 9 Dec 2025 22:19:11 -0600 Subject: [PATCH 03/17] fix(gguf): Ensure Gemma2 configs have hidden_act for backward compatibility GGUF-loaded configs may only have hidden_activation from config.json, but Gemma2MLP model code expects hidden_act attribute. This adds a post-processing step to copy hidden_activation to hidden_act when needed. Fixes AttributeError: 'Gemma2Config' object has no attribute 'hidden_act' when loading Gemma2 GGUF models. Signed-off-by: Christina (cherry picked from commit 04ceef56e2a267abc3862b7be1e65e64ee573442) --- vllm/config/model.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm/config/model.py b/vllm/config/model.py index 764bdf700056..978852bfab4d 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -494,6 +494,17 @@ def __post_init__( ) self.hf_config = hf_config + + # Ensure Gemma2 configs have hidden_act for backward compatibility. + # GGUF configs may only have hidden_activation; model code expects both. + if ( + hasattr(hf_config, "model_type") + and hf_config.model_type == "gemma2" + and not hasattr(hf_config, "hidden_act") + and hasattr(hf_config, "hidden_activation") + ): + hf_config.hidden_act = hf_config.hidden_activation + if dict_overrides: self._apply_dict_overrides(hf_config, dict_overrides) self.hf_text_config = get_hf_text_config(self.hf_config) From a482bbc0e74ed6ade8749f570107878542e94436 Mon Sep 17 00:00:00 2001 From: Christina Date: Tue, 9 Dec 2025 18:49:45 -0600 Subject: [PATCH 04/17] fix(gguf): Skip lm_head mapping for models with tied word embeddings For models like Gemma2 that use tie_word_embeddings=True, the lm_head.weight is initialized from embed_tokens weights rather than loaded separately. Add lm_head.weight to sideload_params to allow GGUF loading to succeed without requiring this parameter to be mapped. Fixes: RuntimeError: Failed to map GGUF parameters (1): ['lm_head.weight'] Signed-off-by: Christina (cherry picked from commit 9512f74ee802b073e188c4a49c902b992a73aad1) --- vllm/model_executor/model_loader/gguf_loader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 7f94bd234fd3..c00d9c850c89 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -147,6 +147,11 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): ) ) + # For models with tied word embeddings, lm_head.weight is initialized + # from embed_tokens and doesn't need to be mapped from GGUF file + if getattr(config, "tie_word_embeddings", False): + sideload_params.append(re.compile(r"lm_head\.weight")) + arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): if value == model_type: From ab8eea8094b6faeeae03c1e608bc54a752d07f51 Mon Sep 17 00:00:00 2001 From: Christina Date: Wed, 10 Dec 2025 13:24:11 -0600 Subject: [PATCH 05/17] fix(gemma2): Skip missing parameters during GGUF weight loading The GGUF loader yields quantization metadata parameters (qweight_type) for all quantized tensors, including embeddings. However, VocabParallelEmbedding doesn't have these parameters, causing a KeyError when loading GGUF Gemma2 models. This adds a safety check to skip parameters not present in the model, matching the pattern already used in llama.py (lines 502-503). Fixes KeyError: 'embed_tokens.qweight_type' during engine core init. (cherry picked from commit 1c144b23b5a098c77e1302eb8e5ab866411f74e9) --- vllm/model_executor/models/gemma2.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index cb36e0482458..cab434be673c 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -366,6 +366,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue if is_pp_missing_parameter(name, self): continue + # Skip parameters not in the model (e.g., GGUF quantization + # metadata like qweight_type for embeddings) + if name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) From bfe6e8996f2df274ff18034ea537b54e50a91e15 Mon Sep 17 00:00:00 2001 From: Christina Date: Wed, 10 Dec 2025 13:46:11 -0600 Subject: [PATCH 06/17] fix(gemma2): Add quant_config to embedding layer for GGUF support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Gemma2 model was missing the quant_config parameter in the VocabParallelEmbedding initialization, causing GGUF quantized embeddings to be misinterpreted as float values. Without quant_config, GGUF models use UnquantizedEmbeddingMethod which calls F.embedding() directly on quantized bytes, resulting in garbage output during inference. This is the same bug that was fixed for DeepSeek in commit aa375dca9 ("Missing quant_config in deepseek embedding layer (#12836)"). The fix adds: - quant_config parameter to enable GGUFEmbeddingMethod selection - prefix parameter for proper weight mapping Fixes Gemma2 GGUF models (gemma-2-2b-it-GGUF, etc.) producing garbage output like: " GHFW側から ThinkmariKeywords!")... (cherry picked from commit d41b71fc854e38da7d4144dd67fdd254f4a7ff27) --- vllm/model_executor/models/gemma2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index cab434be673c..256df3dfaa17 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -267,6 +267,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, From 4c0686eeee22196ec469deb1ec3219aca6b30881 Mon Sep 17 00:00:00 2001 From: Christina Date: Wed, 10 Dec 2025 15:21:18 -0600 Subject: [PATCH 07/17] fix(gguf): Extract attn_logit_softcapping from GGUF metadata Fixes garbage output from Gemma2 GGUF models by extracting the attn_logit_softcapping parameter from GGUF metadata and patching it onto the HuggingFace config. Root cause: GGUF models store softcap in metadata with arch-specific keys (e.g., gemma2.attn_logit_softcapping), but this wasn't being extracted and applied to the HF config. Without softcap, the V1 FlashAttention backend uses softcap=0 (disabled), causing numerical instability and garbage output. Changes: - Add extract_softcap_from_gguf() to read softcap from GGUF metadata - Update maybe_patch_hf_config_from_gguf() to apply softcap values - Support both attn_logit_softcapping and final_logit_softcapping Tested with: google/gemma-2-2b-it:Q4_K_M (cherry picked from commit b9e724d05e082e6e7e01d22cf9a3e9a9c59e5b08) --- vllm/transformers_utils/gguf_utils.py | 73 ++++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/vllm/transformers_utils/gguf_utils.py b/vllm/transformers_utils/gguf_utils.py index f3fd43c6ace5..2ebcae902e97 100644 --- a/vllm/transformers_utils/gguf_utils.py +++ b/vllm/transformers_utils/gguf_utils.py @@ -199,6 +199,67 @@ def extract_vision_config_from_gguf(mmproj_path: str) -> "SiglipVisionConfig | N return config +def extract_softcap_from_gguf(model: str) -> dict[str, float]: + """Extract attention and final logit softcap values from GGUF metadata. + + Reads softcap parameters from GGUF metadata using arch-specific keys. + These parameters are critical for models like Gemma2 where attention + logit softcapping prevents numerical instability. + + Args: + model: Path to GGUF model file + + Returns: + Dictionary with 'attn_logit_softcapping' and/or 'final_logit_softcapping' + keys if found in GGUF metadata, empty dict otherwise + """ + if not model.endswith(".gguf"): + return {} + + try: + model_path = Path(model) + if not model_path.is_file(): + return {} + + reader = gguf.GGUFReader(str(model_path)) + + # Get architecture name to build arch-specific keys + arch_field = reader.get_field(Keys.General.ARCHITECTURE) + if arch_field is None: + logger.debug("No architecture field found in GGUF metadata") + return {} + + arch = bytes(arch_field.parts[-1]).decode("utf-8") + + result = {} + + # Extract attention logit softcapping + attn_key = Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=arch) + attn_field = reader.get_field(attn_key) + if attn_field is not None: + result["attn_logit_softcapping"] = float(attn_field.parts[-1]) + logger.info( + "Extracted attn_logit_softcapping=%.2f from GGUF metadata", + result["attn_logit_softcapping"], + ) + + # Extract final logit softcapping + final_key = Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=arch) + final_field = reader.get_field(final_key) + if final_field is not None: + result["final_logit_softcapping"] = float(final_field.parts[-1]) + logger.info( + "Extracted final_logit_softcapping=%.2f from GGUF metadata", + result["final_logit_softcapping"], + ) + + return result + + except Exception as e: + logger.debug("Error extracting softcap from GGUF: %s", e) + return {} + + def maybe_patch_hf_config_from_gguf( model: str, hf_config: PretrainedConfig, @@ -207,7 +268,8 @@ def maybe_patch_hf_config_from_gguf( Applies GGUF-specific patches to HuggingFace config: 1. For multimodal models: patches architecture and vision config - 2. For all GGUF models: overrides vocab_size from embedding tensor + 2. For models with softcap (e.g., Gemma2): patches attention/logit softcapping + 3. For all GGUF models: overrides vocab_size from embedding tensor This ensures compatibility with GGUF models that have extended vocabularies (e.g., Unsloth) where the GGUF file contains more @@ -236,6 +298,15 @@ def maybe_patch_hf_config_from_gguf( ) hf_config = new_hf_config + # Patch softcap parameters from GGUF metadata + # Critical for models like Gemma2 where attention softcapping + # prevents numerical instability and ensures correct output + softcap_params = extract_softcap_from_gguf(model) + if "attn_logit_softcapping" in softcap_params: + hf_config.attn_logit_softcapping = softcap_params["attn_logit_softcapping"] + if "final_logit_softcapping" in softcap_params: + hf_config.final_logit_softcapping = softcap_params["final_logit_softcapping"] + return hf_config From cac1839270b983a98a78ccfdeae90bc72262926b Mon Sep 17 00:00:00 2001 From: Christina Date: Wed, 10 Dec 2025 15:46:08 -0600 Subject: [PATCH 08/17] fix(nemotron_h): Add rotary positional embeddings to attention layers --- vllm/model_executor/models/nemotron_h.py | 27 ++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 2d9dfbd3e768..7968bb17939c 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -49,6 +49,7 @@ MambaStateShapeCalculator, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -435,6 +436,7 @@ def __init__( self, config: NemotronHConfig, layer_idx: int, + max_position_embeddings: int, model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, @@ -490,13 +492,25 @@ def __init__( prefix=f"{prefix}.attn", ) + # Rotary embeddings for positional encoding + self.max_position_embeddings = max_position_embeddings + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + is_neox_style=True, + dtype=model_config.dtype if model_config else torch.get_default_dtype(), + ) + def forward( self, + positions: torch.Tensor, hidden_states: torch.Tensor, **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -518,9 +532,10 @@ def __init__( self.mixer = NemotronHAttention( config, layer_idx, - model_config, - cache_config, - quant_config, + max_position_embeddings=config.max_position_embeddings, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, prefix=f"{prefix}.mixer", ) @@ -539,7 +554,7 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states=hidden_states) + hidden_states = self.mixer(positions=positions, hidden_states=hidden_states) return hidden_states, residual @@ -659,6 +674,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: + # Skip rotary embeddings - they are computed dynamically + if "rotary_emb.inv_freq" in name: + continue + if "scale" in name: # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) From 9077a9cd4edf2bfda85498a7a266380dbccf3794 Mon Sep 17 00:00:00 2001 From: Christina Date: Tue, 9 Dec 2025 09:36:59 -0600 Subject: [PATCH 09/17] fix: Default GGUF to float16 while preserving bfloat16 option GGUF dequantization kernels use half precision (fp16) internally via the `dfloat` typedef. On Blackwell GPUs (sm_120), using bfloat16 causes garbage output due to dtype mismatch. Approach taken (middle ground): - arg_utils.py: Auto-set dtype to float16 when dtype="auto" for GGUF - gguf.py: Keep bfloat16 in supported_act_dtypes for explicit override This defaults to safe behavior while preserving user control. Users on hardware where bfloat16 works can still use --dtype bfloat16 explicitly. Options considered: 1. Blanket removal of bfloat16 from GGUF - rejected (breaks working configs) 2. Blackwell-specific detection - rejected (maintenance burden, edge cases) 3. Default fp16 + allow explicit bf16 - chosen (simple, safe, preserves choice) Tested on RTX 5090 (sm_120) with Qwen3-4B-GGUF: 583.8 tok/s Signed-off-by: Christina --- vllm/engine/arg_utils.py | 10 ++++++++++ vllm/model_executor/layers/quantization/gguf.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2f307a7ccf16..862a0f1b3332 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1201,6 +1201,16 @@ def create_model_config(self) -> ModelConfig: # gguf file needs a specific model loader if is_gguf(self.model): self.quantization = self.load_format = "gguf" + # GGUF dequantization kernels use half precision (fp16) internally. + # bfloat16 causes incorrect output on some architectures (e.g., Blackwell). + # Default to float16 for safety; explicit --dtype bfloat16 still allowed + # for users on hardware where it works. + if self.dtype == "auto": + self.dtype = "float16" + logger.info( + "GGUF models default to float16 (dequant kernels use fp16 " + "internally). Use --dtype bfloat16 to override if needed." + ) # NOTE(woosuk): In V1, we use separate processes for workers (unless # VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 13aa2bcad21b..a411d8e5c0d1 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -52,6 +52,10 @@ def get_name(self) -> QuantizationMethods: return "gguf" def get_supported_act_dtypes(self) -> list[torch.dtype]: + # GGUF dequantization kernels use half precision (fp16) internally. + # bfloat16 may cause incorrect output on some architectures (e.g., Blackwell) + # but is kept for users who explicitly request it on hardware where it works. + # See: arg_utils.py auto-defaults to float16 for GGUF when dtype="auto" return [torch.half, torch.bfloat16, torch.float32] @classmethod From af9873c738cdbd2ae5af6f83d45531e7503f2ac2 Mon Sep 17 00:00:00 2001 From: Christina Date: Tue, 9 Dec 2025 10:21:21 -0600 Subject: [PATCH 10/17] fix(gguf): Disable bfloat16 on Blackwell (SM 120+) via device capability check Instead of removing bfloat16 support globally, use device capability detection to disable bfloat16 only on SM 120+ devices (Blackwell). This preserves bfloat16 support on older architectures where tests show it works correctly, while preventing precision issues on Blackwell. Co-Authored-By: Isotr0py Signed-off-by: Christina --- vllm/model_executor/layers/quantization/gguf.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index a411d8e5c0d1..9f64518a49b5 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -33,6 +33,7 @@ ) from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -53,9 +54,13 @@ def get_name(self) -> QuantizationMethods: def get_supported_act_dtypes(self) -> list[torch.dtype]: # GGUF dequantization kernels use half precision (fp16) internally. - # bfloat16 may cause incorrect output on some architectures (e.g., Blackwell) - # but is kept for users who explicitly request it on hardware where it works. - # See: arg_utils.py auto-defaults to float16 for GGUF when dtype="auto" + # bfloat16 has precision issues on SM 120+ devices (Blackwell). + if current_platform.has_device_capability(120): + logger.warning_once( + "GGUF has precision issues with bfloat16 on SM 120+ devices. " + "bfloat16 is unavailable for Blackwell devices." + ) + return [torch.half, torch.float32] return [torch.half, torch.bfloat16, torch.float32] @classmethod From 5cbc39b42a7910320f756fb064a855705f5e8dbe Mon Sep 17 00:00:00 2001 From: Christina Date: Tue, 9 Dec 2025 12:10:08 -0600 Subject: [PATCH 11/17] fix(gguf): Remove dtype auto-override, rely on quantization layer Per review feedback: the arg_utils.py dtype override breaks Gemma2 GGUF which doesn't support FP16. The Blackwell-specific bfloat16 restriction in gguf.py's get_supported_act_dtypes() is sufficient - let _resolve_auto_dtype handle dtype selection automatically. Signed-off-by: Christina --- vllm/engine/arg_utils.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 862a0f1b3332..2f307a7ccf16 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1201,16 +1201,6 @@ def create_model_config(self) -> ModelConfig: # gguf file needs a specific model loader if is_gguf(self.model): self.quantization = self.load_format = "gguf" - # GGUF dequantization kernels use half precision (fp16) internally. - # bfloat16 causes incorrect output on some architectures (e.g., Blackwell). - # Default to float16 for safety; explicit --dtype bfloat16 still allowed - # for users on hardware where it works. - if self.dtype == "auto": - self.dtype = "float16" - logger.info( - "GGUF models default to float16 (dequant kernels use fp16 " - "internally). Use --dtype bfloat16 to override if needed." - ) # NOTE(woosuk): In V1, we use separate processes for workers (unless # VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here From 50850996c7bd29fdd60aa71f0d81ebddb9502088 Mon Sep 17 00:00:00 2001 From: Christina Date: Tue, 9 Dec 2025 17:21:56 -0600 Subject: [PATCH 12/17] fix(gguf): Auto-select compatible dtype for GGUF models on Blackwell MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes Gemma3 GGUF models failing on Blackwell GPUs with --dtype auto. Problem: - Gemma3 blocks float16 (numerical instability) - GGUF on Blackwell blocks bfloat16 (precision issues) - Only float32 works, but dtype=auto picks bfloat16 → fails Changes: 1. gguf.py: Block bfloat16 on SM 120+ (Blackwell) devices 2. vllm.py: Auto-select compatible dtype when model and quantization restrictions conflict, instead of failing with an error This allows --dtype auto to work correctly with Gemma3 GGUF on Blackwell by automatically falling back to float32. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- vllm/config/vllm.py | 40 +++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 607bb44cddd2..0629af29132d 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -389,11 +389,41 @@ def _get_quantization_config( ) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: - raise ValueError( - f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}" - ) + # Handle dtype conflict between model restrictions and + # quantization restrictions (e.g., Gemma3 GGUF on Blackwell + # where Gemma3 blocks float16 and GGUF blocks bfloat16) + from vllm.config.model import _is_valid_dtype + + model_type = getattr(model_config.hf_config, "model_type", None) + compatible_dtypes = [ + d + for d in supported_dtypes + if model_type is None or _is_valid_dtype(model_type, d) + ] + if compatible_dtypes: + # Prefer float16 > bfloat16 > float32 for performance + import torch + + dtype_preference = [torch.float16, torch.bfloat16, torch.float32] + for preferred in dtype_preference: + if preferred in compatible_dtypes: + logger.warning( + "dtype=%s is not supported for quantization " + "method %s with model type %s. " + "Automatically selecting %s as compatible dtype.", + model_config.dtype, + model_config.quantization, + model_type, + preferred, + ) + model_config.dtype = preferred + break + else: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}" + ) quant_config.maybe_update_config(model_config.model) return quant_config return None From e9112e9485239f0217f18c06bf8d86698c340d02 Mon Sep 17 00:00:00 2001 From: Christina Date: Wed, 10 Dec 2025 12:24:35 -0600 Subject: [PATCH 13/17] fix: Address review feedback - correct Blackwell compute capability - Change has_device_capability(120) to (100) for Blackwell SM 10.0 - Update comment and warning message to correctly reference SM 10.0 - Remove redundant torch import (already imported at file top) Signed-off-by: Christina --- vllm/config/vllm.py | 2 -- vllm/model_executor/layers/quantization/gguf.py | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 0629af29132d..6c2640b8156f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -402,8 +402,6 @@ def _get_quantization_config( ] if compatible_dtypes: # Prefer float16 > bfloat16 > float32 for performance - import torch - dtype_preference = [torch.float16, torch.bfloat16, torch.float32] for preferred in dtype_preference: if preferred in compatible_dtypes: diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9f64518a49b5..5c9da892f001 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -54,11 +54,11 @@ def get_name(self) -> QuantizationMethods: def get_supported_act_dtypes(self) -> list[torch.dtype]: # GGUF dequantization kernels use half precision (fp16) internally. - # bfloat16 has precision issues on SM 120+ devices (Blackwell). - if current_platform.has_device_capability(120): + # bfloat16 has precision issues on SM 10.0+ devices (Blackwell). + if current_platform.has_device_capability(100): logger.warning_once( - "GGUF has precision issues with bfloat16 on SM 120+ devices. " - "bfloat16 is unavailable for Blackwell devices." + "GGUF has precision issues with bfloat16 on Blackwell (SM 10.0+). " + "bfloat16 is unavailable." ) return [torch.half, torch.float32] return [torch.half, torch.bfloat16, torch.float32] From 42a7387474fee6d729139131a9e38407aaaf370f Mon Sep 17 00:00:00 2001 From: Christina Date: Wed, 10 Dec 2025 17:20:40 -0600 Subject: [PATCH 14/17] fix(gguf): Use EOS token ID from GGUF metadata instead of HF tokenizer GGUF files store the correct EOS token ID in tokenizer.ggml.eos_token_id metadata. However, vLLM was using the HuggingFace tokenizer's eos_token_id, which can differ from the GGUF value. This causes generation to not stop properly for models like Gemma 3, where: - GGUF metadata specifies EOS token ID 106 () - HF tokenizer reports EOS token ID 1 () The model generates to signal completion, but vLLM waits for token ID 1 which never comes, resulting in repeated EOS tokens until max_tokens is reached. Changes: - Add extract_eos_token_id_from_gguf() in gguf_utils.py to read EOS from GGUF - Patch tokenizer.eos_token_id in hf.py when loading GGUF tokenizers Signed-off-by: Christina Zhu Signed-off-by: Christina (cherry picked from commit d8cf5b747f9e56c26782cc3f8278640aa258d92f) --- vllm/tokenizers/hf.py | 22 +++++++++++++++ vllm/transformers_utils/gguf_utils.py | 40 +++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/vllm/tokenizers/hf.py b/vllm/tokenizers/hf.py index 344507312038..9e85f89c2b67 100644 --- a/vllm/tokenizers/hf.py +++ b/vllm/tokenizers/hf.py @@ -7,11 +7,15 @@ from transformers import AutoTokenizer +from vllm.logger import init_logger from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config +from vllm.transformers_utils.gguf_utils import extract_eos_token_id_from_gguf from .protocol import TokenizerLike from .registry import TokenizerRegistry +logger = init_logger(__name__) + if TYPE_CHECKING: from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -121,4 +125,22 @@ def from_pretrained( } tokenizer.add_special_tokens(special_tokens_map) + # Patch EOS token ID from GGUF metadata if available + # GGUF files may have a different EOS token ID than HF tokenizer config + # (e.g., Gemma uses ID 106 as EOS, but HF reports ID 1) + gguf_file = kwargs.get("gguf_file") + if gguf_file: + gguf_path = Path(path_or_repo_id) / gguf_file + gguf_eos_id = extract_eos_token_id_from_gguf(str(gguf_path)) + if gguf_eos_id is not None: + hf_eos_id = tokenizer.eos_token_id + if hf_eos_id != gguf_eos_id: + logger.info( + "Patching tokenizer eos_token_id from %d to %d " + "(using GGUF metadata)", + hf_eos_id, + gguf_eos_id, + ) + tokenizer.eos_token_id = gguf_eos_id + return get_cached_tokenizer(tokenizer) diff --git a/vllm/transformers_utils/gguf_utils.py b/vllm/transformers_utils/gguf_utils.py index 2ebcae902e97..b14470a99092 100644 --- a/vllm/transformers_utils/gguf_utils.py +++ b/vllm/transformers_utils/gguf_utils.py @@ -260,6 +260,46 @@ def extract_softcap_from_gguf(model: str) -> dict[str, float]: return {} +def extract_eos_token_id_from_gguf(model: str) -> int | None: + """Extract EOS token ID from GGUF metadata. + + GGUF files store the EOS token ID in tokenizer.ggml.eos_token_id field. + This may differ from HuggingFace's tokenizer config (e.g., Gemma models + use token ID 106 as EOS in GGUF, but HF tokenizer reports + token ID 1). + + Args: + model: Path to GGUF model file + + Returns: + EOS token ID from GGUF metadata, or None if not found + """ + if not model.endswith(".gguf"): + return None + + try: + model_path = Path(model) + if not model_path.is_file(): + return None + + reader = gguf.GGUFReader(str(model_path)) + + eos_field = reader.get_field(Keys.Tokenizer.EOS_ID) + if eos_field is not None: + eos_token_id = int(eos_field.parts[-1][0]) + logger.debug( + "Extracted eos_token_id=%d from GGUF metadata", + eos_token_id, + ) + return eos_token_id + + return None + + except Exception as e: + logger.debug("Error extracting EOS token ID from GGUF: %s", e) + return None + + def maybe_patch_hf_config_from_gguf( model: str, hf_config: PretrainedConfig, From 3edaa682f81db5a2ce0299f59af2d4648eeaae90 Mon Sep 17 00:00:00 2001 From: Christina Date: Wed, 10 Dec 2025 18:05:40 -0600 Subject: [PATCH 15/17] feat(gguf): Extract HF config from GGUF metadata for repos without config.json This enables vLLM to load GGUF models from repositories that don't include config.json (e.g., bartowski repos) by extracting the configuration values directly from GGUF metadata. Changes: - Add GGUF_ARCH_TO_HF_MODEL_TYPE mapping for architecture name translation - Add extract_hf_config_from_gguf() function that reads GGUF metadata and constructs a HuggingFace-compatible config dictionary - Add GGUFConfigParser class that uses the extraction function - Register "gguf" format in the config parser system - Update auto-detection to use GGUF parser for local GGUF files without config.json Extracted metadata fields: - Architecture (model_type) - hidden_size, intermediate_size, num_hidden_layers - num_attention_heads, num_key_value_heads - max_position_embeddings, rope_theta, rms_norm_eps - sliding_window, vocab_size - bos_token_id, eos_token_id - attn_logit_softcapping, final_logit_softcapping (for Gemma2) Signed-off-by: Christina --- vllm/transformers_utils/config.py | 86 +++++++++++++- vllm/transformers_utils/gguf_utils.py | 164 ++++++++++++++++++++++++++ 2 files changed, 246 insertions(+), 4 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index d761802da940..0cf3d6ceb3a6 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -31,6 +31,7 @@ from .config_parser_base import ConfigParserBase from .gguf_utils import ( check_gguf_file, + extract_hf_config_from_gguf, is_gguf, is_remote_gguf, split_remote_gguf, @@ -223,15 +224,87 @@ def parse( return config_dict, config +class GGUFConfigParser(ConfigParserBase): + """Config parser that extracts configuration from GGUF metadata. + + This parser is used for GGUF models from repositories that don't include + config.json (e.g., bartowski repos). It reads the GGUF file metadata + directly to construct a HuggingFace-compatible configuration. + """ + + def parse( + self, + model: str | Path, + trust_remote_code: bool, + revision: str | None = None, + code_revision: str | None = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: + # Get the GGUF file path from kwargs + gguf_file = kwargs.get("gguf_file") + gguf_path = str(Path(model) / gguf_file) if gguf_file else str(model) + + # Extract config from GGUF metadata + config_dict = extract_hf_config_from_gguf(gguf_path) + if config_dict is None: + raise ValueError( + f"Failed to extract config from GGUF file: {gguf_path}. " + "The GGUF file may be corrupted or missing required metadata." + ) + + model_type = config_dict.get("model_type") + + # Use hf_overrides if provided + if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None: + config_dict.update(hf_overrides) + model_type = config_dict.get("model_type", model_type) + + # Create config using AutoConfig with the extracted dict + # We need to create a config class based on model_type + if model_type is not None and model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] + config = config_class(**config_dict) + else: + # Use AutoConfig to get the appropriate config class + try: + config_class = AutoConfig.for_model(model_type) + # Filter config_dict to only include valid keys for this config + valid_keys = ( + set(config_class.__dataclass_fields__.keys()) + if hasattr(config_class, "__dataclass_fields__") + else set(config_class().__dict__.keys()) + ) + filtered_dict = { + k: v + for k, v in config_dict.items() + if k in valid_keys or k == "model_type" + } + config = config_class(**filtered_dict) + except Exception as e: + logger.warning( + "Failed to create config with AutoConfig.for_model(%s): %s. " + "Falling back to PretrainedConfig.", + model_type, + e, + ) + # Fallback to basic PretrainedConfig + config = PretrainedConfig(**config_dict) + + config = _maybe_remap_hf_config_attrs(config) + return config_dict, config + + _CONFIG_FORMAT_TO_CONFIG_PARSER: dict[str, type[ConfigParserBase]] = { "hf": HFConfigParser, "mistral": MistralConfigParser, + "gguf": GGUFConfigParser, } ConfigFormat = Literal[ "auto", "hf", "mistral", + "gguf", ] @@ -556,13 +629,18 @@ def get_config( # Transformers implementation. if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): config_format = "mistral" - elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists( - model, HF_CONFIG_NAME, revision=revision - ): + elif file_or_path_exists(model, HF_CONFIG_NAME, revision=revision): config_format = "hf" + # Local GGUF files without config.json - extract config from GGUF metadata + elif _is_gguf and not _is_remote_gguf: + logger.info( + "No config.json found for local GGUF model. " + "Extracting config from GGUF metadata." + ) + config_format = "gguf" # Remote GGUF models must have config.json in repo, # otherwise the config can't be parsed correctly. - # FIXME(Isotr0py): Support remote GGUF repos without config.json + # TODO(Isotr0py): Support remote GGUF repos without config.json elif _is_remote_gguf and not file_or_path_exists( model, HF_CONFIG_NAME, revision=revision ): diff --git a/vllm/transformers_utils/gguf_utils.py b/vllm/transformers_utils/gguf_utils.py index b14470a99092..ddde1fe398f3 100644 --- a/vllm/transformers_utils/gguf_utils.py +++ b/vllm/transformers_utils/gguf_utils.py @@ -5,6 +5,7 @@ from functools import cache from os import PathLike from pathlib import Path +from typing import Any import gguf import regex as re @@ -199,6 +200,169 @@ def extract_vision_config_from_gguf(mmproj_path: str) -> "SiglipVisionConfig | N return config +# Mapping from GGUF architecture names to HuggingFace model_type +GGUF_ARCH_TO_HF_MODEL_TYPE: dict[str, str] = { + "llama": "llama", + "phi3": "phi3", + "gemma": "gemma", + "gemma2": "gemma2", + "qwen2": "qwen2", + "qwen3": "qwen2", # Qwen3 uses qwen2 architecture in HF + "starcoder2": "starcoder2", + "gpt2": "gpt2", + "mistral": "mistral", + "mixtral": "mixtral", + "falcon": "falcon", + "phi2": "phi", + "phi": "phi", + "baichuan": "baichuan", + "internlm2": "internlm2", + "mamba": "mamba", + "nemotron": "nemotron", +} + + +def extract_hf_config_from_gguf(model: str) -> dict[str, Any] | None: + """Extract HuggingFace-compatible config dict from GGUF metadata. + + This function reads GGUF metadata and constructs a config dictionary + that can be used to create a PretrainedConfig. Useful for GGUF repos + that don't include config.json (e.g., bartowski repos). + + Args: + model: Path to GGUF model file + + Returns: + Dictionary with HF-compatible config values, or None if extraction fails + + Raises: + Exception: Exceptions from GGUF reading propagate directly + """ + # Use check_gguf_file to validate - it reads the header magic bytes + # This handles both .gguf extension and HuggingFace cache blob paths + if not check_gguf_file(model): + return None + + try: + model_path = Path(model) + + reader = gguf.GGUFReader(str(model_path)) + + # Get architecture name + arch_field = reader.get_field(Keys.General.ARCHITECTURE) + if arch_field is None: + logger.warning("No architecture field found in GGUF metadata") + return None + + arch = bytes(arch_field.parts[-1]).decode("utf-8") + logger.info("Extracting config from GGUF metadata (architecture: %s)", arch) + + # Map GGUF architecture to HF model_type + model_type = GGUF_ARCH_TO_HF_MODEL_TYPE.get(arch, arch) + + config_dict: dict[str, Any] = { + "model_type": model_type, + } + + # Helper to extract field value + def get_field_value(key: str, default=None): + field = reader.get_field(key.format(arch=arch)) + if field is not None: + val = field.parts[-1] + # Handle arrays vs scalars + if hasattr(val, "__len__") and len(val) == 1: + return val[0] + return val + return default + + # Extract core architecture parameters + # Using arch-specific keys from gguf.constants.Keys + + # Context length -> max_position_embeddings + ctx_len = get_field_value(Keys.LLM.CONTEXT_LENGTH) + if ctx_len is not None: + config_dict["max_position_embeddings"] = int(ctx_len) + + # Embedding length -> hidden_size + embed_len = get_field_value(Keys.LLM.EMBEDDING_LENGTH) + if embed_len is not None: + config_dict["hidden_size"] = int(embed_len) + + # Feed forward length -> intermediate_size + ff_len = get_field_value(Keys.LLM.FEED_FORWARD_LENGTH) + if ff_len is not None: + config_dict["intermediate_size"] = int(ff_len) + + # Block count -> num_hidden_layers + block_count = get_field_value(Keys.LLM.BLOCK_COUNT) + if block_count is not None: + config_dict["num_hidden_layers"] = int(block_count) + + # Attention head count -> num_attention_heads + head_count = get_field_value(Keys.Attention.HEAD_COUNT) + if head_count is not None: + config_dict["num_attention_heads"] = int(head_count) + + # KV head count -> num_key_value_heads + kv_head_count = get_field_value(Keys.Attention.HEAD_COUNT_KV) + if kv_head_count is not None: + config_dict["num_key_value_heads"] = int(kv_head_count) + + # RoPE frequency base -> rope_theta + rope_freq = get_field_value(Keys.Rope.FREQ_BASE) + if rope_freq is not None: + config_dict["rope_theta"] = float(rope_freq) + + # Layer norm epsilon + rms_eps = get_field_value(Keys.Attention.LAYERNORM_RMS_EPS) + if rms_eps is not None: + config_dict["rms_norm_eps"] = float(rms_eps) + + # Sliding window attention + sliding_window = get_field_value(Keys.Attention.SLIDING_WINDOW) + if sliding_window is not None: + config_dict["sliding_window"] = int(sliding_window) + + # Vocab size - from tokenizer tokens list or arch-specific field + vocab_size = get_field_value(Keys.LLM.VOCAB_SIZE) + if vocab_size is None: + tokens_field = reader.get_field(Keys.Tokenizer.LIST) + if tokens_field is not None: + vocab_size = len(tokens_field.parts[-1]) + if vocab_size is not None: + config_dict["vocab_size"] = int(vocab_size) + + # Token IDs + bos_id = get_field_value(Keys.Tokenizer.BOS_ID) + if bos_id is not None: + config_dict["bos_token_id"] = int(bos_id) + + eos_id = get_field_value(Keys.Tokenizer.EOS_ID) + if eos_id is not None: + config_dict["eos_token_id"] = int(eos_id) + + # Attention softcapping (for Gemma2, etc.) + attn_softcap = get_field_value(Keys.LLM.ATTN_LOGIT_SOFTCAPPING) + if attn_softcap is not None: + config_dict["attn_logit_softcapping"] = float(attn_softcap) + + final_softcap = get_field_value(Keys.LLM.FINAL_LOGIT_SOFTCAPPING) + if final_softcap is not None: + config_dict["final_logit_softcapping"] = float(final_softcap) + + logger.info( + "Extracted %d config fields from GGUF metadata for %s", + len(config_dict), + model_type, + ) + + return config_dict + + except Exception as e: + logger.warning("Error extracting config from GGUF: %s", e) + return None + + def extract_softcap_from_gguf(model: str) -> dict[str, float]: """Extract attention and final logit softcap values from GGUF metadata. From 03d0317f00ba8f8da13c2915361055ffa1f4f5bc Mon Sep 17 00:00:00 2001 From: Christina Date: Mon, 1 Dec 2025 16:31:21 -0600 Subject: [PATCH 16/17] fix(shm): Add memory barriers for cross-process shared memory visibility The shared memory ring buffer protocol in shm_broadcast.py uses plain byte writes to signal between writer and reader processes. On multi-core systems, these writes may stay in CPU store buffers and not be visible to other processes running on different cores, causing indefinite spinning/freeze under sustained concurrent load. This patch adds explicit memory barriers using threading.Lock acquire/release (which provides full barrier semantics per POSIX.1-2008) at four critical points: - In acquire_write(): before reading flags and after setting written flag - In acquire_read(): before reading flags and after setting read flag The memory barrier ensures that: 1. All stores before the barrier are globally visible 2. All loads after the barrier see the latest values Fixes freeze observed during sustained concurrent batch inference (~500+ requests) where both writer and readers would spin indefinitely waiting for flags that were updated but not visible across CPU cores. Signed-off-by: Christina Holland Signed-off-by: Christina --- vllm/distributed/device_communicators/shm_broadcast.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 31c6084c9b50..96fd0edcae99 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -64,11 +64,13 @@ def memory_fence(): operation (~20ns) that guarantees: - All stores before the barrier are visible to other threads/processes - All loads after the barrier see the latest values + + Reference: POSIX.1-2008 specifies that mutex operations synchronize memory. """ # Lock acquire/release provides full memory barrier semantics. - # Using context manager ensures lock release even on exceptions. - with _memory_fence_lock: - pass + # This flushes CPU store buffers and invalidates stale cache lines. + _memory_fence_lock.acquire() + _memory_fence_lock.release() def to_bytes_big(value: int, size: int) -> bytes: From a207d146c24cc916f1285d4acbc25c3411b354b5 Mon Sep 17 00:00:00 2001 From: Christina Date: Tue, 9 Dec 2025 18:42:01 -0600 Subject: [PATCH 17/17] fix: Address review feedback - platform-agnostic docs and context manager Signed-off-by: Christina --- vllm/distributed/device_communicators/shm_broadcast.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 96fd0edcae99..31c6084c9b50 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -64,13 +64,11 @@ def memory_fence(): operation (~20ns) that guarantees: - All stores before the barrier are visible to other threads/processes - All loads after the barrier see the latest values - - Reference: POSIX.1-2008 specifies that mutex operations synchronize memory. """ # Lock acquire/release provides full memory barrier semantics. - # This flushes CPU store buffers and invalidates stale cache lines. - _memory_fence_lock.acquire() - _memory_fence_lock.release() + # Using context manager ensures lock release even on exceptions. + with _memory_fence_lock: + pass def to_bytes_big(value: int, size: int) -> bytes: