Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
| `FireRedLIDForConditionalGeneration` | FireRedLID | `PatchyTisa/FireRedLID-vllm`, etc. | | |
| `FunASRForConditionalGeneration` | FunASR | `allendou/Fun-ASR-Nano-2512-vllm`, etc. | | |
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `Gemma4ForConditionalGeneration` | Gemma 4 | `google/gemma-4-E2B-it`, `google/gemma-4-E4B-it`, etc. | | ✅︎ |
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-4.0-1b-speech`, `ibm-granite/granite-speech-3.3-2b`, etc. | ✅︎ | ✅︎ |
| `Qwen3ASRForConditionalGeneration` | Qwen3-ASR | `Qwen/Qwen3-ASR-1.7B`, etc. | ✅︎ | ✅︎ |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):

@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name", ["google/gemma-3n-E2B-it", "Qwen/Qwen3-ASR-0.6B"]
"model_name",
["google/gemma-3n-E2B-it", "google/gemma-4-E2B-it", "Qwen/Qwen3-ASR-0.6B"],
)
async def test_basic_audio_foscolo(foscolo, rocm_aiter_fa_attention, model_name):
# Gemma accuracy on some of the audio samples we use is particularly bad,
Expand All @@ -152,5 +153,5 @@ async def test_basic_audio_foscolo(foscolo, rocm_aiter_fa_attention, model_name)
model_name,
foscolo,
language="it",
expected_text="ove il mio corpo fanciulletto",
expected_text="ove il mio corpo",
)
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ def _get_server_args(attention_config):


@pytest.fixture(
scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"]
scope="module",
params=[
"openai/whisper-small",
"google/gemma-3n-E2B-it",
"google/gemma-4-E2B-it",
],
)
def server(request):
# Parametrize over model name
Expand Down Expand Up @@ -261,8 +266,8 @@ async def test_stream_options(foscolo, server):
@pytest.mark.asyncio
async def test_long_audio_request(foscolo, client_and_model):
client, model_name = client_and_model
if model_name == "google/gemma-3n-E2B-it":
pytest.skip("Gemma3n does not support long audio requests")
if model_name in ("google/gemma-3n-E2B-it", "google/gemma-4-E2B-it"):
pytest.skip(f"{model_name} does not support audio chunking in vLLM yet")
foscolo.seek(0)
audio, sr = load_audio(foscolo)
repeated_audio = np.tile(audio, 2)
Expand Down
44 changes: 44 additions & 0 deletions tests/models/multimodal/processing/test_gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Mapping
from typing import cast

import numpy as np
import pytest
import torch
from PIL import Image as PILImage

from vllm.config.model import ModelConfig
from vllm.config.speech_to_text import SpeechToTextConfig, SpeechToTextParams
from vllm.model_executor.models.gemma4_mm import (
Gemma4ForConditionalGeneration,
Gemma4ImagePixelInputs,
Expand Down Expand Up @@ -285,3 +289,43 @@ def test_encoder_chunk_no_free_memory_falls_back_to_one():
)
== 1
)


# --- STT prompt generation ---


def _make_stt_params(
*,
language: str | None = "en",
task_type: str = "transcribe",
to_language: str | None = None,
) -> SpeechToTextParams:
return SpeechToTextParams(
audio=np.zeros(1600, dtype=np.float32),
stt_config=SpeechToTextConfig(sample_rate=16000),
model_config=cast(ModelConfig, object()),
language=language,
task_type=task_type,
to_language=to_language,
)


def test_gemma4_transcription_prompt_uses_audio_token():
prompt = Gemma4ForConditionalGeneration.get_generation_prompt(_make_stt_params())

assert prompt["prompt"] == (
"<bos><|turn>user\n"
"Transcribe this audio into English: <|audio|><turn|>\n"
"<|turn>model\n"
)
assert prompt["multi_modal_data"]["audio"][1] == 16000


def test_gemma4_translation_prompt_includes_source_and_target_language():
prompt = Gemma4ForConditionalGeneration.get_generation_prompt(
_make_stt_params(task_type="translate", language="it", to_language="en")
)

assert (
"Translate this audio from Italian into English: <|audio|>" in prompt["prompt"]
)
69 changes: 67 additions & 2 deletions vllm/model_executor/models/gemma4_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@
Gemma4TextConfig,
)

from vllm.config import VllmConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.inputs import MultiModalDataDict
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.inputs import MultiModalDataDict, PromptType, TextPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalFieldConfig,
Expand All @@ -63,6 +65,7 @@
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import (
Expand All @@ -71,6 +74,7 @@
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsTranscription,
)
from .utils import (
AutoWeightsLoader,
Expand Down Expand Up @@ -920,7 +924,10 @@ class Gemma4ForConditionalGeneration(
SupportsPP,
SupportsLoRA,
SupportsEagle3,
SupportsTranscription,
):
supported_languages = ISO639_1_SUPPORTED_LANGS

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -1599,3 +1606,61 @@ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality == "video":
return "<|video|>"
raise ValueError(f"Unsupported modality: {modality}")

@classmethod
def get_generation_prompt(cls, stt_params: SpeechToTextParams) -> PromptType:
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
task_type = stt_params.task_type
to_language = stt_params.to_language

prompt = "<bos><|turn>user\n"
prompt += "Transcribe" if task_type == "transcribe" else "Translate"
prompt += " this audio"

full_lang_name = cls.supported_languages.get(language, "")
full_lang_name_to = cls.supported_languages.get(to_language, "")

if task_type == "transcribe" and full_lang_name:
prompt += f" into {full_lang_name}"
elif task_type == "translate":
if full_lang_name:
prompt += f" from {full_lang_name}"
if full_lang_name_to:
prompt += f" into {full_lang_name_to}"

prompt += ": <|audio|><turn|>\n<|turn>model\n"
Comment thread
SoluMilken marked this conversation as resolved.

return TextPrompt(
prompt=prompt,
multi_modal_data={"audio": (audio, stt_config.sample_rate)},
)

@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
processor = cached_processor_from_config(model_config)
feature_extractor = processor.feature_extractor
max_audio_clip_s = math.floor(
processor.audio_seq_length * processor.audio_ms_per_token / 1000
)
return SpeechToTextConfig(
max_audio_clip_s=max_audio_clip_s,
sample_rate=feature_extractor.sampling_rate,
min_energy_split_window_size=None,
)

@classmethod
def get_num_audio_tokens(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> int | None:
processor = cached_processor_from_config(model_config)
num_audio_tokens = math.ceil(
audio_duration_s * 1000 / processor.audio_ms_per_token
)
return min(num_audio_tokens, processor.audio_seq_length) + 2
Loading