diff --git a/vllm/config/model.py b/vllm/config/model.py index 1de9d15cf8c5..1e6f467ffd4f 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -470,6 +470,17 @@ def __post_init__( ) self.hf_config = hf_config + + # Ensure Gemma2 configs have hidden_act for backward compatibility. + # GGUF configs may only have hidden_activation; model code expects both. + if ( + hasattr(hf_config, "model_type") + and hf_config.model_type == "gemma2" + and not hasattr(hf_config, "hidden_act") + and hasattr(hf_config, "hidden_activation") + ): + hf_config.hidden_act = hf_config.hidden_activation + if dict_overrides: self._apply_dict_overrides(hf_config, dict_overrides) self.hf_text_config = get_hf_text_config(self.hf_config) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ace5adc109d8..945b5f6fd080 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -389,11 +389,39 @@ def _get_quantization_config( ) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: - raise ValueError( - f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}" - ) + # Handle dtype conflict between model restrictions and + # quantization restrictions (e.g., Gemma3 GGUF on Blackwell + # where Gemma3 blocks float16 and GGUF blocks bfloat16) + from vllm.config.model import _is_valid_dtype + + model_type = getattr(model_config.hf_config, "model_type", None) + compatible_dtypes = [ + d + for d in supported_dtypes + if model_type is None or _is_valid_dtype(model_type, d) + ] + if compatible_dtypes: + # Prefer float16 > bfloat16 > float32 for performance + dtype_preference = [torch.float16, torch.bfloat16, torch.float32] + for preferred in dtype_preference: + if preferred in compatible_dtypes: + logger.warning( + "dtype=%s is not supported for quantization " + "method %s with model type %s. " + "Automatically selecting %s as compatible dtype.", + model_config.dtype, + model_config.quantization, + model_type, + preferred, + ) + model_config.dtype = preferred + break + else: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}" + ) quant_config.maybe_update_config(model_config.model) return quant_config return None @@ -666,9 +694,8 @@ def has_blocked_weights(): default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level] self._apply_optimization_level_defaults(default_config) - if ( - self.compilation_config.cudagraph_mode.requires_piecewise_compilation() + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.mode != CompilationMode.VLLM_COMPILE ): logger.info( @@ -693,29 +720,22 @@ def has_blocked_weights(): if current_platform.support_static_graph_mode(): # if cudagraph_mode has full cudagraphs, we need to check support - if model_config := self.model_config: - if ( - self.compilation_config.cudagraph_mode.has_full_cudagraphs() - and model_config.pooler_config is not None - ): + if ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + and self.model_config is not None + ): + if self.model_config.pooler_config is not None: logger.warning_once( "Pooling models do not support full cudagraphs. " "Overriding cudagraph_mode to PIECEWISE." ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - elif ( - model_config.is_encoder_decoder - and self.compilation_config.cudagraph_mode - not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY) - ): - logger.info_once( - "Encoder-decoder models do not support %s. " - "Overriding cudagraph_mode to FULL_DECODE_ONLY.", - self.compilation_config.cudagraph_mode.name, - ) - self.compilation_config.cudagraph_mode = ( - CUDAGraphMode.FULL_DECODE_ONLY + elif self.model_config.is_encoder_decoder: + logger.warning_once( + "Encoder-decoder models do not support full cudagraphs. " + "Overriding cudagraph_mode to PIECEWISE." ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE # disable cudagraph when enforce eager execution if self.model_config is not None and self.model_config.enforce_eager: @@ -750,17 +770,27 @@ def has_blocked_weights(): # TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands self._set_compile_ranges() - if ( - self.model_config - and self.model_config.architecture == "WhisperForConditionalGeneration" - and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn" - ): - logger.warning( - "Whisper is known to have issues with " - "forked workers. If startup is hanging, " - "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " - "to 'spawn'." + if self.model_config and self.model_config.is_encoder_decoder: + from vllm.multimodal import MULTIMODAL_REGISTRY + + self.scheduler_config.max_num_encoder_input_tokens = ( + MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) ) + logger.debug( + "Encoder-decoder model detected: setting " + "`max_num_encoder_input_tokens` to encoder length (%s)", + self.scheduler_config.max_num_encoder_input_tokens, + ) + if ( + self.model_config.architecture == "WhisperForConditionalGeneration" + and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn" + ): + logger.warning( + "Whisper is known to have issues with " + "forked workers. If startup is hanging, " + "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " + "to 'spawn'." + ) if ( self.kv_events_config is not None @@ -810,6 +840,11 @@ def has_blocked_weights(): f"({self.parallel_config.cp_kv_cache_interleave_size})." ) + assert ( + self.parallel_config.cp_kv_cache_interleave_size == 1 + or self.speculative_config is None + ), "MTP with cp_kv_cache_interleave_size > 1 is not supported now." + # Do this after all the updates to compilation_config.mode self.compilation_config.set_splitting_ops_for_v1( all2all_backend=self.parallel_config.all2all_backend, @@ -887,48 +922,17 @@ def has_blocked_weights(): if not self.instance_id: self.instance_id = random_uuid()[:5] - # Hybrid KV cache manager (HMA) runtime rules: - # - Explicit enable (--no-disable-kv-cache-manager): error if runtime - # disables it - # - No preference: auto-disable for unsupported features (e.g. kv connector) - # - Explicit disable (--disable-kv-cache-manager): always respect it - need_disable_hybrid_kv_cache_manager = False - # logger should only print warning message for hybrid models. As we - # can't know whether the model is hybrid or not now, so we don't log - # warning message here and will log it later. - if not current_platform.support_hybrid_kv_cache(): - # Hybrid KV cache manager is not supported on non-GPU platforms. - need_disable_hybrid_kv_cache_manager = True - if self.kv_events_config is not None: - # Hybrid KV cache manager is not compatible with KV events. - need_disable_hybrid_kv_cache_manager = True - if ( - self.model_config is not None - and self.model_config.attention_chunk_size is not None - ): - if ( - self.speculative_config is not None - and self.speculative_config.use_eagle() - ): - # Hybrid KV cache manager is not yet supported with chunked - # local attention + eagle. - need_disable_hybrid_kv_cache_manager = True - elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: - logger.warning( - "There is a latency regression when using chunked local" - " attention with the hybrid KV cache manager. Disabling" - " it, by default. To enable it, set the environment " - "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." - ) - # Hybrid KV cache manager is not yet supported with chunked - # local attention. - need_disable_hybrid_kv_cache_manager = True - - if self.scheduler_config.disable_hybrid_kv_cache_manager is None: - # Default to disable HMA, but only if the user didn't express a preference. + if not self.scheduler_config.disable_hybrid_kv_cache_manager: + # logger should only print warning message for hybrid models. As we + # can't know whether the model is hybrid or not now, so we don't log + # warning message here and will log it later. + if not current_platform.support_hybrid_kv_cache(): + # Hybrid KV cache manager is not supported on non-GPU platforms. + self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.kv_transfer_config is not None: - # NOTE(Kuntai): turn HMA off for connector unless specifically enabled. - need_disable_hybrid_kv_cache_manager = True + # NOTE(Kuntai): turn HMA off for connector for now. + # TODO(Kuntai): have a more elegent solution to check and + # turn off HMA for connector that does not support HMA. logger.warning( "Turning off hybrid kv cache manager because " "`--kv-transfer-config` is set. This will reduce the " @@ -936,26 +940,33 @@ def has_blocked_weights(): "or Mamba attention. If you are a developer of kv connector" ", please consider supporting hybrid kv cache manager for " "your connector by making sure your connector is a subclass" - " of `SupportsHMA` defined in kv_connector/v1/base.py and" - " use --no-disable-hybrid-kv-cache-manager to start vLLM." + " of `SupportsHMA` defined in kv_connector/v1/base.py." ) - self.scheduler_config.disable_hybrid_kv_cache_manager = ( - need_disable_hybrid_kv_cache_manager - ) - elif ( - self.scheduler_config.disable_hybrid_kv_cache_manager is False - and need_disable_hybrid_kv_cache_manager - ): - raise ValueError( - "Hybrid KV cache manager was explicitly enabled but is not " - "supported in this configuration. Consider omitting the " - "--no-disable-hybrid-kv-cache-manager flag to let vLLM decide" - " automatically." - ) - - if self.scheduler_config.disable_hybrid_kv_cache_manager is None: - # Default to enable HMA if not explicitly disabled by user or logic above. - self.scheduler_config.disable_hybrid_kv_cache_manager = False + self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.kv_events_config is not None: + # Hybrid KV cache manager is not compatible with KV events. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + if ( + self.model_config is not None + and self.model_config.attention_chunk_size is not None + ): + if ( + self.speculative_config is not None + and self.speculative_config.use_eagle() + ): + # Hybrid KV cache manager is not yet supported with chunked + # local attention + eagle. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: + logger.warning( + "There is a latency regression when using chunked local" + " attention with the hybrid KV cache manager. Disabling" + " it, by default. To enable it, set the environment " + "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." + ) + # Hybrid KV cache manager is not yet supported with chunked + # local attention. + self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.compilation_config.debug_dump_path: self.compilation_config.debug_dump_path = ( @@ -1023,7 +1034,7 @@ def _set_cudagraph_sizes(self): max_graph_size = min(max_num_seqs * 2, 512) # 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16 # up to max_graph_size - cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list( + cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list( range(256, max_graph_size + 1, 16)) In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` @@ -1064,14 +1075,8 @@ def _set_cudagraph_sizes(self): self.compilation_config.max_cudagraph_capture_size ) if max_cudagraph_capture_size is None: - decode_query_len = 1 - if ( - self.speculative_config - and self.speculative_config.num_speculative_tokens - ): - decode_query_len += self.speculative_config.num_speculative_tokens max_cudagraph_capture_size = min( - self.scheduler_config.max_num_seqs * decode_query_len * 2, 512 + self.scheduler_config.max_num_seqs * 2, 512 ) max_num_tokens = self.scheduler_config.max_num_batched_tokens max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9dd734f2fea6..5c9da892f001 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -54,9 +54,12 @@ def get_name(self) -> QuantizationMethods: def get_supported_act_dtypes(self) -> list[torch.dtype]: # GGUF dequantization kernels use half precision (fp16) internally. - # bfloat16 has precision issues on Blackwell devices. + # bfloat16 has precision issues on SM 10.0+ devices (Blackwell). if current_platform.has_device_capability(100): - logger.warning_once("GGUF has precision issues with bfloat16 on Blackwell.") + logger.warning_once( + "GGUF has precision issues with bfloat16 on Blackwell (SM 10.0+). " + "bfloat16 is unavailable." + ) return [torch.half, torch.float32] return [torch.half, torch.bfloat16, torch.float32] diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 7f94bd234fd3..c00d9c850c89 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -147,6 +147,11 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): ) ) + # For models with tied word embeddings, lm_head.weight is initialized + # from embed_tokens and doesn't need to be mapped from GGUF file + if getattr(config, "tie_word_embeddings", False): + sideload_params.append(re.compile(r"lm_head\.weight")) + arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): if value == model_type: diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index fe6ec5ff83de..ab0fc925d8b5 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -266,6 +266,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -365,6 +367,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue if is_pp_missing_parameter(name, self): continue + # Skip parameters not in the model (e.g., GGUF quantization + # metadata like qweight_type for embeddings) + if name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index a11d37b4b2ed..8ecde9c53360 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -505,14 +505,26 @@ def maybe_override_with_speculators( else: gguf_model_repo = None kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE - config_dict, _ = PretrainedConfig.get_config_dict( - model if gguf_model_repo is None else gguf_model_repo, - revision=revision, - trust_remote_code=trust_remote_code, - token=_get_hf_token(), - **kwargs, - ) - speculators_config = config_dict.get("speculators_config") + try: + config_dict, _ = PretrainedConfig.get_config_dict( + model if gguf_model_repo is None else gguf_model_repo, + revision=revision, + trust_remote_code=trust_remote_code, + token=_get_hf_token(), + **kwargs, + ) + speculators_config = config_dict.get("speculators_config") + except OSError as e: + # GGUF models without config.json cannot have speculators config + # (speculators is defined in config.json), so skip gracefully. + # We only suppress "file not found" errors, not other OS errors like + # permission denied. + is_file_not_found = isinstance( + e, FileNotFoundError + ) or "does not appear to have a file named" in str(e) + if gguf_model_repo is not None and is_file_not_found: + return model, tokenizer, vllm_speculative_config + raise if speculators_config is None: # No speculators config found, return original values diff --git a/vllm/transformers_utils/gguf_utils.py b/vllm/transformers_utils/gguf_utils.py index f3fd43c6ace5..67858dcd4752 100644 --- a/vllm/transformers_utils/gguf_utils.py +++ b/vllm/transformers_utils/gguf_utils.py @@ -95,13 +95,13 @@ def detect_gguf_multimodal(model: str) -> Path | None: Returns: Path to mmproj file if found, None otherwise """ - if not model.endswith(".gguf"): + # Use magic bytes detection instead of file extension heuristic + # HuggingFace blob paths don't have .gguf extension + if not check_gguf_file(model): return None try: model_path = Path(model) - if not model_path.is_file(): - return None model_dir = model_path.parent mmproj_patterns = ["mmproj.gguf", "mmproj-*.gguf", "*mmproj*.gguf"] @@ -199,6 +199,67 @@ def extract_vision_config_from_gguf(mmproj_path: str) -> "SiglipVisionConfig | N return config +def extract_softcap_from_gguf(model: str) -> dict[str, float]: + """Extract attention and final logit softcap values from GGUF metadata. + + Reads softcap parameters from GGUF metadata using arch-specific keys. + These parameters are critical for models like Gemma2 where attention + logit softcapping prevents numerical instability. + + Args: + model: Path to GGUF model file + + Returns: + Dictionary with 'attn_logit_softcapping' and/or 'final_logit_softcapping' + keys if found in GGUF metadata, empty dict otherwise + """ + # Use magic bytes detection instead of file extension heuristic + # HuggingFace blob paths don't have .gguf extension + if not check_gguf_file(model): + return {} + + try: + model_path = Path(model) + + reader = gguf.GGUFReader(str(model_path)) + + # Get architecture name to build arch-specific keys + arch_field = reader.get_field(Keys.General.ARCHITECTURE) + if arch_field is None: + logger.debug("No architecture field found in GGUF metadata") + return {} + + arch = bytes(arch_field.parts[-1]).decode("utf-8") + + result = {} + + # Extract attention logit softcapping + attn_key = Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=arch) + attn_field = reader.get_field(attn_key) + if attn_field is not None: + result["attn_logit_softcapping"] = float(attn_field.parts[-1]) + logger.info( + "Extracted attn_logit_softcapping=%.2f from GGUF metadata", + result["attn_logit_softcapping"], + ) + + # Extract final logit softcapping + final_key = Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=arch) + final_field = reader.get_field(final_key) + if final_field is not None: + result["final_logit_softcapping"] = float(final_field.parts[-1]) + logger.info( + "Extracted final_logit_softcapping=%.2f from GGUF metadata", + result["final_logit_softcapping"], + ) + + return result + + except Exception as e: + logger.warning("Error extracting softcap from GGUF '%s': %s", model, e) + return {} + + def maybe_patch_hf_config_from_gguf( model: str, hf_config: PretrainedConfig, @@ -207,7 +268,8 @@ def maybe_patch_hf_config_from_gguf( Applies GGUF-specific patches to HuggingFace config: 1. For multimodal models: patches architecture and vision config - 2. For all GGUF models: overrides vocab_size from embedding tensor + 2. For models with softcap (e.g., Gemma2): patches attention/logit softcapping + 3. For all GGUF models: overrides vocab_size from embedding tensor This ensures compatibility with GGUF models that have extended vocabularies (e.g., Unsloth) where the GGUF file contains more @@ -236,6 +298,15 @@ def maybe_patch_hf_config_from_gguf( ) hf_config = new_hf_config + # Patch softcap parameters from GGUF metadata + # Critical for models like Gemma2 where attention softcapping + # prevents numerical instability and ensures correct output + softcap_params = extract_softcap_from_gguf(model) + if "attn_logit_softcapping" in softcap_params: + hf_config.attn_logit_softcapping = softcap_params["attn_logit_softcapping"] + if "final_logit_softcapping" in softcap_params: + hf_config.final_logit_softcapping = softcap_params["final_logit_softcapping"] + return hf_config