diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 02267b8f6825..23aa32101791 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from functools import cached_property from typing import Any, Union import torch @@ -168,6 +169,12 @@ def get_quant_method( return None +class BitsAndBytesWeightParameter(torch.nn.Parameter): + @cached_property + def dtype(self) -> torch.dtype: + return torch.get_default_dtype() + + def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]): # Split the prefix into its dot-separated components components = prefix.split(".") @@ -246,7 +253,7 @@ def create_qweight_for_4bit(): "The input size is not aligned with the quantized weight shape." ) - qweight = torch.nn.Parameter( + qweight = BitsAndBytesWeightParameter( torch.empty(total_size // quant_ratio, 1, dtype=torch.uint8), requires_grad=False, ) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 4d9ae267e084..bc2504b09c5b 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -744,6 +744,29 @@ def _stack_quantization_states( stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[ non_stacked_param_name ] + + # repeat k_proj for v_proj for k_eq_v models (e.g. Gemma4) + config = getattr(model, "config", None) + if config is not None: + text_config = config.get_text_config() + if getattr(text_config, "attention_k_eq_v", False): + shard_packed = { + name + for name, subs in self.modules_mapping.packed_mapping.items() + if len(subs) == 3 + } + for param_name, shards in stacked_quant_state_dict.items(): + is_target = ( + isinstance(shards, dict) + and len(shards) == 2 + and any( + param_name.endswith(f"{p}.weight") for p in shard_packed + ) + ) + if is_target: + assert 1 in shards and 2 not in shards + shards[2] = shards[1] + return stacked_quant_state_dict def _bind_quant_states_to_params( diff --git a/vllm/model_executor/models/gemma4_mm.py b/vllm/model_executor/models/gemma4_mm.py index b546040b7414..43e27054596d 100644 --- a/vllm/model_executor/models/gemma4_mm.py +++ b/vllm/model_executor/models/gemma4_mm.py @@ -16,7 +16,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal +from typing import TYPE_CHECKING, Annotated, Any, Literal import numpy as np import torch @@ -41,6 +41,7 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.transformers.utils import recursive_replace_linear from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalFieldConfig, @@ -71,6 +72,7 @@ SupportsLoRA, SupportsMultiModal, SupportsPP, + SupportsQuant, ) from .utils import ( AutoWeightsLoader, @@ -79,6 +81,9 @@ maybe_prefix, ) +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationConfig + logger = init_logger(__name__) # Video constants — match transformers Gemma4VideoProcessor defaults. @@ -872,6 +877,9 @@ def __init__( self, multimodal_config: Gemma4VisionConfig | Gemma4AudioConfig, text_config: Gemma4TextConfig, + *, + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ): super().__init__() @@ -895,6 +903,8 @@ def __init__( embedding_dim, self.text_hidden_size, bias=False, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "embedding_projection"), ) def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: @@ -917,6 +927,7 @@ def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: class Gemma4ForConditionalGeneration( nn.Module, SupportsMultiModal, + SupportsQuant, SupportsPP, SupportsLoRA, SupportsEagle3, @@ -936,11 +947,14 @@ class Gemma4ForConditionalGeneration( # Maps checkpoint prefixes to vLLM module paths. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ + # vision tower + "model.vision_tower": "vision_tower", + "model.embed_vision": "embed_vision", + # audio tower + "model.audio_tower.": "audio_tower.", "model.embed_audio.": "embed_audio.", - "model.embed_vision.": "embed_vision.", + # backbone "model.language_model.": "language_model.model.", - "model.vision_tower.": "vision_tower.", - "model.audio_tower.": "audio_tower.", "lm_head.": "language_model.lm_head.", "model": "language_model.model", } @@ -959,7 +973,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): with self._mark_tower_model(vllm_config, {"image", "video"}): self.vision_tower = AutoModel.from_config(config=config.vision_config) self.embed_vision = Gemma4MultimodalEmbedder( - config.vision_config, config.text_config + config.vision_config, + config.text_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "embed_vision"), + ) + recursive_replace_linear( + self.vision_tower, + quant_config, + prefix=maybe_prefix(prefix, "vision_tower"), ) # ---- Audio tower (variants with audio_config) ---- @@ -972,7 +994,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # position embeddings, softcap, gradient_clipping). self.audio_tower.post_init() self.embed_audio = Gemma4MultimodalEmbedder( - config.audio_config, config.text_config + config.audio_config, + config.text_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "embed_audio"), + ) + recursive_replace_linear( + self.audio_tower, + quant_config, + prefix=maybe_prefix(prefix, "audio_tower"), ) else: self.audio_tower = None @@ -1153,6 +1183,7 @@ def _process_image_input( vt = self.vision_tower vision_cfg = self.config.vision_config pooling_k2 = vision_cfg.pooling_kernel_size**2 + target_dtype = self.language_model.model.embed_tokens.weight.dtype # Concurrent requests with different image resolutions may # arrive as a list of per-image tensors, while same-resolution @@ -1193,7 +1224,11 @@ def _process_image_input( ) pad_tensor = (pp_tensor == -1).all(dim=-1) - inputs_embeds = vt.patch_embedder(pv_tensor, pp_tensor, pad_tensor) + inputs_embeds = vt.patch_embedder( + pv_tensor, + pp_tensor, + pad_tensor, + ).to(target_dtype) encoder_outputs = vt.encoder( inputs_embeds=inputs_embeds, attention_mask=~pad_tensor, @@ -1230,7 +1265,9 @@ def _process_image_input( all_valid_states[orig_idx] = valid_states valid_lens[orig_idx] = valid_states.shape[0] - target_dtype = self.embed_vision.embedding_projection.weight.dtype + # Use embed_tokens dtype as compute dtype; embedding_projection.weight + # may be uint8 under BnB 4-bit, which would corrupt the cast. + target_dtype = self.language_model.model.embed_tokens.weight.dtype # Project all images in a single batched call. flat_valid_states = torch.cat(all_valid_states, dim=0).to(target_dtype) @@ -1273,7 +1310,7 @@ def _process_video_input( vt = self.vision_tower vision_cfg = self.config.vision_config pooling_k2 = vision_cfg.pooling_kernel_size**2 - target_dtype = self.embed_vision.embedding_projection.weight.dtype + target_dtype = self.language_model.model.embed_tokens.weight.dtype if isinstance(frame_counts, torch.Tensor): fc_list = frame_counts.tolist() @@ -1301,7 +1338,11 @@ def _process_video_input( pp_chunk = pixel_position_ids[i : i + max_batch_size] pad_chunk = padding_positions[i : i + max_batch_size] - inputs_embeds = vt.patch_embedder(pv_chunk, pp_chunk, pad_chunk) + inputs_embeds = vt.patch_embedder( + pv_chunk, + pp_chunk, + pad_chunk, + ).to(target_dtype) encoder_outputs = vt.encoder( inputs_embeds=inputs_embeds, attention_mask=~pad_chunk, diff --git a/vllm/model_executor/models/transformers/utils.py b/vllm/model_executor/models/transformers/utils.py index 04d6de28efd0..dbf0a084f783 100644 --- a/vllm/model_executor/models/transformers/utils.py +++ b/vllm/model_executor/models/transformers/utils.py @@ -32,6 +32,7 @@ ReplicatedLinear, RowParallelLinear, ) +from vllm.model_executor.models.utils import maybe_prefix from vllm.transformers_utils.config import is_rope_parameters_nested if TYPE_CHECKING: @@ -227,6 +228,34 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: return RMSNorm(**kwargs) +def recursive_replace_linear( + model: nn.Module, + quant_config: "QuantizationConfig | None", + prefix: str = "", +): + """Recursively replace linear modules in the model as needed.""" + + def _recursive_replace(module: nn.Module, prefix: str): + for child_name, child_module in module.named_children(): + new_module = child_module + qual_name = maybe_prefix(prefix, child_name) + # Replace modules as needed + if isinstance(child_module, nn.Linear): + style = "replicate" + new_module = replace_linear_class( + child_module, + style, + quant_config, + prefix=qual_name, + ) + else: + _recursive_replace(child_module, prefix=qual_name) + if new_module is not child_module: + setattr(module, child_name, new_module) + + _recursive_replace(model, prefix=prefix) + + def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): logger.debug("%s: %s -> %s", name, old_module, new_module)