Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/models/test_gguf_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,20 @@ def test_prepare_weights_repo_quant_type(
"""Test _prepare_weights with repo_id:quant_type format."""
mock_hf_config = MagicMock()
mock_hf_config.architectures = ["Qwen3ForCausalLM"]
mock_hf_config.model_type = "qwen3"
mock_hf_config.quantization_config = None
mock_hf_config.compression_config = None

class MockTextConfig:
max_position_embeddings = 4096
sliding_window = None
model_type = "qwen3"
num_attention_heads = 32
quantization_config = None
compression_config = None

mock_text_config = MockTextConfig()
mock_hf_config.text_config = mock_text_config
mock_hf_config.get_text_config.return_value = mock_text_config
mock_hf_config.dtype = "bfloat16"
mock_get_config.return_value = mock_hf_config
Expand Down Expand Up @@ -197,14 +203,20 @@ def test_prepare_weights_invalid_format(
"""Test _prepare_weights with invalid format."""
mock_hf_config = MagicMock()
mock_hf_config.architectures = ["Qwen3ForCausalLM"]
mock_hf_config.model_type = "qwen3"
mock_hf_config.quantization_config = None
mock_hf_config.compression_config = None

class MockTextConfig:
max_position_embeddings = 4096
sliding_window = None
model_type = "qwen3"
num_attention_heads = 32
quantization_config = None
compression_config = None

mock_text_config = MockTextConfig()
mock_hf_config.text_config = mock_text_config
mock_hf_config.get_text_config.return_value = mock_text_config
mock_hf_config.dtype = "bfloat16"
mock_get_config.return_value = mock_hf_config
Expand Down
282 changes: 244 additions & 38 deletions vllm/model_executor/model_loader/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,183 @@ def _get_all_gguf_files(model_path: str) -> list[str]:
logger.info("Discovered %d GGUF shard files", len(files))
return files if files else [model_path]

@staticmethod
def _normalize_hf_name_for_gguf(
hf_name: str, *, is_multimodal: bool
) -> tuple[str, str]:
"""Normalize HF state dict names for GGUF tensor lookup.

Returns:
(normalized_name, suffix)
"""
if is_multimodal and hf_name.startswith("model."):
hf_name = hf_name[6:]

if hf_name.startswith("language_model."):
hf_name = hf_name[15:]
if is_multimodal:
hf_name = "model." + hf_name

if hf_name.endswith((".weight", ".bias")):
base_name, suffix = hf_name.rsplit(".", 1)
else:
base_name, suffix = hf_name, ""
if base_name.endswith("_weight"):
base_name = base_name[:-7]
suffix = "weight"

return base_name, suffix

@staticmethod
def _build_gemma4_manual_mapping(
normalized_state_names: set[str],
num_hidden_layers: int,
*,
vision_num_hidden_layers: int | None = None,
) -> tuple[dict[str, str], set[str]]:
"""Build Gemma4 GGUF mappings missing from gguf-py's tensor tables."""
gguf_to_hf_name_map: dict[str, str] = {}
handled_params: set[str] = set()

def add_mapping(
gguf_name: str,
hf_name: str,
*,
handled_name: str | None = None,
) -> None:
if handled_name is None:
handled_name = hf_name
if handled_name in normalized_state_names:
gguf_to_hf_name_map[gguf_name] = hf_name
handled_params.add(handled_name)

for idx in range(num_hidden_layers):
layer_prefix = f"model.layers.{idx}"
add_mapping(
f"blk.{idx}.layer_output_scale.weight",
f"{layer_prefix}.layer_scalar",
)
add_mapping(
f"blk.{idx}.ffn_gate_inp.scale",
f"{layer_prefix}.router.scale",
)
add_mapping(
f"blk.{idx}.ffn_down_exps.scale",
f"{layer_prefix}.router.per_expert_scale",
)
add_mapping(
f"blk.{idx}.ffn_gate_inp.weight",
f"{layer_prefix}.router.proj.weight",
)
add_mapping(
f"blk.{idx}.ffn_gate_up_exps.weight",
f"{layer_prefix}.moe.gate_up_proj.weight",
handled_name=f"{layer_prefix}.experts.gate_up_proj",
)
add_mapping(
f"blk.{idx}.ffn_down_exps.weight",
f"{layer_prefix}.moe.down_proj.weight",
handled_name=f"{layer_prefix}.experts.down_proj",
)
Comment on lines +171 to +180
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The handled_name for Gemma4 MoE expert weights (gate/up and down) is missing the .weight suffix. Since normalized_state_names (built at line 441) includes the suffix for weight tensors, this mismatch will cause the manual mapping to be skipped. Consequently, these parameters will be flagged as unmapped in the final check (line 509), leading to a RuntimeError during model loading.

Suggested change
add_mapping(
f"blk.{idx}.ffn_gate_up_exps.weight",
f"{layer_prefix}.moe.gate_up_proj.weight",
handled_name=f"{layer_prefix}.experts.gate_up_proj",
)
add_mapping(
f"blk.{idx}.ffn_down_exps.weight",
f"{layer_prefix}.moe.down_proj.weight",
handled_name=f"{layer_prefix}.experts.down_proj",
)
add_mapping(
f"blk.{idx}.ffn_gate_up_exps.weight",
f"{layer_prefix}.moe.gate_up_proj.weight",
handled_name=f"{layer_prefix}.experts.gate_up_proj.weight",
)
add_mapping(
f"blk.{idx}.ffn_down_exps.weight",
f"{layer_prefix}.moe.down_proj.weight",
handled_name=f"{layer_prefix}.experts.down_proj.weight",
)

add_mapping(
f"blk.{idx}.post_ffw_norm_1.weight",
f"{layer_prefix}.post_feedforward_layernorm_1.weight",
)
add_mapping(
f"blk.{idx}.post_ffw_norm_2.weight",
f"{layer_prefix}.post_feedforward_layernorm_2.weight",
)
add_mapping(
f"blk.{idx}.pre_ffw_norm_2.weight",
f"{layer_prefix}.pre_feedforward_layernorm_2.weight",
)

add_mapping("v.std_bias", "vision_tower.std_bias")
add_mapping("v.std_scale", "vision_tower.std_scale")
add_mapping(
"v.patch_embd.weight",
"vision_tower.patch_embedder.input_proj.weight",
)
add_mapping(
"v.position_embd.weight",
"vision_tower.patch_embedder.position_embedding_table",
)
add_mapping(
"mm.input_projection.weight",
"embed_vision.embedding_projection.weight",
)

if vision_num_hidden_layers is not None:
for idx in range(vision_num_hidden_layers):
layer_prefix = f"vision_tower.encoder.layers.{idx}"
add_mapping(
f"v.blk.{idx}.attn_q.weight",
f"{layer_prefix}.self_attn.q_proj.linear.weight",
)
add_mapping(
f"v.blk.{idx}.attn_k.weight",
f"{layer_prefix}.self_attn.k_proj.linear.weight",
)
add_mapping(
f"v.blk.{idx}.attn_v.weight",
f"{layer_prefix}.self_attn.v_proj.linear.weight",
)
add_mapping(
f"v.blk.{idx}.attn_out.weight",
f"{layer_prefix}.self_attn.o_proj.linear.weight",
)
add_mapping(
f"v.blk.{idx}.attn_q_norm.weight",
f"{layer_prefix}.self_attn.q_norm.weight",
)
add_mapping(
f"v.blk.{idx}.attn_k_norm.weight",
f"{layer_prefix}.self_attn.k_norm.weight",
)
add_mapping(
f"v.blk.{idx}.ln1.weight",
f"{layer_prefix}.input_layernorm.weight",
)
add_mapping(
f"v.blk.{idx}.attn_post_norm.weight",
f"{layer_prefix}.post_attention_layernorm.weight",
)
add_mapping(
f"v.blk.{idx}.ln2.weight",
f"{layer_prefix}.pre_feedforward_layernorm.weight",
)
add_mapping(
f"v.blk.{idx}.ffn_post_norm.weight",
f"{layer_prefix}.post_feedforward_layernorm.weight",
)
add_mapping(
f"v.blk.{idx}.ffn_gate.weight",
f"{layer_prefix}.mlp.gate_proj.linear.weight",
)
add_mapping(
f"v.blk.{idx}.ffn_up.weight",
f"{layer_prefix}.mlp.up_proj.linear.weight",
)
add_mapping(
f"v.blk.{idx}.ffn_down.weight",
f"{layer_prefix}.mlp.down_proj.linear.weight",
)

return gguf_to_hf_name_map, handled_params

@staticmethod
def _transform_gemma4_gguf_tensor_name_and_weight(
name: str, weight: torch.Tensor
) -> tuple[str, torch.Tensor]:
"""Adapt Gemma4 GGUF tensors to vLLM's final parameter layout."""
if (
name == "vision_tower.patch_embedder.input_proj.weight"
and weight.dim() == 4
):
return name, weight.reshape(weight.shape[0], -1).contiguous()

return name, weight

def _get_gguf_weights_map(self, model_config: ModelConfig):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
Expand Down Expand Up @@ -195,11 +372,21 @@ def _get_gguf_weights_map(self, model_config: ModelConfig):
)
)

gemma4_manual_map: dict[str, str] = {}
gemma4_handled_params: set[str] = set()

arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items():
if value == model_type:
arch = key
break
if model_type == "gemma4":
# gguf-py may lag behind Gemma4 architecture registration even when
# the tensor naming convention is largely compatible with Gemma3.
# Reuse the closest built-in table for common tensors and layer on
# top manual mappings for Gemma4-specific additions.
arch = gguf.MODEL_ARCH.GEMMA3
else:
for key, value in gguf.MODEL_ARCH_NAMES.items():
if value == model_type:
arch = key
break
if arch is None:
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
text_num_layers = text_config.num_hidden_layers
Expand Down Expand Up @@ -246,6 +433,27 @@ def revert_hf_rename(name: str) -> str:
for name, tensor in state_dict.items()
}

normalized_state_names: set[str] = set()
for hf_name in state_dict:
base_name, suffix = self._normalize_hf_name_for_gguf(
hf_name, is_multimodal=is_multimodal
)
normalized_state_names.add(base_name + (f".{suffix}" if suffix else ""))

if model_type == "gemma4":
gemma4_manual_map, gemma4_handled_params = (
self._build_gemma4_manual_mapping(
normalized_state_names,
text_num_layers,
vision_num_hidden_layers=(
config.vision_config.num_hidden_layers
if is_multimodal
else None
),
)
)
gguf_to_hf_name_map.update(gemma4_manual_map)

def find_hf_name_in_tensor_map(hf_name: str) -> str | None:
"""
Map HuggingFace parameter name to GGUF tensor name.
Expand All @@ -265,35 +473,9 @@ def find_hf_name_in_tensor_map(hf_name: str) -> str | None:
GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight')
or None if no mapping found
"""
# In transformers v5, multimodal models (e.g. Gemma3) wrap
# all sub-models under an outer 'model.' attribute, producing
# state_dict keys like 'model.language_model.layers.0...' and
# 'model.vision_tower.vision_model...'. Strip this outer
# prefix so the keys match what gguf-py expects.
if is_multimodal and hf_name.startswith("model."):
hf_name = hf_name[6:] # Remove outer 'model.'

# Strip 'language_model.' prefix for multimodal models - gguf-py
# tensor mappings expect parameter names without this prefix.
# Note: 'model.' prefix should be KEPT for text-only models as
# gguf-py expects it.
if hf_name.startswith("language_model."):
hf_name = hf_name[15:] # Remove 'language_model.'
# Re-add 'model.' prefix because gguf-py text tensor maps
# expect 'model.layers...' format.
if is_multimodal:
hf_name = "model." + hf_name

# Parse parameter name and suffix
if hf_name.endswith((".weight", ".bias")):
base_name, suffix = hf_name.rsplit(".", 1)
else:
base_name, suffix = hf_name, ""
# Handle '_weight' suffix (Gemma3 naming: parameter ends with
# '_weight' instead of '.weight')
if base_name.endswith("_weight"):
base_name = base_name[:-7] # Remove '_weight'
suffix = "weight"
base_name, suffix = self._normalize_hf_name_for_gguf(
hf_name, is_multimodal=is_multimodal
)

gguf_name = None
# Priority 1: Search vision/projector parameters for multimodal models
Expand All @@ -313,14 +495,23 @@ def find_hf_name_in_tensor_map(hf_name: str) -> str | None:
unmapped_params = []
for hf_name in state_dict:
gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name)
normalized_base, normalized_suffix = self._normalize_hf_name_for_gguf(
hf_name, is_multimodal=is_multimodal
)
normalized_hf_name = normalized_base + (
f".{normalized_suffix}" if normalized_suffix else ""
)

# Track mapping success
if gguf_name_with_suffix is not None:
gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name
logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name)
elif hf_name not in gguf_to_hf_name_map.values():
elif (
normalized_hf_name not in gemma4_handled_params
and hf_name not in gguf_to_hf_name_map.values()
):
# Parameter not in manual overrides either
unmapped_params.append(hf_name)
unmapped_params.append(normalized_hf_name)

# All parameters (except those initialized by other means) must be mapped:
# both vision/projector and backbone
Expand Down Expand Up @@ -388,18 +579,33 @@ def _get_weights_iterator(
assert mmproj_file is not None, (
"Could not find mm_proj file for multimodal GGUF model"
)
yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map)
mmproj_iterator = gguf_quant_weights_iterator(
mmproj_file, gguf_to_hf_name_map
)
if hf_config.model_type == "gemma4":
for name, weight in mmproj_iterator:
yield self._transform_gemma4_gguf_tensor_name_and_weight(
name, weight
)
else:
yield from mmproj_iterator

gguf_files = self._get_all_gguf_files(model_name_or_path)
if len(gguf_files) > 1:
yield from gguf_quant_weights_iterator_multi(
iterator = gguf_quant_weights_iterator_multi(
gguf_files, gguf_to_hf_name_map
)
else:
yield from gguf_quant_weights_iterator(
iterator = gguf_quant_weights_iterator(
model_name_or_path, gguf_to_hf_name_map
)

if hf_config.model_type == "gemma4":
for name, weight in iterator:
yield self._transform_gemma4_gguf_tensor_name_and_weight(name, weight)
else:
yield from iterator

def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config)

Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/gemma4_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,12 +962,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Some variants have hidden_size_per_layer_input=None (no PLE).
ple_dim = config.text_config.hidden_size_per_layer_input
if ple_dim is not None:
model_device = next(self.language_model.parameters()).device
self.per_layer_embeddings = torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.num_hidden_layers,
ple_dim,
device=(self.language_model.model.embed_tokens.weight.device),
dtype=(self.language_model.model.embed_tokens.weight.dtype),
device=model_device,
dtype=vllm_config.model_config.dtype,
)
else:
self.per_layer_embeddings = None
Expand Down
Loading
Loading