Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
00c20ae
fix: Address review feedback - platform-agnostic docs and context man…
kitaekatt Dec 10, 2025
c2c25e5
fix: Lazy tokenizer init in StructuredOutputManager to prevent semaph…
kitaekatt Dec 9, 2025
2c04152
fix(gguf): Ensure Gemma2 configs have hidden_act for backward compati…
kitaekatt Dec 10, 2025
a482bbc
fix(gguf): Skip lm_head mapping for models with tied word embeddings
kitaekatt Dec 10, 2025
ab8eea8
fix(gemma2): Skip missing parameters during GGUF weight loading
kitaekatt Dec 10, 2025
bfe6e89
fix(gemma2): Add quant_config to embedding layer for GGUF support
kitaekatt Dec 10, 2025
4c0686e
fix(gguf): Extract attn_logit_softcapping from GGUF metadata
kitaekatt Dec 10, 2025
cac1839
fix(nemotron_h): Add rotary positional embeddings to attention layers
kitaekatt Dec 10, 2025
9077a9c
fix: Default GGUF to float16 while preserving bfloat16 option
kitaekatt Dec 9, 2025
af9873c
fix(gguf): Disable bfloat16 on Blackwell (SM 120+) via device capabil…
kitaekatt Dec 9, 2025
5cbc39b
fix(gguf): Remove dtype auto-override, rely on quantization layer
kitaekatt Dec 9, 2025
5085099
fix(gguf): Auto-select compatible dtype for GGUF models on Blackwell
kitaekatt Dec 9, 2025
e9112e9
fix: Address review feedback - correct Blackwell compute capability
kitaekatt Dec 10, 2025
42a7387
fix(gguf): Use EOS token ID from GGUF metadata instead of HF tokenizer
kitaekatt Dec 10, 2025
3edaa68
feat(gguf): Extract HF config from GGUF metadata for repos without co…
kitaekatt Dec 11, 2025
03d0317
fix(shm): Add memory barriers for cross-process shared memory visibility
kitaekatt Dec 1, 2025
a207d14
fix: Address review feedback - platform-agnostic docs and context man…
kitaekatt Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,17 @@ def __post_init__(
)

self.hf_config = hf_config

# Ensure Gemma2 configs have hidden_act for backward compatibility.
# GGUF configs may only have hidden_activation; model code expects both.
if (
hasattr(hf_config, "model_type")
and hf_config.model_type == "gemma2"
and not hasattr(hf_config, "hidden_act")
and hasattr(hf_config, "hidden_activation")
):
hf_config.hidden_act = hf_config.hidden_activation

if dict_overrides:
self._apply_dict_overrides(hf_config, dict_overrides)
self.hf_text_config = get_hf_text_config(self.hf_config)
Expand Down
38 changes: 33 additions & 5 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,39 @@ def _get_quantization_config(
)
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}"
)
# Handle dtype conflict between model restrictions and
# quantization restrictions (e.g., Gemma3 GGUF on Blackwell
# where Gemma3 blocks float16 and GGUF blocks bfloat16)
from vllm.config.model import _is_valid_dtype

model_type = getattr(model_config.hf_config, "model_type", None)
compatible_dtypes = [
d
for d in supported_dtypes
if model_type is None or _is_valid_dtype(model_type, d)
]
if compatible_dtypes:
# Prefer float16 > bfloat16 > float32 for performance
dtype_preference = [torch.float16, torch.bfloat16, torch.float32]
for preferred in dtype_preference:
if preferred in compatible_dtypes:
logger.warning(
"dtype=%s is not supported for quantization "
"method %s with model type %s. "
"Automatically selecting %s as compatible dtype.",
model_config.dtype,
model_config.quantization,
model_type,
preferred,
)
model_config.dtype = preferred
break
else:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}"
)
quant_config.maybe_update_config(model_config.model)
return quant_config
return None
Expand Down
44 changes: 44 additions & 0 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import pickle
import threading
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
Expand Down Expand Up @@ -43,6 +44,33 @@
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")


# Memory fence for cross-process shared memory visibility.
# Required for correct producer-consumer synchronization when using
# shared memory without locks.
_memory_fence_lock = threading.Lock()


def memory_fence():
"""
Full memory barrier for shared memory synchronization.

Ensures all prior memory writes are visible to other processes before
any subsequent reads. This is critical for lock-free producer-consumer
patterns using shared memory.

Implementation acquires and immediately releases a lock. Python's
threading.Lock provides sequentially consistent memory barrier semantics
across all major platforms (POSIX, Windows). This is a lightweight
operation (~20ns) that guarantees:
- All stores before the barrier are visible to other threads/processes
- All loads after the barrier see the latest values
"""
# Lock acquire/release provides full memory barrier semantics.
# Using context manager ensures lock release even on exceptions.
with _memory_fence_lock:
pass


def to_bytes_big(value: int, size: int) -> bytes:
return value.to_bytes(size, byteorder="big")

Expand Down Expand Up @@ -414,6 +442,10 @@ def acquire_write(self, timeout: float | None = None):
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
# Memory fence ensures we see the latest read flags from readers.
# Without this, we may read stale flags from our CPU cache and
# spin indefinitely even though readers have completed.
memory_fence()
read_count = sum(metadata_buffer[1:])
written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader:
Expand Down Expand Up @@ -458,6 +490,10 @@ def acquire_write(self, timeout: float | None = None):
metadata_buffer[i] = 0
# mark the block as written
metadata_buffer[0] = 1
# Memory fence ensures the write is visible to readers on other cores
# before we proceed. Without this, readers may spin indefinitely
# waiting for a write that's stuck in our CPU's store buffer.
memory_fence()
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
break

Expand All @@ -473,6 +509,10 @@ def acquire_read(
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
# Memory fence ensures we see the latest writes from the writer.
# Without this, we may read stale flags from our CPU cache
# and spin indefinitely even though writer has updated them.
memory_fence()
read_flag = metadata_buffer[self.local_reader_rank + 1]
written_flag = metadata_buffer[0]
if not written_flag or read_flag:
Expand Down Expand Up @@ -513,6 +553,10 @@ def acquire_read(
# caller has read from the buffer
# set the read flag
metadata_buffer[self.local_reader_rank + 1] = 1
# Memory fence ensures the read flag is visible to the writer.
# Without this, writer may not see our read completion and
# could wait indefinitely for all readers to finish.
memory_fence()
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks

self._read_spin_timer.record_activity()
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op

logger = init_logger(__name__)
Expand All @@ -52,6 +53,14 @@ def get_name(self) -> QuantizationMethods:
return "gguf"

def get_supported_act_dtypes(self) -> list[torch.dtype]:
# GGUF dequantization kernels use half precision (fp16) internally.
# bfloat16 has precision issues on SM 10.0+ devices (Blackwell).
if current_platform.has_device_capability(100):
logger.warning_once(
"GGUF has precision issues with bfloat16 on Blackwell (SM 10.0+). "
"bfloat16 is unavailable."
)
return [torch.half, torch.float32]
return [torch.half, torch.bfloat16, torch.float32]

@classmethod
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/model_loader/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def _get_gguf_weights_map(self, model_config: ModelConfig):
)
)

# For models with tied word embeddings, lm_head.weight is initialized
# from embed_tokens and doesn't need to be mapped from GGUF file
if getattr(config, "tie_word_embeddings", False):
sideload_params.append(re.compile(r"lm_head\.weight"))

arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items():
if value == model_type:
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
Expand Down Expand Up @@ -366,6 +368,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
continue
if is_pp_missing_parameter(name, self):
continue
# Skip parameters not in the model (e.g., GGUF quantization
# metadata like qweight_type for embeddings)
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
Expand Down
27 changes: 23 additions & 4 deletions vllm/model_executor/models/nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down Expand Up @@ -435,6 +436,7 @@ def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
max_position_embeddings: int,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
Expand Down Expand Up @@ -490,13 +492,25 @@ def __init__(
prefix=f"{prefix}.attn",
)

# Rotary embeddings for positional encoding
self.max_position_embeddings = max_position_embeddings
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
is_neox_style=True,
dtype=model_config.dtype if model_config else torch.get_default_dtype(),
)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
Expand All @@ -518,9 +532,10 @@ def __init__(
self.mixer = NemotronHAttention(
config,
layer_idx,
model_config,
cache_config,
quant_config,
max_position_embeddings=config.max_position_embeddings,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
)

Expand All @@ -539,7 +554,7 @@ def forward(
else:
hidden_states, residual = self.norm(hidden_states, residual)

hidden_states = self.mixer(hidden_states=hidden_states)
hidden_states = self.mixer(positions=positions, hidden_states=hidden_states)
return hidden_states, residual


Expand Down Expand Up @@ -659,6 +674,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# Skip rotary embeddings - they are computed dynamically
if "rotary_emb.inv_freq" in name:
continue

if "scale" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
Expand Down
22 changes: 22 additions & 0 deletions vllm/tokenizers/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@

from transformers import AutoTokenizer

from vllm.logger import init_logger
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
from vllm.transformers_utils.gguf_utils import extract_eos_token_id_from_gguf

from .protocol import TokenizerLike
from .registry import TokenizerRegistry

logger = init_logger(__name__)

if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

Expand Down Expand Up @@ -121,4 +125,22 @@ def from_pretrained(
}
tokenizer.add_special_tokens(special_tokens_map)

# Patch EOS token ID from GGUF metadata if available
# GGUF files may have a different EOS token ID than HF tokenizer config
# (e.g., Gemma uses <end_of_turn> ID 106 as EOS, but HF reports <eos> ID 1)
gguf_file = kwargs.get("gguf_file")
if gguf_file:
gguf_path = Path(path_or_repo_id) / gguf_file
gguf_eos_id = extract_eos_token_id_from_gguf(str(gguf_path))
if gguf_eos_id is not None:
hf_eos_id = tokenizer.eos_token_id
if hf_eos_id != gguf_eos_id:
logger.info(
"Patching tokenizer eos_token_id from %d to %d "
"(using GGUF metadata)",
hf_eos_id,
gguf_eos_id,
)
tokenizer.eos_token_id = gguf_eos_id

return get_cached_tokenizer(tokenizer)
Loading
Loading