diff --git a/vllm/config/model.py b/vllm/config/model.py index 7ff095bcb9cc..6a812ef079af 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -470,6 +470,17 @@ def __post_init__( ) self.hf_config = hf_config + + # Ensure Gemma2 configs have hidden_act for backward compatibility. + # GGUF configs may only have hidden_activation; model code expects both. + if ( + hasattr(hf_config, "model_type") + and hf_config.model_type == "gemma2" + and not hasattr(hf_config, "hidden_act") + and hasattr(hf_config, "hidden_activation") + ): + hf_config.hidden_act = hf_config.hidden_activation + if dict_overrides: self._apply_dict_overrides(hf_config, dict_overrides) self.hf_text_config = get_hf_text_config(self.hf_config) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ace5adc109d8..c85cba6ee324 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -389,11 +389,41 @@ def _get_quantization_config( ) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: - raise ValueError( - f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}" - ) + # Handle dtype conflict between model restrictions and + # quantization restrictions (e.g., Gemma3 GGUF on Blackwell + # where Gemma3 blocks float16 and GGUF blocks bfloat16) + from vllm.config.model import _is_valid_dtype + + model_type = getattr(model_config.hf_config, "model_type", None) + compatible_dtypes = [ + d + for d in supported_dtypes + if model_type is None or _is_valid_dtype(model_type, d) + ] + if compatible_dtypes: + # Prefer float16 > bfloat16 > float32 for performance + import torch + + dtype_preference = [torch.float16, torch.bfloat16, torch.float32] + for preferred in dtype_preference: + if preferred in compatible_dtypes: + logger.warning( + "dtype=%s is not supported for quantization " + "method %s with model type %s. " + "Automatically selecting %s as compatible dtype.", + model_config.dtype, + model_config.quantization, + model_type, + preferred, + ) + model_config.dtype = preferred + break + else: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}" + ) quant_config.maybe_update_config(model_config.model) return quant_config return None diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 7f94bd234fd3..c00d9c850c89 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -147,6 +147,11 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): ) ) + # For models with tied word embeddings, lm_head.weight is initialized + # from embed_tokens and doesn't need to be mapped from GGUF file + if getattr(config, "tie_word_embeddings", False): + sideload_params.append(re.compile(r"lm_head\.weight")) + arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): if value == model_type: diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index fe6ec5ff83de..7a9e62f76bf6 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -365,6 +365,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue if is_pp_missing_parameter(name, self): continue + # Skip parameters not in the model (e.g., GGUF quantization + # metadata like qweight_type for embeddings) + if name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 2d9dfbd3e768..7968bb17939c 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -49,6 +49,7 @@ MambaStateShapeCalculator, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -435,6 +436,7 @@ def __init__( self, config: NemotronHConfig, layer_idx: int, + max_position_embeddings: int, model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, @@ -490,13 +492,25 @@ def __init__( prefix=f"{prefix}.attn", ) + # Rotary embeddings for positional encoding + self.max_position_embeddings = max_position_embeddings + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + is_neox_style=True, + dtype=model_config.dtype if model_config else torch.get_default_dtype(), + ) + def forward( self, + positions: torch.Tensor, hidden_states: torch.Tensor, **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -518,9 +532,10 @@ def __init__( self.mixer = NemotronHAttention( config, layer_idx, - model_config, - cache_config, - quant_config, + max_position_embeddings=config.max_position_embeddings, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, prefix=f"{prefix}.mixer", ) @@ -539,7 +554,7 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states=hidden_states) + hidden_states = self.mixer(positions=positions, hidden_states=hidden_states) return hidden_states, residual @@ -659,6 +674,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: + # Skip rotary embeddings - they are computed dynamically + if "rotary_emb.inv_freq" in name: + continue + if "scale" in name: # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 79ee4161e9df..2e60771dbbb7 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import multiprocessing +import threading from concurrent.futures import Future, ThreadPoolExecutor from typing import TYPE_CHECKING @@ -63,39 +64,62 @@ def __init__(self, vllm_config: VllmConfig): max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8)) self.executor_for_fillmask = ThreadPoolExecutor(max_workers=max_workers) - if not self.vllm_config.model_config.skip_tokenizer_init: - # The default max_workers if not specified is the number of - # CPUs * 5, which is way too high since these tasks are CPU-bound, - # not I/O bound. We also know we would never dominate CPU usage - # with just grammar compilation, so we set it to half the number - # of CPUs. - max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) - self.executor = ThreadPoolExecutor(max_workers=max_workers) - self.tokenizer = cached_tokenizer_from_config( - model_config=self.vllm_config.model_config - ) - reasoning_parser = ( - self.vllm_config.structured_outputs_config.reasoning_parser - ) - reasoning_parser_plugin = ( - self.vllm_config.structured_outputs_config.reasoning_parser_plugin - ) - if reasoning_parser_plugin and len(reasoning_parser_plugin) > 3: - ReasoningParserManager.import_reasoning_parser(reasoning_parser_plugin) - - reasoning_parser = ( - self.vllm_config.structured_outputs_config.reasoning_parser - ) - if reasoning_parser: - reasoner_cls = ReasoningParserManager.get_reasoning_parser( - reasoning_parser - ) - self.reasoner = reasoner_cls(tokenizer=self.tokenizer) + # Tokenizer is loaded lazily to avoid duplicate tokenizer initialization + # in multiprocess mode. For GGUF models, this prevents a semaphore leak + # that causes server hangs (tokenizer builds merges on the fly, which + # uses multiprocessing primitives that don't clean up in subprocesses). + self._tokenizer = None + self._tokenizer_initialized = False + self._tokenizer_init_lock = threading.Lock() + self.executor: ThreadPoolExecutor | None = None self.enable_in_reasoning = ( self.vllm_config.structured_outputs_config.enable_in_reasoning ) + @property + def tokenizer(self): + """Lazily initialize tokenizer when first accessed (thread-safe).""" + # Double-checked locking pattern for thread-safe lazy initialization + if not self._tokenizer_initialized: + with self._tokenizer_init_lock: + if not self._tokenizer_initialized: + self._init_tokenizer() + return self._tokenizer + + def _init_tokenizer(self): + """Initialize tokenizer and related components on first use.""" + if self._tokenizer_initialized: + return + + if self.vllm_config.model_config.skip_tokenizer_init: + self._tokenizer_initialized = True + return + + # The default max_workers if not specified is the number of + # CPUs * 5, which is way too high since these tasks are CPU-bound, + # not I/O bound. We also know we would never dominate CPU usage + # with just grammar compilation, so we set it to half the number + # of CPUs. + max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self._tokenizer = cached_tokenizer_from_config( + model_config=self.vllm_config.model_config + ) + + reasoning_parser = self.vllm_config.structured_outputs_config.reasoning_parser + reasoning_parser_plugin = ( + self.vllm_config.structured_outputs_config.reasoning_parser_plugin + ) + if reasoning_parser_plugin and len(reasoning_parser_plugin) > 3: + ReasoningParserManager.import_reasoning_parser(reasoning_parser_plugin) + + if reasoning_parser: + reasoner_cls = ReasoningParserManager.get_reasoning_parser(reasoning_parser) + self.reasoner = reasoner_cls(tokenizer=self._tokenizer) + + self._tokenizer_initialized = True + def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: return @@ -149,6 +173,11 @@ def grammar_init(self, request: Request) -> None: raise ValueError(f"Unsupported structured output backend: {backend}") if self._use_async_grammar_compilation: + # Ensure tokenizer (and executor) is initialized + _ = self.tokenizer + assert self.executor is not None, ( + "Executor should be initialized with tokenizer" + ) grammar = self.executor.submit(self._create_grammar, request) else: grammar = self._create_grammar(request) # type: ignore[assignment]