Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from functools import cached_property
from typing import Any, Union

import torch
Expand Down Expand Up @@ -168,6 +169,12 @@ def get_quant_method(
return None


class BitsAndBytesWeightParameter(torch.nn.Parameter):
@cached_property
def dtype(self) -> torch.dtype:
return torch.get_default_dtype()


def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
# Split the prefix into its dot-separated components
components = prefix.split(".")
Expand Down Expand Up @@ -246,7 +253,7 @@ def create_qweight_for_4bit():
"The input size is not aligned with the quantized weight shape."
)

qweight = torch.nn.Parameter(
qweight = BitsAndBytesWeightParameter(
torch.empty(total_size // quant_ratio, 1, dtype=torch.uint8),
requires_grad=False,
)
Expand Down
23 changes: 23 additions & 0 deletions vllm/model_executor/model_loader/bitsandbytes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,29 @@ def _stack_quantization_states(
stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[
non_stacked_param_name
]

# repeat k_proj for v_proj for k_eq_v models (e.g. Gemma4)
config = getattr(model, "config", None)
if config is not None:
text_config = config.get_text_config()
if getattr(text_config, "attention_k_eq_v", False):
shard_packed = {
name
for name, subs in self.modules_mapping.packed_mapping.items()
if len(subs) == 3
}
for param_name, shards in stacked_quant_state_dict.items():
is_target = (
isinstance(shards, dict)
and len(shards) == 2
and any(
param_name.endswith(f"{p}.weight") for p in shard_packed
)
)
if is_target:
assert 1 in shards and 2 not in shards
shards[2] = shards[1]

return stacked_quant_state_dict

def _bind_quant_states_to_params(
Expand Down
61 changes: 51 additions & 10 deletions vllm/model_executor/models/gemma4_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal
from typing import TYPE_CHECKING, Annotated, Any, Literal

import numpy as np
import torch
Expand All @@ -41,6 +41,7 @@
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.transformers.utils import recursive_replace_linear
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalFieldConfig,
Expand Down Expand Up @@ -71,6 +72,7 @@
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsQuant,
)
from .utils import (
AutoWeightsLoader,
Expand All @@ -79,6 +81,9 @@
maybe_prefix,
)

if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationConfig

logger = init_logger(__name__)

# Video constants — match transformers Gemma4VideoProcessor defaults.
Expand Down Expand Up @@ -872,6 +877,9 @@ def __init__(
self,
multimodal_config: Gemma4VisionConfig | Gemma4AudioConfig,
text_config: Gemma4TextConfig,
*,
quant_config: "QuantizationConfig | None" = None,
prefix: str = "",
):
super().__init__()

Expand All @@ -895,6 +903,8 @@ def __init__(
embedding_dim,
self.text_hidden_size,
bias=False,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "embedding_projection"),
)

def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
Expand All @@ -917,6 +927,7 @@ def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
class Gemma4ForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsQuant,
SupportsPP,
SupportsLoRA,
SupportsEagle3,
Expand All @@ -936,11 +947,14 @@ class Gemma4ForConditionalGeneration(
# Maps checkpoint prefixes to vLLM module paths.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# vision tower
"model.vision_tower": "vision_tower",
"model.embed_vision": "embed_vision",
# audio tower
"model.audio_tower.": "audio_tower.",
"model.embed_audio.": "embed_audio.",
"model.embed_vision.": "embed_vision.",
# backbone
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.audio_tower.": "audio_tower.",
"lm_head.": "language_model.lm_head.",
"model": "language_model.model",
}
Expand All @@ -959,7 +973,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.embed_vision = Gemma4MultimodalEmbedder(
config.vision_config, config.text_config
config.vision_config,
config.text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "embed_vision"),
)
recursive_replace_linear(
self.vision_tower,
quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)

# ---- Audio tower (variants with audio_config) ----
Expand All @@ -972,7 +994,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# position embeddings, softcap, gradient_clipping).
self.audio_tower.post_init()
self.embed_audio = Gemma4MultimodalEmbedder(
config.audio_config, config.text_config
config.audio_config,
config.text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "embed_audio"),
)
recursive_replace_linear(
self.audio_tower,
quant_config,
prefix=maybe_prefix(prefix, "audio_tower"),
)
else:
self.audio_tower = None
Expand Down Expand Up @@ -1153,6 +1183,7 @@ def _process_image_input(
vt = self.vision_tower
vision_cfg = self.config.vision_config
pooling_k2 = vision_cfg.pooling_kernel_size**2
target_dtype = self.language_model.model.embed_tokens.weight.dtype

# Concurrent requests with different image resolutions may
# arrive as a list of per-image tensors, while same-resolution
Expand Down Expand Up @@ -1193,7 +1224,11 @@ def _process_image_input(
)
pad_tensor = (pp_tensor == -1).all(dim=-1)

inputs_embeds = vt.patch_embedder(pv_tensor, pp_tensor, pad_tensor)
inputs_embeds = vt.patch_embedder(
pv_tensor,
pp_tensor,
pad_tensor,
).to(target_dtype)
encoder_outputs = vt.encoder(
inputs_embeds=inputs_embeds,
attention_mask=~pad_tensor,
Expand Down Expand Up @@ -1230,7 +1265,9 @@ def _process_image_input(
all_valid_states[orig_idx] = valid_states
valid_lens[orig_idx] = valid_states.shape[0]

target_dtype = self.embed_vision.embedding_projection.weight.dtype
# Use embed_tokens dtype as compute dtype; embedding_projection.weight
# may be uint8 under BnB 4-bit, which would corrupt the cast.
target_dtype = self.language_model.model.embed_tokens.weight.dtype

# Project all images in a single batched call.
flat_valid_states = torch.cat(all_valid_states, dim=0).to(target_dtype)
Expand Down Expand Up @@ -1273,7 +1310,7 @@ def _process_video_input(
vt = self.vision_tower
vision_cfg = self.config.vision_config
pooling_k2 = vision_cfg.pooling_kernel_size**2
target_dtype = self.embed_vision.embedding_projection.weight.dtype
target_dtype = self.language_model.model.embed_tokens.weight.dtype

if isinstance(frame_counts, torch.Tensor):
fc_list = frame_counts.tolist()
Expand Down Expand Up @@ -1301,7 +1338,11 @@ def _process_video_input(
pp_chunk = pixel_position_ids[i : i + max_batch_size]
pad_chunk = padding_positions[i : i + max_batch_size]

inputs_embeds = vt.patch_embedder(pv_chunk, pp_chunk, pad_chunk)
inputs_embeds = vt.patch_embedder(
pv_chunk,
pp_chunk,
pad_chunk,
).to(target_dtype)
encoder_outputs = vt.encoder(
inputs_embeds=inputs_embeds,
attention_mask=~pad_chunk,
Expand Down
29 changes: 29 additions & 0 deletions vllm/model_executor/models/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.transformers_utils.config import is_rope_parameters_nested

if TYPE_CHECKING:
Expand Down Expand Up @@ -227,6 +228,34 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
return RMSNorm(**kwargs)


def recursive_replace_linear(
model: nn.Module,
quant_config: "QuantizationConfig | None",
prefix: str = "",
):
"""Recursively replace linear modules in the model as needed."""

def _recursive_replace(module: nn.Module, prefix: str):
for child_name, child_module in module.named_children():
new_module = child_module
qual_name = maybe_prefix(prefix, child_name)
# Replace modules as needed
if isinstance(child_module, nn.Linear):
style = "replicate"
new_module = replace_linear_class(
child_module,
style,
quant_config,
prefix=qual_name,
)
else:
_recursive_replace(child_module, prefix=qual_name)
if new_module is not child_module:
setattr(module, child_name, new_module)

_recursive_replace(model, prefix=prefix)


def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)

Expand Down
Loading