diff --git a/vllm/model_executor/layers/fused_moe/activation.py b/vllm/model_executor/layers/fused_moe/activation.py index 3112b3054fcd..7a013c8f28ae 100644 --- a/vllm/model_executor/layers/fused_moe/activation.py +++ b/vllm/model_executor/layers/fused_moe/activation.py @@ -15,6 +15,7 @@ class MoEActivation(Enum): # and produce output of shape [..., d] SILU = "silu" GELU = "gelu" + GELU_TANH = "gelu_tanh" RELU2 = "relu2" SWIGLUOAI = "swigluoai" SWIGLUSTEP = "swiglustep" @@ -24,6 +25,7 @@ class MoEActivation(Enum): # NOTE: Non-gated activations require the "_no_mul" suffix to be present. SILU_NO_MUL = "silu_no_mul" GELU_NO_MUL = "gelu_no_mul" + GELU_TANH_NO_MUL = "gelu_tanh_no_mul" RELU2_NO_MUL = "relu2_no_mul" @property @@ -53,6 +55,8 @@ def without_mul(self) -> "MoEActivation": @classmethod def from_str(cls, s: str) -> "MoEActivation": """Parse from string for backward compatibility.""" + if s == "gelu_pytorch_tanh": + s = cls.GELU_TANH.value for member in cls: if member.value == s: return member @@ -64,17 +68,20 @@ def from_str(cls, s: str) -> "MoEActivation": _CUSTOM_OP_NAMES: dict[MoEActivation, str] = { MoEActivation.SILU: "silu_and_mul", MoEActivation.GELU: "gelu_and_mul", + MoEActivation.GELU_TANH: "gelu_tanh_and_mul", MoEActivation.SWIGLUOAI: "swigluoai_and_mul", MoEActivation.SWIGLUSTEP: "swiglustep_and_mul", MoEActivation.RELU2: "relu2", MoEActivation.SILU_NO_MUL: "silu_and_mul", MoEActivation.GELU_NO_MUL: "gelu_and_mul", + MoEActivation.GELU_TANH_NO_MUL: "gelu_tanh_and_mul", MoEActivation.RELU2_NO_MUL: "relu2", } _WITHOUT_MUL: dict[MoEActivation, MoEActivation] = { MoEActivation.SILU: MoEActivation.SILU_NO_MUL, MoEActivation.GELU: MoEActivation.GELU_NO_MUL, + MoEActivation.GELU_TANH: MoEActivation.GELU_TANH_NO_MUL, MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL, } @@ -115,6 +122,12 @@ def apply_moe_activation( torch.ops._C.silu_and_mul(output, input) elif activation == MoEActivation.GELU: torch.ops._C.gelu_and_mul(output, input) + elif activation == MoEActivation.GELU_TANH: + if hasattr(torch.ops._C, "gelu_tanh_and_mul"): + torch.ops._C.gelu_tanh_and_mul(output, input) + else: + gate, up = input.chunk(2, dim=-1) + output.copy_(F.gelu(gate, approximate="tanh") * up) elif activation == MoEActivation.SWIGLUOAI: torch.ops._C.swigluoai_and_mul(output, input) elif activation == MoEActivation.SWIGLUSTEP: @@ -127,6 +140,8 @@ def apply_moe_activation( output.copy_(F.silu(input)) elif activation == MoEActivation.GELU_NO_MUL: output.copy_(F.gelu(input)) + elif activation == MoEActivation.GELU_TANH_NO_MUL: + output.copy_(F.gelu(input, approximate="tanh")) elif activation == MoEActivation.RELU2_NO_MUL: F.relu(input, inplace=True) torch.square(input, out=output) diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index e1bedd6f45be..89d69527779f 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -48,6 +48,10 @@ def _gelu_and_mul( MoEActivation.SILU: SiluAndMul.forward_native, MoEActivation.SWIGLUOAI: _swigluoai_forward_native, MoEActivation.GELU: _gelu_and_mul, + MoEActivation.GELU_TANH: ( + lambda x: F.gelu(x[..., : x.shape[-1] // 2], approximate="tanh") + * x[..., x.shape[-1] // 2 :] + ), } diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index fdd802e7da3a..7c0d0d8d1771 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -319,6 +319,7 @@ def _supports_activation(activation: MoEActivation) -> bool: return activation in [ MoEActivation.SILU, MoEActivation.GELU, + MoEActivation.GELU_TANH, MoEActivation.SWIGLUOAI, ] @@ -709,10 +710,12 @@ def _supports_activation(activation: MoEActivation) -> bool: return activation in [ MoEActivation.SILU, MoEActivation.GELU, + MoEActivation.GELU_TANH, MoEActivation.SWIGLUOAI, MoEActivation.SWIGLUSTEP, MoEActivation.SILU_NO_MUL, MoEActivation.GELU_NO_MUL, + MoEActivation.GELU_TANH_NO_MUL, MoEActivation.RELU2_NO_MUL, ] diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 5554298bd090..cad85c241a47 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -944,9 +944,11 @@ def _supports_activation(activation: MoEActivation) -> bool: return activation in [ MoEActivation.SILU, MoEActivation.GELU, + MoEActivation.GELU_TANH, MoEActivation.SWIGLUOAI, MoEActivation.SILU_NO_MUL, MoEActivation.GELU_NO_MUL, + MoEActivation.GELU_TANH_NO_MUL, MoEActivation.RELU2_NO_MUL, ] diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 6143c3d0adca..f0e48bf736ef 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -599,10 +599,12 @@ def _supports_activation(activation: MoEActivation) -> bool: return activation in [ MoEActivation.SILU, MoEActivation.GELU, + MoEActivation.GELU_TANH, MoEActivation.SWIGLUOAI, MoEActivation.SWIGLUSTEP, MoEActivation.SILU_NO_MUL, MoEActivation.GELU_NO_MUL, + MoEActivation.GELU_TANH_NO_MUL, MoEActivation.RELU2_NO_MUL, ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index cf53907e2c3f..50e7a94b0e6a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1940,10 +1940,12 @@ def _supports_activation(activation: MoEActivation) -> bool: return activation in [ MoEActivation.SILU, MoEActivation.GELU, + MoEActivation.GELU_TANH, MoEActivation.SWIGLUOAI, MoEActivation.SWIGLUSTEP, MoEActivation.SILU_NO_MUL, MoEActivation.GELU_NO_MUL, + MoEActivation.GELU_TANH_NO_MUL, MoEActivation.RELU2_NO_MUL, ] diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py index 4457555c0764..d67bff572a7d 100644 --- a/vllm/model_executor/layers/quantization/inc.py +++ b/vllm/model_executor/layers/quantization/inc.py @@ -6,8 +6,10 @@ import regex as re import torch +import torch.nn.functional as F from torch.nn.parameter import Parameter +from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.model_executor.layers.linear import ( LinearBase, @@ -18,9 +20,17 @@ QuantizationConfig, QuantizationMethods, ) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supports_layer, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + unpack_quantized_values_into_int32, +) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, GroupQuantScaleParameter, + PackedColumnParameter, PackedvLLMParameter, RowvLLMParameter, ) @@ -341,6 +351,22 @@ def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"): group_size, sym, ) + if ( + isinstance(layer, LinearBase) + and group_size > 0 + and getattr(layer, "input_size_per_partition", layer.input_size) + % group_size + != 0 + ): + # Gemma4 AutoRound row-parallel linears can produce TP shards that + # straddle a GPTQ group boundary. Fall back to a correctness-first + # path in that case instead of using Marlin/GPTQ kernels that + # assume group-aligned input shards. + return INCGPTQRowParallelTailLinearMethod( + weight_bits=weight_bits, + group_size=group_size, + sym=sym, + ) if backend == "auto" or "marlin" in backend: GPTQ_TYPE_MAP = { (4, True): scalar_types.uint4b8, @@ -353,6 +379,10 @@ def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"): use_marlin = use_marlin and check_moe_marlin_supports_layer( layer, group_size ) + elif isinstance(layer, LinearBase): + use_marlin = use_marlin and check_marlin_supports_layer( + layer, group_size + ) else: use_marlin = False if use_marlin: @@ -625,3 +655,130 @@ def apply( None, # g_idx not needed: desc_act is always False for INC models ) return out.reshape(out_shape) + + +class INCGPTQRowParallelTailLinearMethod(LinearMethodBase): + """Fallback for row-parallel GPTQ-family linears with group-tail shards.""" + + def __init__(self, weight_bits: int, group_size: int, sym: bool): + self.weight_bits = weight_bits + self.group_size = group_size + self.sym = sym + self.pack_factor = 32 // weight_bits + self.weight_type = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128, + }[weight_bits] + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + full_num_groups = (input_size + self.group_size - 1) // self.group_size + + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.pack_factor, + weight_loader=weight_loader, + ) + scales = ChannelQuantScaleParameter( + data=torch.empty( + full_num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + output_dim=1, + weight_loader=weight_loader, + ) + qzeros = PackedColumnParameter( + data=torch.empty( + full_num_groups, + output_size_per_partition // self.pack_factor, + dtype=torch.int32, + ), + output_dim=1, + packed_dim=1, + packed_factor=self.pack_factor, + weight_loader=weight_loader, + ) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + + shard_width = getattr( + layer, "input_size_per_partition", input_size_per_partition + ) + shard_offset = get_tensor_model_parallel_rank() * shard_width + g_idx = ( + torch.arange(input_size_per_partition, dtype=torch.int32) + shard_offset + ) // self.group_size + layer.register_parameter("g_idx", Parameter(g_idx, requires_grad=False)) + layer._inc_tail_dequant_weight = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.sym: + # The tail-shard fallback dequantizes weights on demand and handles + # the symmetric zero point via weight_type.bias in + # _get_dequantized_weight(), so the large packed qzeros tensor is + # replaced with a tiny placeholder after loading. + layer.qzeros = Parameter( + torch.tensor([8], dtype=torch.int8, device=layer.qweight.device), + requires_grad=False, + ) + else: + layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.scales = Parameter(layer.scales.data, requires_grad=False) + + def _get_dequantized_weight(self, layer: torch.nn.Module) -> torch.Tensor: + cached = layer._inc_tail_dequant_weight + if cached is not None: + return cached + + if not self.sym: + raise NotImplementedError( + "INCGPTQRowParallelTailLinearMethod currently supports only " + "symmetric checkpoints." + ) + + qweight = unpack_quantized_values_into_int32( + layer.qweight.data, self.weight_type, packed_dim=0 + ).to(torch.float32) + qweight = qweight - float(self.weight_type.bias) + + g_idx = layer.g_idx.data.to(torch.long) + scales = layer.scales.data.to(torch.float32) + dequant = qweight * scales.index_select(0, g_idx) + weight = dequant.t().contiguous() + # Cache the dequantized tail-shard weight after the first fallback use. + layer._inc_tail_dequant_weight = weight + return weight + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + out_shape = x.shape[:-1] + (layer.qweight.shape[1],) + x_2d = x.reshape(-1, x.shape[-1]).to(torch.float32) + bias_2d = bias.to(torch.float32) if bias is not None else None + output = F.linear(x_2d, self._get_dequantized_weight(layer), bias_2d) + return output.to(x.dtype).reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index e5ef3f4c3168..d67b386b9330 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -6,7 +6,6 @@ import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group -from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, int4_w4a16_moe_quant_config, @@ -372,10 +371,6 @@ def apply( ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts - assert layer.activation == MoEActivation.SILU, ( - f"Only SiLU activation is supported, not {layer.activation}." - ) - return fused_experts( x, layer.w13_qweight, @@ -383,6 +378,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=not self.moe.disable_inplace, + activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index fc6f88b49ee1..d62e10c04f49 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -100,6 +100,187 @@ def _get_all_gguf_files(model_path: str) -> list[str]: logger.info("Discovered %d GGUF shard files", len(files)) return files if files else [model_path] + @staticmethod + def _normalize_hf_name_for_gguf( + hf_name: str, *, is_multimodal: bool + ) -> tuple[str, str]: + """Normalize HF state dict names for GGUF tensor lookup. + + Returns: + (normalized_name, suffix) + """ + if is_multimodal and hf_name.startswith("model."): + hf_name = hf_name[6:] + + if hf_name.startswith("language_model."): + hf_name = hf_name[15:] + if is_multimodal: + hf_name = "model." + hf_name + + if hf_name.endswith((".weight", ".bias")): + base_name, suffix = hf_name.rsplit(".", 1) + else: + base_name, suffix = hf_name, "" + if base_name.endswith("_weight"): + base_name = base_name[:-7] + suffix = "weight" + + return base_name, suffix + + @staticmethod + def _build_gemma4_manual_mapping( + normalized_state_names: set[str], + num_hidden_layers: int, + *, + vision_num_hidden_layers: int | None = None, + ) -> tuple[dict[str, str], set[str]]: + """Build Gemma4 GGUF mappings missing from gguf-py's tensor tables.""" + gguf_to_hf_name_map: dict[str, str] = {} + handled_params: set[str] = set() + + def add_mapping( + gguf_name: str, + hf_name: str, + *, + handled_name: str | None = None, + ) -> None: + if handled_name is None: + handled_name = hf_name + # handled_name must match the HF state_dict key emitted by the + # installed Gemma4 transformers config/model classes. If upstream + # renames these tensors, update handled_name alongside the manual + # GGUF mapping. + if handled_name in normalized_state_names: + gguf_to_hf_name_map[gguf_name] = hf_name + handled_params.add(handled_name) + + for idx in range(num_hidden_layers): + layer_prefix = f"model.layers.{idx}" + add_mapping( + f"blk.{idx}.layer_output_scale.weight", + f"{layer_prefix}.layer_scalar", + ) + add_mapping( + f"blk.{idx}.ffn_gate_inp.scale", + f"{layer_prefix}.router.scale", + ) + add_mapping( + f"blk.{idx}.ffn_down_exps.scale", + f"{layer_prefix}.router.per_expert_scale", + ) + add_mapping( + f"blk.{idx}.ffn_gate_inp.weight", + f"{layer_prefix}.router.proj.weight", + ) + add_mapping( + f"blk.{idx}.ffn_gate_up_exps.weight", + f"{layer_prefix}.moe.gate_up_proj.weight", + handled_name=f"{layer_prefix}.experts.gate_up_proj", + ) + add_mapping( + f"blk.{idx}.ffn_down_exps.weight", + f"{layer_prefix}.moe.down_proj.weight", + handled_name=f"{layer_prefix}.experts.down_proj", + ) + add_mapping( + f"blk.{idx}.post_ffw_norm_1.weight", + f"{layer_prefix}.post_feedforward_layernorm_1.weight", + ) + add_mapping( + f"blk.{idx}.post_ffw_norm_2.weight", + f"{layer_prefix}.post_feedforward_layernorm_2.weight", + ) + add_mapping( + f"blk.{idx}.pre_ffw_norm_2.weight", + f"{layer_prefix}.pre_feedforward_layernorm_2.weight", + ) + + add_mapping("v.std_bias", "vision_tower.std_bias") + add_mapping("v.std_scale", "vision_tower.std_scale") + add_mapping( + "v.patch_embd.weight", + "vision_tower.patch_embedder.input_proj.weight", + ) + add_mapping( + "v.position_embd.weight", + "vision_tower.patch_embedder.position_embedding_table", + ) + add_mapping( + "mm.input_projection.weight", + "embed_vision.embedding_projection.weight", + ) + + if vision_num_hidden_layers is not None: + for idx in range(vision_num_hidden_layers): + layer_prefix = f"vision_tower.encoder.layers.{idx}" + add_mapping( + f"v.blk.{idx}.attn_q.weight", + f"{layer_prefix}.self_attn.q_proj.linear.weight", + ) + add_mapping( + f"v.blk.{idx}.attn_k.weight", + f"{layer_prefix}.self_attn.k_proj.linear.weight", + ) + add_mapping( + f"v.blk.{idx}.attn_v.weight", + f"{layer_prefix}.self_attn.v_proj.linear.weight", + ) + add_mapping( + f"v.blk.{idx}.attn_out.weight", + f"{layer_prefix}.self_attn.o_proj.linear.weight", + ) + add_mapping( + f"v.blk.{idx}.attn_q_norm.weight", + f"{layer_prefix}.self_attn.q_norm.weight", + ) + add_mapping( + f"v.blk.{idx}.attn_k_norm.weight", + f"{layer_prefix}.self_attn.k_norm.weight", + ) + add_mapping( + f"v.blk.{idx}.ln1.weight", + f"{layer_prefix}.input_layernorm.weight", + ) + add_mapping( + f"v.blk.{idx}.attn_post_norm.weight", + f"{layer_prefix}.post_attention_layernorm.weight", + ) + add_mapping( + f"v.blk.{idx}.ln2.weight", + f"{layer_prefix}.pre_feedforward_layernorm.weight", + ) + add_mapping( + f"v.blk.{idx}.ffn_post_norm.weight", + f"{layer_prefix}.post_feedforward_layernorm.weight", + ) + add_mapping( + f"v.blk.{idx}.ffn_gate.weight", + f"{layer_prefix}.mlp.gate_proj.linear.weight", + ) + add_mapping( + f"v.blk.{idx}.ffn_up.weight", + f"{layer_prefix}.mlp.up_proj.linear.weight", + ) + add_mapping( + f"v.blk.{idx}.ffn_down.weight", + f"{layer_prefix}.mlp.down_proj.linear.weight", + ) + + return gguf_to_hf_name_map, handled_params + + @staticmethod + def _transform_gemma4_gguf_tensor_name_and_weight( + name: str, weight: torch.Tensor + ) -> tuple[str, torch.Tensor]: + """Adapt Gemma4 GGUF tensors to vLLM's final parameter layout.""" + if ( + name == "vision_tower.patch_embedder.input_proj.weight" + and weight.dim() == 4 + ): + return name, weight.reshape(weight.shape[0], -1).contiguous() + + return name, weight + def _get_gguf_weights_map(self, model_config: ModelConfig): """ GGUF uses this naming convention for their tensors from HF checkpoint: @@ -195,11 +376,21 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): ) ) + gemma4_manual_map: dict[str, str] = {} + gemma4_handled_params: set[str] = set() + arch = None - for key, value in gguf.MODEL_ARCH_NAMES.items(): - if value == model_type: - arch = key - break + if model_type == "gemma4": + # gguf-py may lag behind Gemma4 architecture registration even when + # the tensor naming convention is largely compatible with Gemma3. + # Reuse the closest built-in table for common tensors and layer on + # top manual mappings for Gemma4-specific additions. + arch = gguf.MODEL_ARCH.GEMMA3 + else: + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + arch = key + break if arch is None: raise RuntimeError(f"Unknown gguf model_type: {model_type}") text_num_layers = text_config.num_hidden_layers @@ -246,6 +437,27 @@ def revert_hf_rename(name: str) -> str: for name, tensor in state_dict.items() } + normalized_state_names: set[str] = set() + for hf_name in state_dict: + base_name, suffix = self._normalize_hf_name_for_gguf( + hf_name, is_multimodal=is_multimodal + ) + normalized_state_names.add(base_name + (f".{suffix}" if suffix else "")) + + if model_type == "gemma4": + gemma4_manual_map, gemma4_handled_params = ( + self._build_gemma4_manual_mapping( + normalized_state_names, + text_num_layers, + vision_num_hidden_layers=( + config.vision_config.num_hidden_layers + if is_multimodal + else None + ), + ) + ) + gguf_to_hf_name_map.update(gemma4_manual_map) + def find_hf_name_in_tensor_map(hf_name: str) -> str | None: """ Map HuggingFace parameter name to GGUF tensor name. @@ -265,35 +477,9 @@ def find_hf_name_in_tensor_map(hf_name: str) -> str | None: GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight') or None if no mapping found """ - # In transformers v5, multimodal models (e.g. Gemma3) wrap - # all sub-models under an outer 'model.' attribute, producing - # state_dict keys like 'model.language_model.layers.0...' and - # 'model.vision_tower.vision_model...'. Strip this outer - # prefix so the keys match what gguf-py expects. - if is_multimodal and hf_name.startswith("model."): - hf_name = hf_name[6:] # Remove outer 'model.' - - # Strip 'language_model.' prefix for multimodal models - gguf-py - # tensor mappings expect parameter names without this prefix. - # Note: 'model.' prefix should be KEPT for text-only models as - # gguf-py expects it. - if hf_name.startswith("language_model."): - hf_name = hf_name[15:] # Remove 'language_model.' - # Re-add 'model.' prefix because gguf-py text tensor maps - # expect 'model.layers...' format. - if is_multimodal: - hf_name = "model." + hf_name - - # Parse parameter name and suffix - if hf_name.endswith((".weight", ".bias")): - base_name, suffix = hf_name.rsplit(".", 1) - else: - base_name, suffix = hf_name, "" - # Handle '_weight' suffix (Gemma3 naming: parameter ends with - # '_weight' instead of '.weight') - if base_name.endswith("_weight"): - base_name = base_name[:-7] # Remove '_weight' - suffix = "weight" + base_name, suffix = self._normalize_hf_name_for_gguf( + hf_name, is_multimodal=is_multimodal + ) gguf_name = None # Priority 1: Search vision/projector parameters for multimodal models @@ -313,14 +499,23 @@ def find_hf_name_in_tensor_map(hf_name: str) -> str | None: unmapped_params = [] for hf_name in state_dict: gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name) + normalized_base, normalized_suffix = self._normalize_hf_name_for_gguf( + hf_name, is_multimodal=is_multimodal + ) + normalized_hf_name = normalized_base + ( + f".{normalized_suffix}" if normalized_suffix else "" + ) # Track mapping success if gguf_name_with_suffix is not None: gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name) - elif hf_name not in gguf_to_hf_name_map.values(): + elif ( + normalized_hf_name not in gemma4_handled_params + and hf_name not in gguf_to_hf_name_map.values() + ): # Parameter not in manual overrides either - unmapped_params.append(hf_name) + unmapped_params.append(normalized_hf_name) # All parameters (except those initialized by other means) must be mapped: # both vision/projector and backbone @@ -388,18 +583,33 @@ def _get_weights_iterator( assert mmproj_file is not None, ( "Could not find mm_proj file for multimodal GGUF model" ) - yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map) + mmproj_iterator = gguf_quant_weights_iterator( + mmproj_file, gguf_to_hf_name_map + ) + if hf_config.model_type == "gemma4": + for name, weight in mmproj_iterator: + yield self._transform_gemma4_gguf_tensor_name_and_weight( + name, weight + ) + else: + yield from mmproj_iterator gguf_files = self._get_all_gguf_files(model_name_or_path) if len(gguf_files) > 1: - yield from gguf_quant_weights_iterator_multi( + iterator = gguf_quant_weights_iterator_multi( gguf_files, gguf_to_hf_name_map ) else: - yield from gguf_quant_weights_iterator( + iterator = gguf_quant_weights_iterator( model_name_or_path, gguf_to_hf_name_map ) + if hf_config.model_type == "gemma4": + for name, weight in iterator: + yield self._transform_gemma4_gguf_tensor_name_and_weight(name, weight) + else: + yield from iterator + def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config) diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index 42762e36f816..448dda1df066 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -48,6 +48,9 @@ ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + unpack_quantized_values_into_int32, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -58,6 +61,7 @@ maybe_remap_kv_scale_name, ) from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types from vllm.sequence import IntermediateTensors from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata @@ -195,6 +199,51 @@ def gemma4_routing_function_torch( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +def _dequantize_autoround_gptq_router_weight( + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + num_bits: int, + group_size: int, + sym: bool, + params_dtype: torch.dtype, +) -> torch.Tensor: + if num_bits not in (4, 8): + raise ValueError(f"Router dequant: unsupported num_bits={num_bits}") + weight_type = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128, + }[num_bits] + unpacked_qweight = unpack_quantized_values_into_int32( + qweight, weight_type, packed_dim=0 + ).to(torch.float32) + unpacked_qzeros = unpack_quantized_values_into_int32( + qzeros, weight_type, packed_dim=1 + ).to(torch.float32) + if sym: + # AutoRound's GPTQ-style symmetric checkpoints store router qzeros + # with a -1 offset relative to the effective zero point. + unpacked_qzeros = unpacked_qzeros + 1 + row_groups = ( + torch.arange(unpacked_qweight.shape[0], device=qweight.device) // group_size + ) + scales_per_row = scales.to(torch.float32)[row_groups] + qzeros_per_row = unpacked_qzeros[row_groups] + weight = (unpacked_qweight - qzeros_per_row) * scales_per_row + return weight.t().to(params_dtype) + + +def _map_gemma4_moe_suffix(name: str, source_base: str, target_base: str) -> str | None: + for source_suffix, target_suffix in ( + (source_base, f"{target_base}_weight"), + (f"{source_base}_packed", f"{target_base}_weight_packed"), + (f"{source_base}_scale", f"{target_base}_weight_scale"), + ): + if name.endswith(source_suffix): + return name[: -len(source_suffix)] + target_suffix + return None + + def _get_text_config(config): """Dereference text_config if config is a nested Gemma4Config. @@ -353,7 +402,11 @@ def routing_function( quant_config=quant_config, prefix=f"{prefix}.experts", custom_routing_function=routing_function, - activation="gelu", + activation=( + "gelu_tanh" + if config.hidden_activation == "gelu_pytorch_tanh" + else "gelu" + ), ) def forward(self, x: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor: @@ -1402,6 +1455,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Include buffers (e.g. layer_scalar) so they can be loaded too params_dict.update(dict(self.named_buffers())) loaded_params: set[str] = set() + router_quant_params: dict[str, dict[str, torch.Tensor]] = {} for name, loaded_weight in weights: if self.quant_config is not None and ( scale_name := self.quant_config.get_cache_scale(name) @@ -1424,6 +1478,40 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params.add(remapped_name) continue + if ( + self.quant_config is not None + and name.endswith((".qweight", ".qzeros", ".scales")) + and ".router.proj." in name + ): + router_name, _, suffix = name.rpartition(".") + router_quant_params.setdefault(router_name, {})[suffix] = loaded_weight + loaded_params.add(name) + quant_params = router_quant_params[router_name] + if len(quant_params) == 3: + weight_name = f"{router_name}.weight" + if is_pp_missing_parameter(weight_name, self): + del router_quant_params[router_name] + continue + if weight_name not in params_dict: + raise KeyError(weight_name) + param = params_dict[weight_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + router_weight = _dequantize_autoround_gptq_router_weight( + qweight=quant_params["qweight"], + qzeros=quant_params["qzeros"], + scales=quant_params["scales"], + num_bits=self.quant_config.weight_bits, + group_size=self.quant_config.group_size, + sym=self.quant_config.sym, + params_dtype=param.dtype, + ) + weight_loader(param, router_weight) + loaded_params.add(weight_name) + del router_quant_params[router_name] + continue + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue @@ -1455,12 +1543,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if weight_name in name: # Has suffix (e.g., .weight_scale) moe_name = name.replace(weight_name, param_name) - elif name.endswith(weight_name_base): - # Bare weight (no suffix) - moe_name = name.replace( - weight_name_base, param_name.rstrip("_") + "_weight" - ) else: + moe_name = _map_gemma4_moe_suffix( + name, + weight_name_base, + param_name.rstrip("_"), + ) + if moe_name is None: continue if moe_name not in params_dict: continue @@ -1497,6 +1586,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weight_loader(param, loaded_weight) loaded_params.add(name) + if router_quant_params: + for router_name, quant_params in router_quant_params.items(): + logger.warning( + "Skipping incomplete quantized router params for %s: %s", + router_name, + sorted(quant_params), + ) + return loaded_params @@ -1672,6 +1769,35 @@ def _weight_iterator(): # No transpose needed: checkpoint orientation already # matches FusedMoE's expected layout. if "moe.gate_up_proj" in name and weight.dim() == 3: + if name.endswith("_packed") or name.endswith("_scale"): + num_experts = weight.size(0) + split_size = weight.size(1) // 2 + for expert_id in range(num_experts): + gate_weight = weight[expert_id, :split_size, :] + up_weight = weight[expert_id, split_size:, :] + base = name.replace("moe.", f"moe.experts.{expert_id}.") + if name.endswith("_packed"): + gate_name = base.replace( + "gate_up_proj_packed", + "gate_proj.weight_packed", + ) + up_name = base.replace( + "gate_up_proj_packed", + "up_proj.weight_packed", + ) + else: + gate_name = base.replace( + "gate_up_proj_scale", + "gate_proj.weight_scale", + ) + up_name = base.replace( + "gate_up_proj_scale", + "up_proj.weight_scale", + ) + yield gate_name, gate_weight + yield up_name, up_weight + continue + num_experts = weight.size(0) intermediate_size = weight.size(1) // 2 for expert_id in range(num_experts): @@ -1686,6 +1812,14 @@ def _weight_iterator(): num_experts = weight.size(0) for expert_id in range(num_experts): expert_name = name.replace("moe.", f"moe.experts.{expert_id}.") + if expert_name.endswith("_packed"): + expert_name = expert_name.replace( + "down_proj_packed", "down_proj.weight_packed" + ) + elif expert_name.endswith("_scale"): + expert_name = expert_name.replace( + "down_proj_scale", "down_proj.weight_scale" + ) yield expert_name, weight[expert_id] continue diff --git a/vllm/model_executor/models/gemma4_mm.py b/vllm/model_executor/models/gemma4_mm.py index 46d0308f4c86..4865bba27a55 100644 --- a/vllm/model_executor/models/gemma4_mm.py +++ b/vllm/model_executor/models/gemma4_mm.py @@ -955,12 +955,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Some variants have hidden_size_per_layer_input=None (no PLE). ple_dim = config.text_config.hidden_size_per_layer_input if ple_dim is not None: + model_device = next(self.language_model.parameters()).device self.per_layer_embeddings = torch.zeros( vllm_config.scheduler_config.max_num_batched_tokens, config.text_config.num_hidden_layers, ple_dim, - device=(self.language_model.model.embed_tokens.weight.device), - dtype=(self.language_model.model.embed_tokens.weight.dtype), + device=model_device, + dtype=vllm_config.model_config.dtype, ) else: self.per_layer_embeddings = None diff --git a/vllm/tokenizers/registry.py b/vllm/tokenizers/registry.py index 8f16e6d28f43..9d1d812c2d19 100644 --- a/vllm/tokenizers/registry.py +++ b/vllm/tokenizers/registry.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import copy from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path @@ -15,6 +16,7 @@ from vllm.transformers_utils.gguf_utils import ( check_gguf_file, get_gguf_file_path_from_hf, + get_gguf_tokenizer_special_ids, is_gguf, is_remote_gguf, split_remote_gguf, @@ -91,6 +93,37 @@ def load_tokenizer(self, tokenizer_mode: str, *args, **kwargs) -> TokenizerLike: ) +def _maybe_patch_gemma4_gguf_tokenizer( + tokenizer: TokenizerLike, + model: str | Path, + model_type: str | None, +) -> TokenizerLike: + if model_type != "gemma4" or not check_gguf_file(model): + return tokenizer + + special_ids = get_gguf_tokenizer_special_ids(model) + if not special_ids: + return tokenizer + + patched_tokenizer = copy.copy(tokenizer) + token_attrs = { + "padding_token_id": "pad_token", + "bos_token_id": "bos_token", + "eos_token_id": "eos_token", + "unknown_token_id": "unk_token", + } + for id_attr, token_attr in token_attrs.items(): + token_id = special_ids.get(id_attr) + if token_id is None: + continue + token = patched_tokenizer.convert_ids_to_tokens(token_id) + if token is None: + continue + setattr(patched_tokenizer, token_attr, token) + + return patched_tokenizer + + def resolve_tokenizer_args( tokenizer_name: str | Path, *args, @@ -257,7 +290,7 @@ def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs): if model_config.skip_tokenizer_init: return None - return cached_get_tokenizer( + tokenizer = cached_get_tokenizer( model_config.tokenizer, runner_type=model_config.runner_type, tokenizer_mode=model_config.tokenizer_mode, @@ -265,3 +298,8 @@ def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs): trust_remote_code=model_config.trust_remote_code, **kwargs, ) + return _maybe_patch_gemma4_gguf_tokenizer( + tokenizer, + model_config.model, + getattr(model_config.hf_config, "model_type", None), + ) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index cf2676a8f724..f1eabaa1fff8 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -585,8 +585,14 @@ def maybe_override_with_speculators( Tuple of (resolved_model, resolved_tokenizer, speculative_config) """ if check_gguf_file(model): - kwargs["gguf_file"] = Path(model).name gguf_model_repo = Path(model).parent + # Prefer sibling config.json when present instead of forcing + # transformers to parse GGUF metadata directly. This keeps local GGUF + # models loadable even when the installed transformers GGUF parser + # lags behind the architecture but the repository config.json is + # already available. + if not file_or_path_exists(gguf_model_repo, HF_CONFIG_NAME, revision=revision): + kwargs["gguf_file"] = Path(model).name elif is_remote_gguf(model): repo_id, _ = split_remote_gguf(model) gguf_model_repo = Path(repo_id) @@ -638,9 +644,15 @@ def get_config( _is_remote_gguf = is_remote_gguf(model) if _is_gguf: if check_gguf_file(model): - # Local GGUF file - kwargs["gguf_file"] = Path(model).name - model = Path(model).parent + # Local GGUF file. Prefer sibling config.json when available + # rather than routing through transformers' GGUF checkpoint + # loader, which may not support the architecture yet. + gguf_model_dir = Path(model).parent + if not file_or_path_exists( + gguf_model_dir, HF_CONFIG_NAME, revision=revision + ): + kwargs["gguf_file"] = Path(model).name + model = gguf_model_dir elif _is_remote_gguf: # Remote GGUF - extract repo_id from repo_id:quant_type format # The actual GGUF file will be downloaded later by GGUFModelLoader diff --git a/vllm/transformers_utils/gguf_utils.py b/vllm/transformers_utils/gguf_utils.py index 7708378ee13b..2c63f73dfe17 100644 --- a/vllm/transformers_utils/gguf_utils.py +++ b/vllm/transformers_utils/gguf_utils.py @@ -3,7 +3,6 @@ """GGUF utility functions.""" from functools import cache -from os import PathLike from pathlib import Path import gguf @@ -19,8 +18,16 @@ logger = init_logger(__name__) +_GGUF_TOKENIZER_SPECIAL_ID_FIELDS = { + "bos_token_id": "tokenizer.ggml.bos_token_id", + "eos_token_id": "tokenizer.ggml.eos_token_id", + "unknown_token_id": "tokenizer.ggml.unknown_token_id", + "padding_token_id": "tokenizer.ggml.padding_token_id", +} + + @cache -def check_gguf_file(model: str | PathLike) -> bool: +def check_gguf_file(model: str | Path) -> bool: """Check if the file is a GGUF model.""" model = Path(model) if not model.is_file(): @@ -170,6 +177,25 @@ def detect_gguf_multimodal(model: str) -> Path | None: return None +@cache +def get_gguf_tokenizer_special_ids(model: str | Path) -> dict[str, int]: + """Read tokenizer special token ids embedded in a local GGUF file.""" + if not check_gguf_file(model): + return {} + + reader = gguf.GGUFReader(str(model)) + special_ids: dict[str, int] = {} + for key, field_name in _GGUF_TOKENIZER_SPECIAL_ID_FIELDS.items(): + field = reader.get_field(field_name) + if field is None: + continue + try: + special_ids[key] = int(field.parts[-1]) + except (TypeError, ValueError): + logger.warning("Failed to parse GGUF tokenizer field %s", field_name) + return special_ids + + def extract_vision_config_from_gguf(mmproj_path: str) -> "SiglipVisionConfig | None": """Extract vision config parameters from mmproj.gguf metadata. diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 0e241f6abfd1..b4b5d30f65ac 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import importlib import inspect from functools import lru_cache @@ -24,6 +25,7 @@ from typing_extensions import TypeVar from vllm.logger import init_logger +from vllm.tokenizers.registry import _maybe_patch_gemma4_gguf_tokenizer from vllm.transformers_utils import processors from vllm.transformers_utils.gguf_utils import is_gguf from vllm.transformers_utils.repo_utils import get_hf_file_to_dict @@ -352,13 +354,24 @@ def cached_processor_from_config( model = model_config.model revision = model_config.revision - return cached_get_processor_without_dynamic_kwargs( + processor = cached_get_processor_without_dynamic_kwargs( model, revision=revision, trust_remote_code=model_config.trust_remote_code, processor_cls=processor_cls, # type: ignore[arg-type] **_merge_mm_kwargs(model_config, processor_cls, **kwargs), ) + tokenizer = getattr(processor, "tokenizer", None) + if tokenizer is not None: + tokenizer = _maybe_patch_gemma4_gguf_tokenizer( + tokenizer, + model_config.model, + getattr(model_config.hf_config, "model_type", None), + ) + if tokenizer is not processor.tokenizer: + processor = copy.copy(processor) + processor.tokenizer = tokenizer + return processor def get_feature_extractor(