Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
046e05a
Gemma3 1B working
WoosukKwon Mar 4, 2025
669ae5a
Add modeling_gemma3.py
WoosukKwon Mar 6, 2025
e044645
remove
WoosukKwon Mar 9, 2025
bd78da5
[TMP] Add HF Gemma 3
WoosukKwon Mar 9, 2025
a50c2d5
Fix config:
WoosukKwon Mar 9, 2025
d2562cb
[TMP] image input
WoosukKwon Mar 10, 2025
017239e
Update
WoosukKwon Mar 11, 2025
de0136a
Update
WoosukKwon Mar 11, 2025
12b7e9d
update
WoosukKwon Mar 11, 2025
26b6199
minor
WoosukKwon Mar 11, 2025
b99336a
Remove
WoosukKwon Mar 11, 2025
004dc92
fix
WoosukKwon Mar 11, 2025
2bb965b
Add kwargs
WoosukKwon Mar 11, 2025
b119945
Hew HF
WoosukKwon Mar 11, 2025
4c67573
Fix scaling
WoosukKwon Mar 11, 2025
d90d410
bidirectional attn
WoosukKwon Mar 11, 2025
366e4b5
sliding window
WoosukKwon Mar 11, 2025
77b9dd7
Remove HF
WoosukKwon Mar 11, 2025
f0f8e9d
revert
WoosukKwon Mar 11, 2025
28e757b
add placeholder str
WoosukKwon Mar 11, 2025
a935b24
Merge branch 'main' into woosuk-gemma3
WoosukKwon Mar 12, 2025
285ffc4
minor
WoosukKwon Mar 12, 2025
713766b
Add comments
WoosukKwon Mar 12, 2025
6746086
Minor
WoosukKwon Mar 12, 2025
6fa0336
cleanup
WoosukKwon Mar 12, 2025
0384ceb
tmp
WoosukKwon Mar 12, 2025
2beb199
minor
WoosukKwon Mar 12, 2025
64ef15f
Docs
WoosukKwon Mar 12, 2025
d92c7c1
ruff
WoosukKwon Mar 12, 2025
82acdcd
Update transformers version
WoosukKwon Mar 12, 2025
2943443
minor
WoosukKwon Mar 12, 2025
635e1a9
Minor
WoosukKwon Mar 12, 2025
02bf606
gg-hf-g -> google
WoosukKwon Mar 12, 2025
a4c924b
Merge branch 'main' into woosuk-gemma3
WoosukKwon Mar 12, 2025
7097fa8
Fix example
WoosukKwon Mar 12, 2025
ac72f69
minor
WoosukKwon Mar 12, 2025
5ab60d0
minor
WoosukKwon Mar 12, 2025
3c42695
fix docs
WoosukKwon Mar 12, 2025
883c656
minor docs
WoosukKwon Mar 12, 2025
1d2a064
polish
WoosukKwon Mar 12, 2025
0230237
polish
WoosukKwon Mar 12, 2025
32ebaf1
fix reqs
WoosukKwon Mar 12, 2025
4b65deb
Remove hardcoded values
WoosukKwon Mar 12, 2025
7ca73b1
Fix
WoosukKwon Mar 12, 2025
d5f2eef
Add min_transformers_version
WoosukKwon Mar 12, 2025
a4dbd56
fix
WoosukKwon Mar 12, 2025
2a8e2fa
comment about pas
WoosukKwon Mar 12, 2025
6ecd1ec
fix input indices
ywang96 Mar 12, 2025
e2e2a22
fix batch with mixed numbers of images
ywang96 Mar 12, 2025
49b29ca
update examples with instruct format
ywang96 Mar 12, 2025
4111004
update doc
ywang96 Mar 12, 2025
154c24c
Fix registry
DarkLight1337 Mar 12, 2025
a2d2062
Update doc
DarkLight1337 Mar 12, 2025
d328675
Clean up examples
DarkLight1337 Mar 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,15 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
- * `Gemma2ForCausalLM`
* Gemma2
* Gemma 2
* `google/gemma-2-9b`, `google/gemma-2-27b`, etc.
* ✅︎
* ✅︎
- * `Gemma3ForCausalLM`
* Gemma 3
* `google/gemma-3-1b-it`, etc.
* ✅︎
* ✅︎
- * `GlmForCausalLM`
* GLM-4
* `THUDM/glm-4-9b-chat-hf`, etc.
Expand Down Expand Up @@ -504,7 +509,7 @@ you should explicitly specify the task type to ensure that the model is used in
*
*
- * `Gemma2Model`
* Gemma2-based
* Gemma 2-based
* `BAAI/bge-multilingual-gemma2`, etc.
*
* ✅︎
Expand Down Expand Up @@ -752,6 +757,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
- * `Gemma3ForConditionalGeneration`
* Gemma 3
* T + I<sup>+</sup>
* `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc.
* ✅︎
* ✅︎
* ✅︎\*
- * `GLM4VForCausalLM`<sup>^</sup>
* GLM-4V
* T + I
Expand Down Expand Up @@ -937,6 +949,31 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
:::

:::{note}
To use Gemma3 series models, you have to install Hugging Face Transformers library from source via
`pip install git+https://github.com/huggingface/transformers`.
The earliest commit that supports this is [`50d3530aa04e7a7d003e6b255a98f79fd0447357`](https://github.com/huggingface/transformers/commit/50d3530aa04e7a7d003e6b255a98f79fd0447357).

Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
However, there are differences in how they handle text + image inputs:

V0 correctly implements the model's attention pattern:
- Uses bidirectional attention between the image tokens corresponding to the same image
- Uses causal attention for other tokens
- Implemented via (naive) PyTorch SDPA with masking tensors
- Note: May use significant memory for long prompts with image

V1 currently uses a simplified attention pattern:
- Uses causal attention for all tokens, including image tokens
- Generates reasonable outputs but does not match the original model's attention for text + image inputs
- Will be updated in the future to support the correct behavior

This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.

Additionally, vLLM's current Gemma 3 implementation does not support the pan-and-scan image pre-processing algorithm, which helps handle images with skewed aspect ratios by intelligently cropping them into multiple views.
Without this feature, model performance may degrade when processing images that deviate significantly from square dimensions.
:::

### Pooling Models

See [this page](pooling-models) for more information on how to use pooling models.
Expand Down
20 changes: 19 additions & 1 deletion examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,23 @@ def run_fuyu(questions: list[str], modality: str):
return llm, prompts, stop_token_ids


# Gemma 3
def run_gemma3(questions: list[str], modality: str):
assert modality == "image"
model_name = "google/gemma-3-4b-it"

llm = LLM(model=model_name,
max_model_len=2048,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

prompts = [("<bos><start_of_turn>user\n"
f"<start_of_image>{question}<end_of_turn>\n"
"<start_of_turn>model\n") for question in questions]
stop_token_ids = None
return llm, prompts, stop_token_ids


# GLM-4v
def run_glm4v(questions: list[str], modality: str):
assert modality == "image"
Expand Down Expand Up @@ -405,7 +422,7 @@ def run_mllama(questions: list[str], modality: str):
"type": "image"
}, {
"type": "text",
"text": f"{question}"
"text": question
}]
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
Expand Down Expand Up @@ -664,6 +681,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str):
"deepseek_vl_v2": run_deepseek_vl2,
"florence2": run_florence2,
"fuyu": run_fuyu,
"gemma3": run_gemma3,
"glm4v": run_glm4v,
"h2ovl_chat": run_h2ovl,
"idefics3": run_idefics3,
Expand Down
37 changes: 37 additions & 0 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,42 @@ def load_deepseek_vl2(question: str, image_urls: list[str]):
)


def load_gemma3(question, image_urls: list[str]) -> ModelRequestData:
model_name = "google/gemma-3-4b-it"

llm = LLM(model=model_name,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)})

placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]

processor = AutoProcessor.from_pretrained(model_name)

prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=None,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)


def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "h2oai/h2ovl-mississippi-800m"

Expand Down Expand Up @@ -496,6 +532,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
model_example_map = {
"aria": load_aria,
"deepseek_vl_v2": load_deepseek_vl2,
"gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
Expand Down
1 change: 1 addition & 0 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _test_processing_correctness(
"deepseek-ai/deepseek-vl2-tiny",
"microsoft/Florence-2-base",
"adept/fuyu-8b",
"google/gemma-3-4b-it",
"THUDM/glm-4v-9b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
Expand Down
4 changes: 4 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def check_available_online(
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it",
min_transformers_version="4.50"),
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
"GPT2LMHeadModel": _HfExamplesInfo("gpt2"),
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"),
Expand Down Expand Up @@ -241,6 +243,8 @@ def check_available_online(
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it",
min_transformers_version="4.50"),
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
Expand Down
15 changes: 9 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,11 @@ def __init__(
if self.enforce_eager is None:
self.enforce_eager = False

interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
(self.hf_text_config.model_type in ["gemma2", "cohere2"]))
(self.hf_text_config.model_type in interleaved_attn_models))

if (not self.disable_sliding_window and has_interleaved_attention):
if (backend :=
Expand Down Expand Up @@ -2501,11 +2502,11 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type == "gemma2":
if config.model_type in ("gemma2", "gemma3", "gemma3_text"):
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16.")
"For Gemma 2 and 3, we downcast float32 to bfloat16 "
"instead of float16 by default. Please specify `dtype` "
"if you want to use float16.")
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
Expand Down Expand Up @@ -2637,7 +2638,9 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
# NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE
# scaling, so we skip applying the scaling factor again.
if rope_scaling is not None and "gemma3" not in hf_config.model_type:
# No need to consider "type" key because of patch_rope_scaling when
# loading HF config
rope_type = rope_scaling["rope_type"]
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,8 @@ def _placeholder_str(self, modality: ModalityStr,
return "<image>"
if model_type == "aria":
return "<|fim_prefix|><|img|><|fim_suffix|>"
if model_type == "gemma3":
return "<start_of_image>"

raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
Expand Down
Loading