diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index e46934b9caeb..98e7572981de 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -263,10 +263,15 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ * ✅︎ - * `Gemma2ForCausalLM` - * Gemma2 + * Gemma 2 * `google/gemma-2-9b`, `google/gemma-2-27b`, etc. * ✅︎ * ✅︎ +- * `Gemma3ForCausalLM` + * Gemma 3 + * `google/gemma-3-1b-it`, etc. + * ✅︎ + * ✅︎ - * `GlmForCausalLM` * GLM-4 * `THUDM/glm-4-9b-chat-hf`, etc. @@ -504,7 +509,7 @@ you should explicitly specify the task type to ensure that the model is used in * * - * `Gemma2Model` - * Gemma2-based + * Gemma 2-based * `BAAI/bge-multilingual-gemma2`, etc. * * ✅︎ @@ -752,6 +757,13 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `Gemma3ForConditionalGeneration` + * Gemma 3 + * T + I+ + * `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. + * ✅︎ + * ✅︎ + * ✅︎\* - * `GLM4VForCausalLM`^ * GLM-4V * T + I @@ -937,6 +949,31 @@ For more details, please see: To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`. ::: +:::{note} +To use Gemma3 series models, you have to install Hugging Face Transformers library from source via +`pip install git+https://github.com/huggingface/transformers`. +The earliest commit that supports this is [`50d3530aa04e7a7d003e6b255a98f79fd0447357`](https://github.com/huggingface/transformers/commit/50d3530aa04e7a7d003e6b255a98f79fd0447357). + +Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. +However, there are differences in how they handle text + image inputs: + +V0 correctly implements the model's attention pattern: +- Uses bidirectional attention between the image tokens corresponding to the same image +- Uses causal attention for other tokens +- Implemented via (naive) PyTorch SDPA with masking tensors +- Note: May use significant memory for long prompts with image + +V1 currently uses a simplified attention pattern: +- Uses causal attention for all tokens, including image tokens +- Generates reasonable outputs but does not match the original model's attention for text + image inputs +- Will be updated in the future to support the correct behavior + +This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. + +Additionally, vLLM's current Gemma 3 implementation does not support the pan-and-scan image pre-processing algorithm, which helps handle images with skewed aspect ratios by intelligently cropping them into multiple views. +Without this feature, model performance may degrade when processing images that deviate significantly from square dimensions. +::: + ### Pooling Models See [this page](pooling-models) for more information on how to use pooling models. diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 716c31b96ed1..39acab4765a3 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -118,6 +118,23 @@ def run_fuyu(questions: list[str], modality: str): return llm, prompts, stop_token_ids +# Gemma 3 +def run_gemma3(questions: list[str], modality: str): + assert modality == "image" + model_name = "google/gemma-3-4b-it" + + llm = LLM(model=model_name, + max_model_len=2048, + max_num_seqs=2, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) + + prompts = [("user\n" + f"{question}\n" + "model\n") for question in questions] + stop_token_ids = None + return llm, prompts, stop_token_ids + + # GLM-4v def run_glm4v(questions: list[str], modality: str): assert modality == "image" @@ -405,7 +422,7 @@ def run_mllama(questions: list[str], modality: str): "type": "image" }, { "type": "text", - "text": f"{question}" + "text": question }] }] for question in questions] prompts = tokenizer.apply_chat_template(messages, @@ -664,6 +681,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str): "deepseek_vl_v2": run_deepseek_vl2, "florence2": run_florence2, "fuyu": run_fuyu, + "gemma3": run_gemma3, "glm4v": run_glm4v, "h2ovl_chat": run_h2ovl, "idefics3": run_idefics3, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 6fdd4383c1a1..4963e6a8c4e7 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -80,6 +80,42 @@ def load_deepseek_vl2(question: str, image_urls: list[str]): ) +def load_gemma3(question, image_urls: list[str]) -> ModelRequestData: + model_name = "google/gemma-3-4b-it" + + llm = LLM(model=model_name, + max_model_len=8192, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [{ + "role": + "user", + "content": [ + *placeholders, + { + "type": "text", + "text": question + }, + ], + }] + + processor = AutoProcessor.from_pretrained(model_name) + + prompt = processor.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=None, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "h2oai/h2ovl-mississippi-800m" @@ -496,6 +532,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData: model_example_map = { "aria": load_aria, "deepseek_vl_v2": load_deepseek_vl2, + "gemma3": load_gemma3, "h2ovl_chat": load_h2ovl, "idefics3": load_idefics3, "internvl_chat": load_internvl, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index e64b703cc520..467114eedb01 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -162,6 +162,7 @@ def _test_processing_correctness( "deepseek-ai/deepseek-vl2-tiny", "microsoft/Florence-2-base", "adept/fuyu-8b", + "google/gemma-3-4b-it", "THUDM/glm-4v-9b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", diff --git a/tests/models/registry.py b/tests/models/registry.py index a7a88d199047..eadbd7e6f492 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -124,6 +124,8 @@ def check_available_online( "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), + "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it", + min_transformers_version="4.50"), "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), "GPT2LMHeadModel": _HfExamplesInfo("gpt2"), "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"), @@ -241,6 +243,8 @@ def check_available_online( "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), + "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it", + min_transformers_version="4.50"), "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 diff --git a/vllm/config.py b/vllm/config.py index a0f30d0e7b77..2ee45f1837c4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -350,10 +350,11 @@ def __init__( if self.enforce_eager is None: self.enforce_eager = False + interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"] sliding_window = getattr(self.hf_text_config, "sliding_window", None) has_interleaved_attention = (sliding_window is not None) and ( isinstance(sliding_window, list) or - (self.hf_text_config.model_type in ["gemma2", "cohere2"])) + (self.hf_text_config.model_type in interleaved_attn_models)) if (not self.disable_sliding_window and has_interleaved_attention): if (backend := @@ -2501,11 +2502,11 @@ def _get_and_verify_dtype( dtype = dtype.lower() if dtype == "auto": if config_dtype == torch.float32: - if config.model_type == "gemma2": + if config.model_type in ("gemma2", "gemma3", "gemma3_text"): logger.info( - "For Gemma 2, we downcast float32 to bfloat16 instead " - "of float16 by default. Please specify `dtype` if you " - "want to use float16.") + "For Gemma 2 and 3, we downcast float32 to bfloat16 " + "instead of float16 by default. Please specify `dtype` " + "if you want to use float16.") torch_dtype = torch.bfloat16 else: # Following the common practice, we use float16 for float32 @@ -2637,7 +2638,9 @@ def _get_and_verify_max_len( derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) - if rope_scaling is not None: + # NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE + # scaling, so we skip applying the scaling factor again. + if rope_scaling is not None and "gemma3" not in hf_config.model_type: # No need to consider "type" key because of patch_rope_scaling when # loading HF config rope_type = rope_scaling["rope_type"] diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index b51ade17def6..61f21482f707 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -433,6 +433,8 @@ def _placeholder_str(self, modality: ModalityStr, return "" if model_type == "aria": return "<|fim_prefix|><|img|><|fim_suffix|>" + if model_type == "gemma3": + return "" raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py new file mode 100644 index 000000000000..f1ecf7fa821d --- /dev/null +++ b/vllm/model_executor/models/gemma3.py @@ -0,0 +1,533 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM team. +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Gemma3TextConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Gemma3MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_activation: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_activation != "gelu_pytorch_tanh": + raise ValueError( + "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_act` and `hidden_activation` to " + "`gelu_pytorch_tanh`.") + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma3Attention(nn.Module): + + def __init__(self, + config: Gemma3TextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.query_pre_attn_scalar**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + + self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + # TODO(woosuk): Add reference to the original HF implementation. + layer_idx = extract_layer_index(prefix) + self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) + # Initialize the rotary embedding. + if self.is_sliding: + # Local attention. Override the values in config.json. + self.rope_theta = config.rope_local_base_freq + self.rope_scaling = {"rope_type": "default"} + self.sliding_window = config.interleaved_sliding_window + else: + # Global attention. Use the values in config.json. + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.sliding_window = None + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + rope_scaling=self.rope_scaling, + ) + + # Initialize the attention. + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=self.sliding_window, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + k = k.flatten(-2, -1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + + if not kwargs.get("has_images", False): + # Fast path for text-only inputs. The performance for the text-only + # inputs are not affected by the naive attention below. + output, _ = self.o_proj(attn_output) + return output + + # NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens + # that correspond to the same image while using causal attention + # otherwise. Current attention backends cannot handle this pattern, so + # we temporarily use a naive attention implementation with mask tensors. + + # We intentionally keep the attention backend as-is and only override + # `attn_output` with the naive implementation's output. This minimizes + # changes to existing model runners and attention backends. The call to + # `self.attn(q, k, v)` is only used to populate the KV cache - its + # output is discarded and overwritten below. While this duplicates + # computation, it maintains compatibility. + # TODO(woosuk): Optimize by implementing custom attention kernels. + attn_output = self.naive_attn_with_masks(q, + k, + v, + out=attn_output, + **kwargs) + output, _ = self.o_proj(attn_output) + return output + + def naive_attn_with_masks( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # NOTE(woosuk): As described in the comment above, this code is not + # meant to be performant. It is only meant to be correct. + q = q.view(-1, self.num_heads, self.head_dim) + # Expand the key and value to handle GQA. + num_queries_per_kv = self.num_heads // self.num_kv_heads + k = k.view(-1, self.num_kv_heads, self.head_dim) + k = k.repeat_interleave(num_queries_per_kv, dim=-2) + v = v.view(-1, self.num_kv_heads, self.head_dim) + v = v.repeat_interleave(num_queries_per_kv, dim=-2) + + if self.is_sliding: + attn_masks = kwargs["local_attn_masks"] + else: + attn_masks = kwargs["global_attn_masks"] + + seq_lens = kwargs["seq_lens"] + start_idx = 0 + for seq_len, attn_mask in zip(seq_lens, attn_masks): + end_idx = start_idx + seq_len + query = q[start_idx:end_idx].unsqueeze(0) + key = k[start_idx:end_idx].unsqueeze(0) + value = v[start_idx:end_idx].unsqueeze(0) + + # Transpose. + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask, + self.scaling, + ) + output = output.transpose(1, 2).flatten(-2, -1) + out[start_idx:end_idx] = output + start_idx = end_idx + return out + + +class Gemma3DecoderLayer(nn.Module): + + def __init__( + self, + config: Gemma3TextConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Gemma3Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + attn_logits_soft_cap=None, + prefix=f"{prefix}.self_attn", + ) + self.hidden_size = config.hidden_size + self.mlp = Gemma3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Gemma3Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma3DecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", torch.tensor(normalizer)) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + # NOTE(woosuk): Only apply the normalizer to the output of + # vocab embedding. Don't apply it to the vision embedding. + return self.embed_tokens(input_ids) * self.normalizer + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + **kwargs, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + logger.warning( + "Some weights are not initialized from checkpoints: %s", + unloaded_params) + return loaded_params + + +class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + del lora_config # Unused. + super().__init__() + self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.quant_config = quant_config + self.model = Gemma3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.logits_processor = LogitsProcessor( + config.vocab_size, soft_cap=config.final_logit_softcapping) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py new file mode 100644 index 000000000000..121aee51786b --- /dev/null +++ b/vllm/model_executor/models/gemma3_mm.py @@ -0,0 +1,425 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set, + Tuple, TypedDict, Union) + +import torch +from torch import nn +from transformers import BatchFeature, Gemma3Config, ProcessorMixin + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsMultiModal, SupportsPP +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) + +logger = init_logger(__name__) + + +class Gemma3ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +Gemma3ImageInputs = Gemma3ImagePixelInputs + + +class Gemma3ProcessingInfo(BaseProcessingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + hf_config = self.ctx.get_hf_config() + return {"image": hf_config.mm_tokens_per_image} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[ProcessorMixin], + ) -> int: + hf_config = self.ctx.get_hf_config() + return hf_config.mm_tokens_per_image + + def get_image_size_with_most_features(self) -> ImageSize: + # Result in the max possible feature size (h:w = 16:1) + return ImageSize(height=8000, width=50) + + +class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + tokenizer = self.info.get_tokenizer() + boi_token = tokenizer.boi_token + + num_images = mm_counts.get("image", 0) + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + return ProcessorInputs( + prompt_text=" ".join([boi_token] * num_images), + mm_data=mm_data, + ) + + +class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + # TODO(woosuk): Support pan-and-scan. + img_kwargs = mm_kwargs.get("images_kwargs", {}) + img_kwargs["do_pan_and_scan"] = False + mm_kwargs["images_kwargs"] = img_kwargs + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + tokenizer = self.info.get_tokenizer() + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + hf_config = self.info.get_hf_config() + + boi_token = tokenizer.boi_token + image_token = tokenizer.image_token + mm_tokens_per_image = hf_config.mm_tokens_per_image + image_tokens_expanded = "".join([image_token] * mm_tokens_per_image) + + def get_replacement_gemma3(item_idx: int): + return PromptUpdateDetails( + full=hf_processor.full_image_sequence, + features=image_tokens_expanded, + ) + + return [ + PromptReplacement( + modality="image", + target=boi_token, + replacement=get_replacement_gemma3, + ) + ] + + +class Gemma3MultiModalProjector(nn.Module): + + def __init__(self, config: Gemma3Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros(config.vision_config.hidden_size, + config.text_config.hidden_size)) + + self.mm_soft_emb_norm = GemmaRMSNorm( + config.vision_config.hidden_size, + eps=config.vision_config.layer_norm_eps) + + self.patches_per_image = int(config.vision_config.image_size // + config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, + stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, + self.patches_per_image) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul( + normed_vision_outputs, self.mm_input_projection_weight) + return projected_vision_outputs.type_as(vision_outputs) + + +@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor, + info=Gemma3ProcessingInfo, + dummy_inputs=Gemma3DummyInputsBuilder) +class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + self.sliding_window = config.text_config.interleaved_sliding_window + + self.vision_tower = SiglipVisionModel(config.vision_config, + quant_config, + prefix=maybe_prefix( + prefix, "vision_tower")) + self.multi_modal_projector = Gemma3MultiModalProjector(config) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Gemma3ForCausalLM"], + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.language_model.logits_processor.scale *= logit_scale + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @property + def sampler(self): + return self.language_model.sampler + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + if d.shape != expected_dims: + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_dims}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Gemma3ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + assert image_embeds is None, "Gemma3 does not support image_embeds." + if pixel_values is None: + return None + + if not isinstance(pixel_values, (torch.Tensor, list[torch.Tensor])): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + return Gemma3ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + target_dtype = vision_tower.get_input_embeddings().weight.dtype + image_features = vision_tower(pixel_values.to(dtype=target_dtype)) + return image_features + + def _process_image_input( + self, + image_input: Gemma3ImageInputs, + ) -> torch.Tensor: + assert self.vision_tower is not None + pixel_values = image_input["data"] + vision_outputs = self._image_pixels_to_features( + self.vision_tower, + pixel_values, + ) + return self.multi_modal_projector(vision_outputs) + + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + if multimodal_embeddings is None: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + else: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_index) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + if vision_embeddings is not None: + kwargs = self.prepare_attn_masks( + input_ids, + positions, + mask_dtype=vision_embeddings.dtype, + **kwargs) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs) + + return hidden_states + + def prepare_attn_masks( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mask_dtype: torch.dtype, + **kwargs, + ): + kwargs["has_images"] = True + # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. + # This is a HACK. Fix this. + start_idices = (positions == 0).cpu().nonzero() + num_seqs = len(start_idices) + seq_lens = [] + for i in range(num_seqs): + start_idx = start_idices[i].item() + if i < num_seqs - 1: + end_idx = start_idices[i + 1].item() + else: + end_idx = len(input_ids) + seq_lens.append(end_idx - start_idx) + kwargs["seq_lens"] = seq_lens + + global_attn_masks = [] + local_attn_masks = [] + start_idx = 0 + for seq_len in seq_lens: + end_idx = start_idx + seq_len + input_token_ids = input_ids[start_idx:end_idx] + start_idx = end_idx + # Create a global causal mask. + global_attn_mask = torch.empty( + 1, + 1, + seq_len, + seq_len, + dtype=mask_dtype, + device=input_ids.device, + ) + global_attn_mask.fill_(float("-inf")) + # Fill the lower triangle with 0. + global_attn_mask = global_attn_mask.triu(diagonal=1) + + # Consider the bidirectional attention between image tokens. + img_mask = torch.zeros_like(global_attn_mask) + img_pos = (input_token_ids == self.config.image_token_index) + img_mask[:, :, :, img_pos] += 1 + img_mask[:, :, img_pos, :] += 1 + global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) + global_attn_masks.append(global_attn_mask) + + # Create a local causal mask with sliding window (1024). + local_attn_mask = torch.ones_like(global_attn_mask) + local_attn_mask = torch.tril(local_attn_mask, + diagonal=-self.sliding_window) + local_attn_mask = torch.where(local_attn_mask == 0, + global_attn_mask, float("-inf")) + local_attn_masks.append(local_attn_mask) + kwargs["global_attn_masks"] = global_attn_masks + kwargs["local_attn_masks"] = local_attn_masks + return kwargs + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 74160e2d9ee4..5dd3aa2973cd 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -53,6 +53,7 @@ "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), + "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), @@ -161,6 +162,7 @@ "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), + "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"),