diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index e46934b9caeb..98e7572981de 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -263,10 +263,15 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
- * `Gemma2ForCausalLM`
- * Gemma2
+ * Gemma 2
* `google/gemma-2-9b`, `google/gemma-2-27b`, etc.
* ✅︎
* ✅︎
+- * `Gemma3ForCausalLM`
+ * Gemma 3
+ * `google/gemma-3-1b-it`, etc.
+ * ✅︎
+ * ✅︎
- * `GlmForCausalLM`
* GLM-4
* `THUDM/glm-4-9b-chat-hf`, etc.
@@ -504,7 +509,7 @@ you should explicitly specify the task type to ensure that the model is used in
*
*
- * `Gemma2Model`
- * Gemma2-based
+ * Gemma 2-based
* `BAAI/bge-multilingual-gemma2`, etc.
*
* ✅︎
@@ -752,6 +757,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
+- * `Gemma3ForConditionalGeneration`
+ * Gemma 3
+ * T + I+
+ * `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc.
+ * ✅︎
+ * ✅︎
+ * ✅︎\*
- * `GLM4VForCausalLM`^
* GLM-4V
* T + I
@@ -937,6 +949,31 @@ For more details, please see:
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
:::
+:::{note}
+To use Gemma3 series models, you have to install Hugging Face Transformers library from source via
+`pip install git+https://github.com/huggingface/transformers`.
+The earliest commit that supports this is [`50d3530aa04e7a7d003e6b255a98f79fd0447357`](https://github.com/huggingface/transformers/commit/50d3530aa04e7a7d003e6b255a98f79fd0447357).
+
+Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
+However, there are differences in how they handle text + image inputs:
+
+V0 correctly implements the model's attention pattern:
+- Uses bidirectional attention between the image tokens corresponding to the same image
+- Uses causal attention for other tokens
+- Implemented via (naive) PyTorch SDPA with masking tensors
+- Note: May use significant memory for long prompts with image
+
+V1 currently uses a simplified attention pattern:
+- Uses causal attention for all tokens, including image tokens
+- Generates reasonable outputs but does not match the original model's attention for text + image inputs
+- Will be updated in the future to support the correct behavior
+
+This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
+
+Additionally, vLLM's current Gemma 3 implementation does not support the pan-and-scan image pre-processing algorithm, which helps handle images with skewed aspect ratios by intelligently cropping them into multiple views.
+Without this feature, model performance may degrade when processing images that deviate significantly from square dimensions.
+:::
+
### Pooling Models
See [this page](pooling-models) for more information on how to use pooling models.
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 716c31b96ed1..39acab4765a3 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -118,6 +118,23 @@ def run_fuyu(questions: list[str], modality: str):
return llm, prompts, stop_token_ids
+# Gemma 3
+def run_gemma3(questions: list[str], modality: str):
+ assert modality == "image"
+ model_name = "google/gemma-3-4b-it"
+
+ llm = LLM(model=model_name,
+ max_model_len=2048,
+ max_num_seqs=2,
+ disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
+
+ prompts = [("user\n"
+ f"{question}\n"
+ "model\n") for question in questions]
+ stop_token_ids = None
+ return llm, prompts, stop_token_ids
+
+
# GLM-4v
def run_glm4v(questions: list[str], modality: str):
assert modality == "image"
@@ -405,7 +422,7 @@ def run_mllama(questions: list[str], modality: str):
"type": "image"
}, {
"type": "text",
- "text": f"{question}"
+ "text": question
}]
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
@@ -664,6 +681,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str):
"deepseek_vl_v2": run_deepseek_vl2,
"florence2": run_florence2,
"fuyu": run_fuyu,
+ "gemma3": run_gemma3,
"glm4v": run_glm4v,
"h2ovl_chat": run_h2ovl,
"idefics3": run_idefics3,
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index 6fdd4383c1a1..4963e6a8c4e7 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -80,6 +80,42 @@ def load_deepseek_vl2(question: str, image_urls: list[str]):
)
+def load_gemma3(question, image_urls: list[str]) -> ModelRequestData:
+ model_name = "google/gemma-3-4b-it"
+
+ llm = LLM(model=model_name,
+ max_model_len=8192,
+ max_num_seqs=2,
+ limit_mm_per_prompt={"image": len(image_urls)})
+
+ placeholders = [{"type": "image", "image": url} for url in image_urls]
+ messages = [{
+ "role":
+ "user",
+ "content": [
+ *placeholders,
+ {
+ "type": "text",
+ "text": question
+ },
+ ],
+ }]
+
+ processor = AutoProcessor.from_pretrained(model_name)
+
+ prompt = processor.apply_chat_template(messages,
+ tokenize=False,
+ add_generation_prompt=True)
+
+ return ModelRequestData(
+ llm=llm,
+ prompt=prompt,
+ stop_token_ids=None,
+ image_data=[fetch_image(url) for url in image_urls],
+ chat_template=None,
+ )
+
+
def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "h2oai/h2ovl-mississippi-800m"
@@ -496,6 +532,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
model_example_map = {
"aria": load_aria,
"deepseek_vl_v2": load_deepseek_vl2,
+ "gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index e64b703cc520..467114eedb01 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -162,6 +162,7 @@ def _test_processing_correctness(
"deepseek-ai/deepseek-vl2-tiny",
"microsoft/Florence-2-base",
"adept/fuyu-8b",
+ "google/gemma-3-4b-it",
"THUDM/glm-4v-9b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
diff --git a/tests/models/registry.py b/tests/models/registry.py
index a7a88d199047..eadbd7e6f492 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -124,6 +124,8 @@ def check_available_online(
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
+ "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it",
+ min_transformers_version="4.50"),
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
"GPT2LMHeadModel": _HfExamplesInfo("gpt2"),
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"),
@@ -241,6 +243,8 @@ def check_available_online(
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
+ "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it",
+ min_transformers_version="4.50"),
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
diff --git a/vllm/config.py b/vllm/config.py
index a0f30d0e7b77..2ee45f1837c4 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -350,10 +350,11 @@ def __init__(
if self.enforce_eager is None:
self.enforce_eager = False
+ interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
- (self.hf_text_config.model_type in ["gemma2", "cohere2"]))
+ (self.hf_text_config.model_type in interleaved_attn_models))
if (not self.disable_sliding_window and has_interleaved_attention):
if (backend :=
@@ -2501,11 +2502,11 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
- if config.model_type == "gemma2":
+ if config.model_type in ("gemma2", "gemma3", "gemma3_text"):
logger.info(
- "For Gemma 2, we downcast float32 to bfloat16 instead "
- "of float16 by default. Please specify `dtype` if you "
- "want to use float16.")
+ "For Gemma 2 and 3, we downcast float32 to bfloat16 "
+ "instead of float16 by default. Please specify `dtype` "
+ "if you want to use float16.")
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
@@ -2637,7 +2638,9 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None)
- if rope_scaling is not None:
+ # NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE
+ # scaling, so we skip applying the scaling factor again.
+ if rope_scaling is not None and "gemma3" not in hf_config.model_type:
# No need to consider "type" key because of patch_rope_scaling when
# loading HF config
rope_type = rope_scaling["rope_type"]
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index b51ade17def6..61f21482f707 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -433,6 +433,8 @@ def _placeholder_str(self, modality: ModalityStr,
return ""
if model_type == "aria":
return "<|fim_prefix|><|img|><|fim_suffix|>"
+ if model_type == "gemma3":
+ return ""
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py
new file mode 100644
index 000000000000..f1ecf7fa821d
--- /dev/null
+++ b/vllm/model_executor/models/gemma3.py
@@ -0,0 +1,533 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM team.
+# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Iterable, Optional, Set, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from transformers import Gemma3TextConfig
+
+from vllm.attention import Attention
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, VllmConfig
+from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import GeluAndMul
+from vllm.model_executor.layers.layernorm import GemmaRMSNorm
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader, maybe_remap_kv_scale_name)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import SupportsLoRA, SupportsPP
+from .utils import (AutoWeightsLoader, extract_layer_index,
+ is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
+
+logger = init_logger(__name__)
+
+
+class Gemma3MLP(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_activation: str,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size, [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config)
+ self.down_proj = RowParallelLinear(intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config)
+ if hidden_activation != "gelu_pytorch_tanh":
+ raise ValueError(
+ "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
+ "function. Please set `hidden_act` and `hidden_activation` to "
+ "`gelu_pytorch_tanh`.")
+ self.act_fn = GeluAndMul(approximate="tanh")
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ gate_up, _ = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x, _ = self.down_proj(x)
+ return x
+
+
+class Gemma3Attention(nn.Module):
+
+ def __init__(self,
+ config: Gemma3TextConfig,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ max_position_embeddings: int,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ attn_logits_soft_cap: Optional[float] = None,
+ prefix: str = "") -> None:
+ super().__init__()
+ self.config = config
+ self.hidden_size = hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ self.head_dim = head_dim
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = config.query_pre_attn_scalar**-0.5
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=config.attention_bias,
+ quant_config=quant_config,
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=config.attention_bias,
+ quant_config=quant_config,
+ )
+
+ self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ # TODO(woosuk): Add reference to the original HF implementation.
+ layer_idx = extract_layer_index(prefix)
+ self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern)
+ # Initialize the rotary embedding.
+ if self.is_sliding:
+ # Local attention. Override the values in config.json.
+ self.rope_theta = config.rope_local_base_freq
+ self.rope_scaling = {"rope_type": "default"}
+ self.sliding_window = config.interleaved_sliding_window
+ else:
+ # Global attention. Use the values in config.json.
+ self.rope_theta = config.rope_theta
+ self.rope_scaling = config.rope_scaling
+ self.sliding_window = None
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position_embeddings,
+ base=self.rope_theta,
+ is_neox_style=True,
+ rope_scaling=self.rope_scaling,
+ )
+
+ # Initialize the attention.
+ self.attn = Attention(self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ logits_soft_cap=attn_logits_soft_cap,
+ per_layer_sliding_window=self.sliding_window,
+ prefix=f"{prefix}.attn")
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ **kwargs,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+ q = q.unflatten(-1, (self.num_heads, self.head_dim))
+ q = self.q_norm(q)
+ q = q.flatten(-2, -1)
+ k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
+ k = self.k_norm(k)
+ k = k.flatten(-2, -1)
+
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v)
+
+ if not kwargs.get("has_images", False):
+ # Fast path for text-only inputs. The performance for the text-only
+ # inputs are not affected by the naive attention below.
+ output, _ = self.o_proj(attn_output)
+ return output
+
+ # NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
+ # that correspond to the same image while using causal attention
+ # otherwise. Current attention backends cannot handle this pattern, so
+ # we temporarily use a naive attention implementation with mask tensors.
+
+ # We intentionally keep the attention backend as-is and only override
+ # `attn_output` with the naive implementation's output. This minimizes
+ # changes to existing model runners and attention backends. The call to
+ # `self.attn(q, k, v)` is only used to populate the KV cache - its
+ # output is discarded and overwritten below. While this duplicates
+ # computation, it maintains compatibility.
+ # TODO(woosuk): Optimize by implementing custom attention kernels.
+ attn_output = self.naive_attn_with_masks(q,
+ k,
+ v,
+ out=attn_output,
+ **kwargs)
+ output, _ = self.o_proj(attn_output)
+ return output
+
+ def naive_attn_with_masks(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ out: torch.Tensor,
+ **kwargs,
+ ) -> torch.Tensor:
+ # NOTE(woosuk): As described in the comment above, this code is not
+ # meant to be performant. It is only meant to be correct.
+ q = q.view(-1, self.num_heads, self.head_dim)
+ # Expand the key and value to handle GQA.
+ num_queries_per_kv = self.num_heads // self.num_kv_heads
+ k = k.view(-1, self.num_kv_heads, self.head_dim)
+ k = k.repeat_interleave(num_queries_per_kv, dim=-2)
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
+ v = v.repeat_interleave(num_queries_per_kv, dim=-2)
+
+ if self.is_sliding:
+ attn_masks = kwargs["local_attn_masks"]
+ else:
+ attn_masks = kwargs["global_attn_masks"]
+
+ seq_lens = kwargs["seq_lens"]
+ start_idx = 0
+ for seq_len, attn_mask in zip(seq_lens, attn_masks):
+ end_idx = start_idx + seq_len
+ query = q[start_idx:end_idx].unsqueeze(0)
+ key = k[start_idx:end_idx].unsqueeze(0)
+ value = v[start_idx:end_idx].unsqueeze(0)
+
+ # Transpose.
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ output = F.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ attn_mask,
+ self.scaling,
+ )
+ output = output.transpose(1, 2).flatten(-2, -1)
+ out[start_idx:end_idx] = output
+ start_idx = end_idx
+ return out
+
+
+class Gemma3DecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: Gemma3TextConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = Gemma3Attention(
+ config=config,
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ head_dim=config.head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ attn_logits_soft_cap=None,
+ prefix=f"{prefix}.self_attn",
+ )
+ self.hidden_size = config.hidden_size
+ self.mlp = Gemma3MLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_activation=config.hidden_activation,
+ quant_config=quant_config,
+ )
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: Optional[torch.Tensor],
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual)
+ hidden_states = self.self_attn(
+ positions=positions,
+ hidden_states=hidden_states,
+ **kwargs,
+ )
+ hidden_states = self.post_attention_layernorm(hidden_states)
+
+ hidden_states, residual = self.pre_feedforward_layernorm(
+ hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
+ return hidden_states, residual
+
+
+@support_torch_compile
+class Gemma3Model(nn.Module):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.quant_config = quant_config
+
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ )
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: Gemma3DecoderLayer(
+ config, cache_config, quant_config, prefix=prefix),
+ prefix=f"{prefix}.layers")
+ self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ # Normalize the embedding by sqrt(hidden_size)
+ # The normalizer's data type should be downcasted to the model's
+ # data type such as bfloat16, not float32.
+ # See https://github.com/huggingface/transformers/pull/29402
+ normalizer = self.config.hidden_size**0.5
+ self.register_buffer("normalizer", torch.tensor(normalizer))
+ self.make_empty_intermediate_tensors = (
+ make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size))
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ # NOTE(woosuk): Only apply the normalizer to the output of
+ # vocab embedding. Don't apply it to the vision embedding.
+ return self.embed_tokens(input_ids) * self.normalizer
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor],
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.get_input_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+ for layer in self.layers[self.start_layer:self.end_layer]:
+ hidden_states, residual = layer(
+ positions,
+ hidden_states,
+ residual,
+ **kwargs,
+ )
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({
+ "hidden_states": hidden_states,
+ "residual": residual
+ })
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> Set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+ params_dict = dict(self.named_parameters())
+ loaded_params: Set[str] = set()
+ for name, loaded_weight in weights:
+ if (self.quant_config is not None and
+ (scale_name := self.quant_config.get_cache_scale(name))):
+ # Loading kv cache scales for compressed-tensors quantization
+ param = params_dict[scale_name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ loaded_weight = loaded_weight[0]
+ weight_loader(param, loaded_weight)
+ loaded_params.add(scale_name)
+ continue
+ for (param_name, shard_name, shard_id) in stacked_params_mapping:
+ if shard_name not in name:
+ continue
+ name = name.replace(shard_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+
+ unloaded_params = params_dict.keys() - loaded_params
+ if unloaded_params:
+ logger.warning(
+ "Some weights are not initialized from checkpoints: %s",
+ unloaded_params)
+ return loaded_params
+
+
+class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ lora_config = vllm_config.lora_config
+ del lora_config # Unused.
+ super().__init__()
+ self.config = config
+ # currently all existing Gemma models have `tie_word_embeddings` enabled
+ assert config.tie_word_embeddings
+ self.quant_config = quant_config
+ self.model = Gemma3Model(vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"))
+ self.logits_processor = LogitsProcessor(
+ config.vocab_size, soft_cap=config.final_logit_softcapping)
+ self.sampler = get_sampler()
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.get_input_embeddings(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ hidden_states = self.model(input_ids, positions, intermediate_tensors,
+ inputs_embeds, **kwargs)
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ logits = self.logits_processor(self.model.embed_tokens, hidden_states,
+ sampling_metadata)
+ return logits
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ return next_tokens
+
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> Set[str]:
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=(["lm_head."]
+ if self.config.tie_word_embeddings else None),
+ )
+ return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py
new file mode 100644
index 000000000000..121aee51786b
--- /dev/null
+++ b/vllm/model_executor/models/gemma3_mm.py
@@ -0,0 +1,425 @@
+# SPDX-License-Identifier: Apache-2.0
+from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set,
+ Tuple, TypedDict, Union)
+
+import torch
+from torch import nn
+from transformers import BatchFeature, Gemma3Config, ProcessorMixin
+
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.model_executor.layers.layernorm import GemmaRMSNorm
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
+ NestedTensors)
+from vllm.multimodal.parse import ImageSize, MultiModalDataItems
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement,
+ PromptUpdate, PromptUpdateDetails)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import SupportsMultiModal, SupportsPP
+from .siglip import SiglipVisionModel
+from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
+ maybe_prefix, merge_multimodal_embeddings)
+
+logger = init_logger(__name__)
+
+
+class Gemma3ImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ data: torch.Tensor
+ """Shape: `(batch_size * num_images, num_channels, height, width)`"""
+
+
+Gemma3ImageInputs = Gemma3ImagePixelInputs
+
+
+class Gemma3ProcessingInfo(BaseProcessingInfo):
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": None}
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ hf_config = self.ctx.get_hf_config()
+ return {"image": hf_config.mm_tokens_per_image}
+
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ processor: Optional[ProcessorMixin],
+ ) -> int:
+ hf_config = self.ctx.get_hf_config()
+ return hf_config.mm_tokens_per_image
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ # Result in the max possible feature size (h:w = 16:1)
+ return ImageSize(height=8000, width=50)
+
+
+class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
+
+ def get_dummy_processor_inputs(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ tokenizer = self.info.get_tokenizer()
+ boi_token = tokenizer.boi_token
+
+ num_images = mm_counts.get("image", 0)
+ target_width, target_height = \
+ self.info.get_image_size_with_most_features()
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images)
+ }
+ return ProcessorInputs(
+ prompt_text=" ".join([boi_token] * num_images),
+ mm_data=mm_data,
+ )
+
+
+class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ # TODO(woosuk): Support pan-and-scan.
+ img_kwargs = mm_kwargs.get("images_kwargs", {})
+ img_kwargs["do_pan_and_scan"] = False
+ mm_kwargs["images_kwargs"] = img_kwargs
+ return super()._call_hf_processor(
+ prompt=prompt,
+ mm_data=mm_data,
+ mm_kwargs=mm_kwargs,
+ )
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(pixel_values=MultiModalFieldConfig.batched("image"))
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> Sequence[PromptUpdate]:
+ tokenizer = self.info.get_tokenizer()
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ hf_config = self.info.get_hf_config()
+
+ boi_token = tokenizer.boi_token
+ image_token = tokenizer.image_token
+ mm_tokens_per_image = hf_config.mm_tokens_per_image
+ image_tokens_expanded = "".join([image_token] * mm_tokens_per_image)
+
+ def get_replacement_gemma3(item_idx: int):
+ return PromptUpdateDetails(
+ full=hf_processor.full_image_sequence,
+ features=image_tokens_expanded,
+ )
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=boi_token,
+ replacement=get_replacement_gemma3,
+ )
+ ]
+
+
+class Gemma3MultiModalProjector(nn.Module):
+
+ def __init__(self, config: Gemma3Config):
+ super().__init__()
+
+ self.mm_input_projection_weight = nn.Parameter(
+ torch.zeros(config.vision_config.hidden_size,
+ config.text_config.hidden_size))
+
+ self.mm_soft_emb_norm = GemmaRMSNorm(
+ config.vision_config.hidden_size,
+ eps=config.vision_config.layer_norm_eps)
+
+ self.patches_per_image = int(config.vision_config.image_size //
+ config.vision_config.patch_size)
+ self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
+ self.kernel_size = self.patches_per_image // self.tokens_per_side
+ self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size,
+ stride=self.kernel_size)
+
+ def forward(self, vision_outputs: torch.Tensor):
+ batch_size, _, seq_length = vision_outputs.shape
+
+ reshaped_vision_outputs = vision_outputs.transpose(1, 2)
+ reshaped_vision_outputs = reshaped_vision_outputs.reshape(
+ batch_size, seq_length, self.patches_per_image,
+ self.patches_per_image)
+ reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
+
+ pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
+ pooled_vision_outputs = pooled_vision_outputs.flatten(2)
+ pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
+
+ normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
+
+ projected_vision_outputs = torch.matmul(
+ normed_vision_outputs, self.mm_input_projection_weight)
+ return projected_vision_outputs.type_as(vision_outputs)
+
+
+@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
+ info=Gemma3ProcessingInfo,
+ dummy_inputs=Gemma3DummyInputsBuilder)
+class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
+ SupportsPP):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+ self.config = config
+ self.quant_config = quant_config
+ self.multimodal_config = multimodal_config
+ self.sliding_window = config.text_config.interleaved_sliding_window
+
+ self.vision_tower = SiglipVisionModel(config.vision_config,
+ quant_config,
+ prefix=maybe_prefix(
+ prefix, "vision_tower"))
+ self.multi_modal_projector = Gemma3MultiModalProjector(config)
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ architectures=["Gemma3ForCausalLM"],
+ )
+ logit_scale = getattr(config, "logit_scale", 1.0)
+ self.language_model.logits_processor.scale *= logit_scale
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ @property
+ def sampler(self):
+ return self.language_model.sampler
+
+ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
+ h = w = self.config.vision_config.image_size
+ expected_dims = (3, h, w)
+
+ def _validate_shape(d: torch.Tensor):
+ if d.shape != expected_dims:
+ raise ValueError(
+ "The expected shape of pixel values per image per batch "
+ f"is {expected_dims}. You supplied {tuple(d.shape)}.")
+
+ for d in data:
+ _validate_shape(d)
+
+ return data
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
+ pixel_values = kwargs.pop("pixel_values", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+ assert image_embeds is None, "Gemma3 does not support image_embeds."
+ if pixel_values is None:
+ return None
+
+ if not isinstance(pixel_values, (torch.Tensor, list[torch.Tensor])):
+ raise ValueError("Incorrect type of pixel values. "
+ f"Got type: {type(pixel_values)}")
+
+ pixel_values = flatten_bn(pixel_values, concat=True)
+ return Gemma3ImagePixelInputs(
+ type="pixel_values",
+ data=self._validate_pixel_values(pixel_values),
+ )
+
+ def _image_pixels_to_features(
+ self,
+ vision_tower: SiglipVisionModel,
+ pixel_values: torch.Tensor,
+ ) -> torch.Tensor:
+ target_dtype = vision_tower.get_input_embeddings().weight.dtype
+ image_features = vision_tower(pixel_values.to(dtype=target_dtype))
+ return image_features
+
+ def _process_image_input(
+ self,
+ image_input: Gemma3ImageInputs,
+ ) -> torch.Tensor:
+ assert self.vision_tower is not None
+ pixel_values = image_input["data"]
+ vision_outputs = self._image_pixels_to_features(
+ self.vision_tower,
+ pixel_values,
+ )
+ return self.multi_modal_projector(vision_outputs)
+
+ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ if image_input is None:
+ return None
+ vision_embeddings = self._process_image_input(image_input)
+ return vision_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[NestedTensors] = None,
+ ) -> torch.Tensor:
+ if multimodal_embeddings is None:
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+ else:
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, multimodal_embeddings,
+ self.config.image_token_index)
+ return inputs_embeds
+
+ def forward(self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
+ # condition is for v0 compatibility.
+ elif inputs_embeds is None:
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
+
+ inputs_embeds = self.get_input_embeddings(input_ids,
+ vision_embeddings)
+ if vision_embeddings is not None:
+ kwargs = self.prepare_attn_masks(
+ input_ids,
+ positions,
+ mask_dtype=vision_embeddings.dtype,
+ **kwargs)
+ input_ids = None
+
+ hidden_states = self.language_model.model(input_ids,
+ positions,
+ intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ **kwargs)
+
+ return hidden_states
+
+ def prepare_attn_masks(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ mask_dtype: torch.dtype,
+ **kwargs,
+ ):
+ kwargs["has_images"] = True
+ # NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
+ # This is a HACK. Fix this.
+ start_idices = (positions == 0).cpu().nonzero()
+ num_seqs = len(start_idices)
+ seq_lens = []
+ for i in range(num_seqs):
+ start_idx = start_idices[i].item()
+ if i < num_seqs - 1:
+ end_idx = start_idices[i + 1].item()
+ else:
+ end_idx = len(input_ids)
+ seq_lens.append(end_idx - start_idx)
+ kwargs["seq_lens"] = seq_lens
+
+ global_attn_masks = []
+ local_attn_masks = []
+ start_idx = 0
+ for seq_len in seq_lens:
+ end_idx = start_idx + seq_len
+ input_token_ids = input_ids[start_idx:end_idx]
+ start_idx = end_idx
+ # Create a global causal mask.
+ global_attn_mask = torch.empty(
+ 1,
+ 1,
+ seq_len,
+ seq_len,
+ dtype=mask_dtype,
+ device=input_ids.device,
+ )
+ global_attn_mask.fill_(float("-inf"))
+ # Fill the lower triangle with 0.
+ global_attn_mask = global_attn_mask.triu(diagonal=1)
+
+ # Consider the bidirectional attention between image tokens.
+ img_mask = torch.zeros_like(global_attn_mask)
+ img_pos = (input_token_ids == self.config.image_token_index)
+ img_mask[:, :, :, img_pos] += 1
+ img_mask[:, :, img_pos, :] += 1
+ global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
+ global_attn_masks.append(global_attn_mask)
+
+ # Create a local causal mask with sliding window (1024).
+ local_attn_mask = torch.ones_like(global_attn_mask)
+ local_attn_mask = torch.tril(local_attn_mask,
+ diagonal=-self.sliding_window)
+ local_attn_mask = torch.where(local_attn_mask == 0,
+ global_attn_mask, float("-inf"))
+ local_attn_masks.append(local_attn_mask)
+ kwargs["global_attn_masks"] = global_attn_masks
+ kwargs["local_attn_masks"] = local_attn_masks
+ return kwargs
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ return self.language_model.sample(logits, sampling_metadata)
+
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> Set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 74160e2d9ee4..5dd3aa2973cd 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -53,6 +53,7 @@
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
+ "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
@@ -161,6 +162,7 @@
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
+ "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),