diff --git a/tests/models/test_gguf_download.py b/tests/models/test_gguf_download.py index e9ca35afd66a..8140de6c0a9c 100644 --- a/tests/models/test_gguf_download.py +++ b/tests/models/test_gguf_download.py @@ -150,14 +150,20 @@ def test_prepare_weights_repo_quant_type( """Test _prepare_weights with repo_id:quant_type format.""" mock_hf_config = MagicMock() mock_hf_config.architectures = ["Qwen3ForCausalLM"] + mock_hf_config.model_type = "qwen3" + mock_hf_config.quantization_config = None + mock_hf_config.compression_config = None class MockTextConfig: max_position_embeddings = 4096 sliding_window = None model_type = "qwen3" num_attention_heads = 32 + quantization_config = None + compression_config = None mock_text_config = MockTextConfig() + mock_hf_config.text_config = mock_text_config mock_hf_config.get_text_config.return_value = mock_text_config mock_hf_config.dtype = "bfloat16" mock_get_config.return_value = mock_hf_config @@ -197,14 +203,20 @@ def test_prepare_weights_invalid_format( """Test _prepare_weights with invalid format.""" mock_hf_config = MagicMock() mock_hf_config.architectures = ["Qwen3ForCausalLM"] + mock_hf_config.model_type = "qwen3" + mock_hf_config.quantization_config = None + mock_hf_config.compression_config = None class MockTextConfig: max_position_embeddings = 4096 sliding_window = None model_type = "qwen3" num_attention_heads = 32 + quantization_config = None + compression_config = None mock_text_config = MockTextConfig() + mock_hf_config.text_config = mock_text_config mock_hf_config.get_text_config.return_value = mock_text_config mock_hf_config.dtype = "bfloat16" mock_get_config.return_value = mock_hf_config diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index fc6f88b49ee1..2d2e9abbf568 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -100,6 +100,183 @@ def _get_all_gguf_files(model_path: str) -> list[str]: logger.info("Discovered %d GGUF shard files", len(files)) return files if files else [model_path] + @staticmethod + def _normalize_hf_name_for_gguf( + hf_name: str, *, is_multimodal: bool + ) -> tuple[str, str]: + """Normalize HF state dict names for GGUF tensor lookup. + + Returns: + (normalized_name, suffix) + """ + if is_multimodal and hf_name.startswith("model."): + hf_name = hf_name[6:] + + if hf_name.startswith("language_model."): + hf_name = hf_name[15:] + if is_multimodal: + hf_name = "model." + hf_name + + if hf_name.endswith((".weight", ".bias")): + base_name, suffix = hf_name.rsplit(".", 1) + else: + base_name, suffix = hf_name, "" + if base_name.endswith("_weight"): + base_name = base_name[:-7] + suffix = "weight" + + return base_name, suffix + + @staticmethod + def _build_gemma4_manual_mapping( + normalized_state_names: set[str], + num_hidden_layers: int, + *, + vision_num_hidden_layers: int | None = None, + ) -> tuple[dict[str, str], set[str]]: + """Build Gemma4 GGUF mappings missing from gguf-py's tensor tables.""" + gguf_to_hf_name_map: dict[str, str] = {} + handled_params: set[str] = set() + + def add_mapping( + gguf_name: str, + hf_name: str, + *, + handled_name: str | None = None, + ) -> None: + if handled_name is None: + handled_name = hf_name + if handled_name in normalized_state_names: + gguf_to_hf_name_map[gguf_name] = hf_name + handled_params.add(handled_name) + + for idx in range(num_hidden_layers): + layer_prefix = f"model.layers.{idx}" + add_mapping( + f"blk.{idx}.layer_output_scale.weight", + f"{layer_prefix}.layer_scalar", + ) + add_mapping( + f"blk.{idx}.ffn_gate_inp.scale", + f"{layer_prefix}.router.scale", + ) + add_mapping( + f"blk.{idx}.ffn_down_exps.scale", + f"{layer_prefix}.router.per_expert_scale", + ) + add_mapping( + f"blk.{idx}.ffn_gate_inp.weight", + f"{layer_prefix}.router.proj.weight", + ) + add_mapping( + f"blk.{idx}.ffn_gate_up_exps.weight", + f"{layer_prefix}.moe.gate_up_proj.weight", + handled_name=f"{layer_prefix}.experts.gate_up_proj", + ) + add_mapping( + f"blk.{idx}.ffn_down_exps.weight", + f"{layer_prefix}.moe.down_proj.weight", + handled_name=f"{layer_prefix}.experts.down_proj", + ) + add_mapping( + f"blk.{idx}.post_ffw_norm_1.weight", + f"{layer_prefix}.post_feedforward_layernorm_1.weight", + ) + add_mapping( + f"blk.{idx}.post_ffw_norm_2.weight", + f"{layer_prefix}.post_feedforward_layernorm_2.weight", + ) + add_mapping( + f"blk.{idx}.pre_ffw_norm_2.weight", + f"{layer_prefix}.pre_feedforward_layernorm_2.weight", + ) + + add_mapping("v.std_bias", "vision_tower.std_bias") + add_mapping("v.std_scale", "vision_tower.std_scale") + add_mapping( + "v.patch_embd.weight", + "vision_tower.patch_embedder.input_proj.weight", + ) + add_mapping( + "v.position_embd.weight", + "vision_tower.patch_embedder.position_embedding_table", + ) + add_mapping( + "mm.input_projection.weight", + "embed_vision.embedding_projection.weight", + ) + + if vision_num_hidden_layers is not None: + for idx in range(vision_num_hidden_layers): + layer_prefix = f"vision_tower.encoder.layers.{idx}" + add_mapping( + f"v.blk.{idx}.attn_q.weight", + f"{layer_prefix}.self_attn.q_proj.linear.weight", + ) + add_mapping( + f"v.blk.{idx}.attn_k.weight", + f"{layer_prefix}.self_attn.k_proj.linear.weight", + ) + add_mapping( + f"v.blk.{idx}.attn_v.weight", + f"{layer_prefix}.self_attn.v_proj.linear.weight", + ) + add_mapping( + f"v.blk.{idx}.attn_out.weight", + f"{layer_prefix}.self_attn.o_proj.linear.weight", + ) + add_mapping( + f"v.blk.{idx}.attn_q_norm.weight", + f"{layer_prefix}.self_attn.q_norm.weight", + ) + add_mapping( + f"v.blk.{idx}.attn_k_norm.weight", + f"{layer_prefix}.self_attn.k_norm.weight", + ) + add_mapping( + f"v.blk.{idx}.ln1.weight", + f"{layer_prefix}.input_layernorm.weight", + ) + add_mapping( + f"v.blk.{idx}.attn_post_norm.weight", + f"{layer_prefix}.post_attention_layernorm.weight", + ) + add_mapping( + f"v.blk.{idx}.ln2.weight", + f"{layer_prefix}.pre_feedforward_layernorm.weight", + ) + add_mapping( + f"v.blk.{idx}.ffn_post_norm.weight", + f"{layer_prefix}.post_feedforward_layernorm.weight", + ) + add_mapping( + f"v.blk.{idx}.ffn_gate.weight", + f"{layer_prefix}.mlp.gate_proj.linear.weight", + ) + add_mapping( + f"v.blk.{idx}.ffn_up.weight", + f"{layer_prefix}.mlp.up_proj.linear.weight", + ) + add_mapping( + f"v.blk.{idx}.ffn_down.weight", + f"{layer_prefix}.mlp.down_proj.linear.weight", + ) + + return gguf_to_hf_name_map, handled_params + + @staticmethod + def _transform_gemma4_gguf_tensor_name_and_weight( + name: str, weight: torch.Tensor + ) -> tuple[str, torch.Tensor]: + """Adapt Gemma4 GGUF tensors to vLLM's final parameter layout.""" + if ( + name == "vision_tower.patch_embedder.input_proj.weight" + and weight.dim() == 4 + ): + return name, weight.reshape(weight.shape[0], -1).contiguous() + + return name, weight + def _get_gguf_weights_map(self, model_config: ModelConfig): """ GGUF uses this naming convention for their tensors from HF checkpoint: @@ -195,11 +372,21 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): ) ) + gemma4_manual_map: dict[str, str] = {} + gemma4_handled_params: set[str] = set() + arch = None - for key, value in gguf.MODEL_ARCH_NAMES.items(): - if value == model_type: - arch = key - break + if model_type == "gemma4": + # gguf-py may lag behind Gemma4 architecture registration even when + # the tensor naming convention is largely compatible with Gemma3. + # Reuse the closest built-in table for common tensors and layer on + # top manual mappings for Gemma4-specific additions. + arch = gguf.MODEL_ARCH.GEMMA3 + else: + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + arch = key + break if arch is None: raise RuntimeError(f"Unknown gguf model_type: {model_type}") text_num_layers = text_config.num_hidden_layers @@ -246,6 +433,27 @@ def revert_hf_rename(name: str) -> str: for name, tensor in state_dict.items() } + normalized_state_names: set[str] = set() + for hf_name in state_dict: + base_name, suffix = self._normalize_hf_name_for_gguf( + hf_name, is_multimodal=is_multimodal + ) + normalized_state_names.add(base_name + (f".{suffix}" if suffix else "")) + + if model_type == "gemma4": + gemma4_manual_map, gemma4_handled_params = ( + self._build_gemma4_manual_mapping( + normalized_state_names, + text_num_layers, + vision_num_hidden_layers=( + config.vision_config.num_hidden_layers + if is_multimodal + else None + ), + ) + ) + gguf_to_hf_name_map.update(gemma4_manual_map) + def find_hf_name_in_tensor_map(hf_name: str) -> str | None: """ Map HuggingFace parameter name to GGUF tensor name. @@ -265,35 +473,9 @@ def find_hf_name_in_tensor_map(hf_name: str) -> str | None: GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight') or None if no mapping found """ - # In transformers v5, multimodal models (e.g. Gemma3) wrap - # all sub-models under an outer 'model.' attribute, producing - # state_dict keys like 'model.language_model.layers.0...' and - # 'model.vision_tower.vision_model...'. Strip this outer - # prefix so the keys match what gguf-py expects. - if is_multimodal and hf_name.startswith("model."): - hf_name = hf_name[6:] # Remove outer 'model.' - - # Strip 'language_model.' prefix for multimodal models - gguf-py - # tensor mappings expect parameter names without this prefix. - # Note: 'model.' prefix should be KEPT for text-only models as - # gguf-py expects it. - if hf_name.startswith("language_model."): - hf_name = hf_name[15:] # Remove 'language_model.' - # Re-add 'model.' prefix because gguf-py text tensor maps - # expect 'model.layers...' format. - if is_multimodal: - hf_name = "model." + hf_name - - # Parse parameter name and suffix - if hf_name.endswith((".weight", ".bias")): - base_name, suffix = hf_name.rsplit(".", 1) - else: - base_name, suffix = hf_name, "" - # Handle '_weight' suffix (Gemma3 naming: parameter ends with - # '_weight' instead of '.weight') - if base_name.endswith("_weight"): - base_name = base_name[:-7] # Remove '_weight' - suffix = "weight" + base_name, suffix = self._normalize_hf_name_for_gguf( + hf_name, is_multimodal=is_multimodal + ) gguf_name = None # Priority 1: Search vision/projector parameters for multimodal models @@ -313,14 +495,23 @@ def find_hf_name_in_tensor_map(hf_name: str) -> str | None: unmapped_params = [] for hf_name in state_dict: gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name) + normalized_base, normalized_suffix = self._normalize_hf_name_for_gguf( + hf_name, is_multimodal=is_multimodal + ) + normalized_hf_name = normalized_base + ( + f".{normalized_suffix}" if normalized_suffix else "" + ) # Track mapping success if gguf_name_with_suffix is not None: gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name) - elif hf_name not in gguf_to_hf_name_map.values(): + elif ( + normalized_hf_name not in gemma4_handled_params + and hf_name not in gguf_to_hf_name_map.values() + ): # Parameter not in manual overrides either - unmapped_params.append(hf_name) + unmapped_params.append(normalized_hf_name) # All parameters (except those initialized by other means) must be mapped: # both vision/projector and backbone @@ -388,18 +579,33 @@ def _get_weights_iterator( assert mmproj_file is not None, ( "Could not find mm_proj file for multimodal GGUF model" ) - yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map) + mmproj_iterator = gguf_quant_weights_iterator( + mmproj_file, gguf_to_hf_name_map + ) + if hf_config.model_type == "gemma4": + for name, weight in mmproj_iterator: + yield self._transform_gemma4_gguf_tensor_name_and_weight( + name, weight + ) + else: + yield from mmproj_iterator gguf_files = self._get_all_gguf_files(model_name_or_path) if len(gguf_files) > 1: - yield from gguf_quant_weights_iterator_multi( + iterator = gguf_quant_weights_iterator_multi( gguf_files, gguf_to_hf_name_map ) else: - yield from gguf_quant_weights_iterator( + iterator = gguf_quant_weights_iterator( model_name_or_path, gguf_to_hf_name_map ) + if hf_config.model_type == "gemma4": + for name, weight in iterator: + yield self._transform_gemma4_gguf_tensor_name_and_weight(name, weight) + else: + yield from iterator + def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config) diff --git a/vllm/model_executor/models/gemma4_mm.py b/vllm/model_executor/models/gemma4_mm.py index 9b2c54e27354..a8817921a091 100644 --- a/vllm/model_executor/models/gemma4_mm.py +++ b/vllm/model_executor/models/gemma4_mm.py @@ -962,12 +962,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Some variants have hidden_size_per_layer_input=None (no PLE). ple_dim = config.text_config.hidden_size_per_layer_input if ple_dim is not None: + model_device = next(self.language_model.parameters()).device self.per_layer_embeddings = torch.zeros( vllm_config.scheduler_config.max_num_batched_tokens, config.text_config.num_hidden_layers, ple_dim, - device=(self.language_model.model.embed_tokens.weight.device), - dtype=(self.language_model.model.embed_tokens.weight.dtype), + device=model_device, + dtype=vllm_config.model_config.dtype, ) else: self.per_layer_embeddings = None diff --git a/vllm/tokenizers/registry.py b/vllm/tokenizers/registry.py index 8778aa9d691f..62d2f0c92eb4 100644 --- a/vllm/tokenizers/registry.py +++ b/vllm/tokenizers/registry.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import copy from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path @@ -15,6 +16,7 @@ from vllm.transformers_utils.gguf_utils import ( check_gguf_file, get_gguf_file_path_from_hf, + get_gguf_tokenizer_special_ids, is_gguf, is_remote_gguf, split_remote_gguf, @@ -92,6 +94,38 @@ def load_tokenizer(self, tokenizer_mode: str, *args, **kwargs) -> TokenizerLike: ) +def _maybe_patch_gemma4_gguf_tokenizer( + tokenizer: TokenizerLike, + model: str | Path, + model_type: str | None, +) -> TokenizerLike: + if model_type != "gemma4" or not check_gguf_file(model): + return tokenizer + + special_ids = get_gguf_tokenizer_special_ids(model) + if not special_ids: + return tokenizer + + patched_tokenizer = copy.copy(tokenizer) + token_attrs = { + "padding_token_id": "pad_token", + "bos_token_id": "bos_token", + "eos_token_id": "eos_token", + "unknown_token_id": "unk_token", + } + for id_attr, token_attr in token_attrs.items(): + token_id = special_ids.get(id_attr) + if token_id is None: + continue + tokens = patched_tokenizer.convert_ids_to_tokens([token_id]) + if not tokens: + continue + token = tokens[0] + setattr(patched_tokenizer, token_attr, token) + + return patched_tokenizer + + def resolve_tokenizer_args( tokenizer_name: str | Path, *args, @@ -258,7 +292,7 @@ def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs): if model_config.skip_tokenizer_init: return None - return cached_get_tokenizer( + tokenizer = cached_get_tokenizer( model_config.tokenizer, runner_type=model_config.runner_type, tokenizer_mode=model_config.tokenizer_mode, @@ -266,3 +300,8 @@ def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs): trust_remote_code=model_config.trust_remote_code, **kwargs, ) + return _maybe_patch_gemma4_gguf_tokenizer( + tokenizer, + model_config.model, + getattr(model_config.hf_config, "model_type", None), + ) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 2f00178ba6ef..d986e5e4619c 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -607,8 +607,14 @@ def maybe_override_with_speculators( Tuple of (resolved_model, resolved_tokenizer, speculative_config) """ if check_gguf_file(model): - kwargs["gguf_file"] = Path(model).name gguf_model_repo = Path(model).parent + # Prefer sibling config.json when present instead of forcing + # transformers to parse GGUF metadata directly. This keeps local GGUF + # models loadable even when the installed transformers GGUF parser + # lags behind the architecture but the repository config.json is + # already available. + if not file_or_path_exists(gguf_model_repo, HF_CONFIG_NAME, revision=revision): + kwargs["gguf_file"] = Path(model).name elif is_remote_gguf(model): repo_id, _ = split_remote_gguf(model) gguf_model_repo = Path(repo_id) @@ -660,9 +666,15 @@ def get_config( _is_remote_gguf = is_remote_gguf(model) if _is_gguf: if check_gguf_file(model): - # Local GGUF file - kwargs["gguf_file"] = Path(model).name - model = Path(model).parent + # Local GGUF file. Prefer sibling config.json when available + # rather than routing through transformers' GGUF checkpoint + # loader, which may not support the architecture yet. + gguf_model_dir = Path(model).parent + if not file_or_path_exists( + gguf_model_dir, HF_CONFIG_NAME, revision=revision + ): + kwargs["gguf_file"] = Path(model).name + model = gguf_model_dir elif _is_remote_gguf: # Remote GGUF - extract repo_id from repo_id:quant_type format # The actual GGUF file will be downloaded later by GGUFModelLoader diff --git a/vllm/transformers_utils/gguf_utils.py b/vllm/transformers_utils/gguf_utils.py index 7708378ee13b..2c63f73dfe17 100644 --- a/vllm/transformers_utils/gguf_utils.py +++ b/vllm/transformers_utils/gguf_utils.py @@ -3,7 +3,6 @@ """GGUF utility functions.""" from functools import cache -from os import PathLike from pathlib import Path import gguf @@ -19,8 +18,16 @@ logger = init_logger(__name__) +_GGUF_TOKENIZER_SPECIAL_ID_FIELDS = { + "bos_token_id": "tokenizer.ggml.bos_token_id", + "eos_token_id": "tokenizer.ggml.eos_token_id", + "unknown_token_id": "tokenizer.ggml.unknown_token_id", + "padding_token_id": "tokenizer.ggml.padding_token_id", +} + + @cache -def check_gguf_file(model: str | PathLike) -> bool: +def check_gguf_file(model: str | Path) -> bool: """Check if the file is a GGUF model.""" model = Path(model) if not model.is_file(): @@ -170,6 +177,25 @@ def detect_gguf_multimodal(model: str) -> Path | None: return None +@cache +def get_gguf_tokenizer_special_ids(model: str | Path) -> dict[str, int]: + """Read tokenizer special token ids embedded in a local GGUF file.""" + if not check_gguf_file(model): + return {} + + reader = gguf.GGUFReader(str(model)) + special_ids: dict[str, int] = {} + for key, field_name in _GGUF_TOKENIZER_SPECIAL_ID_FIELDS.items(): + field = reader.get_field(field_name) + if field is None: + continue + try: + special_ids[key] = int(field.parts[-1]) + except (TypeError, ValueError): + logger.warning("Failed to parse GGUF tokenizer field %s", field_name) + return special_ids + + def extract_vision_config_from_gguf(mmproj_path: str) -> "SiglipVisionConfig | None": """Extract vision config parameters from mmproj.gguf metadata. diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 0e241f6abfd1..b4b5d30f65ac 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import importlib import inspect from functools import lru_cache @@ -24,6 +25,7 @@ from typing_extensions import TypeVar from vllm.logger import init_logger +from vllm.tokenizers.registry import _maybe_patch_gemma4_gguf_tokenizer from vllm.transformers_utils import processors from vllm.transformers_utils.gguf_utils import is_gguf from vllm.transformers_utils.repo_utils import get_hf_file_to_dict @@ -352,13 +354,24 @@ def cached_processor_from_config( model = model_config.model revision = model_config.revision - return cached_get_processor_without_dynamic_kwargs( + processor = cached_get_processor_without_dynamic_kwargs( model, revision=revision, trust_remote_code=model_config.trust_remote_code, processor_cls=processor_cls, # type: ignore[arg-type] **_merge_mm_kwargs(model_config, processor_cls, **kwargs), ) + tokenizer = getattr(processor, "tokenizer", None) + if tokenizer is not None: + tokenizer = _maybe_patch_gemma4_gguf_tokenizer( + tokenizer, + model_config.model, + getattr(model_config.hf_config, "model_type", None), + ) + if tokenizer is not processor.tokenizer: + processor = copy.copy(processor) + processor.tokenizer = tokenizer + return processor def get_feature_extractor(