Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 62 additions & 24 deletions tests/models/multimodal/generation/test_voxtral_realtime.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from dataclasses import asdict

import pytest
import pytest_asyncio
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import RawAudio
from mistral_common.protocol.transcription.request import (
Expand All @@ -17,18 +19,21 @@
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM

from ....utils import ROCM_ENGINE_KWARGS

MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602"
ENGINE_CONFIG = dict(
model=MODEL_NAME,
max_model_len=8192,
max_num_seqs=4,
limit_mm_per_prompt={"audio": 1},
config_format="mistral",
load_format="mistral",
tokenizer_mode="mistral",
enforce_eager=True,
gpu_memory_utilization=0.9,
)
ENGINE_CONFIG = {
"model": MODEL_NAME,
"max_model_len": 8192,
"max_num_seqs": 4,
"limit_mm_per_prompt": {"audio": 1},
"config_format": "mistral",
"load_format": "mistral",
"tokenizer_mode": "mistral",
"enforce_eager": True,
"gpu_memory_utilization": 0.9,
**ROCM_ENGINE_KWARGS,
}


EXPECTED_TEXT = [
Expand All @@ -49,6 +54,14 @@
]


def _normalize(texts: list[str]) -> list[str]:
# The model occasionally transcribes "OBS" as "a base hit" and
# "oh, my" as "oh my", but both are acoustically valid. Normalise so
# the assertion is stable across runs and hardware.
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
return texts


@pytest.fixture
def audio_assets() -> list[AudioAsset]:
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
Expand All @@ -60,15 +73,27 @@ def tokenizer() -> MistralTokenizer:


@pytest.fixture
def engine() -> LLM:
def engine():
engine_args = EngineArgs(**ENGINE_CONFIG)
return LLM(**asdict(engine_args))
llm = LLM(**asdict(engine_args))
try:
yield llm
finally:
with contextlib.suppress(Exception):
llm.llm_engine.engine_core.shutdown()
import torch

torch.accelerator.empty_cache()

@pytest.fixture
def async_engine() -> AsyncLLM:

@pytest_asyncio.fixture
async def async_engine():
engine_args = AsyncEngineArgs(**ENGINE_CONFIG)
return AsyncLLM.from_engine_args(engine_args)
llm = AsyncLLM.from_engine_args(engine_args)
try:
yield llm
finally:
llm.shutdown()


def test_voxtral_realtime_forward(audio_assets, tokenizer, engine):
Expand Down Expand Up @@ -108,8 +133,13 @@ def from_file(file_path: str):
sampling_params=sampling_params,
)

texts = [out.outputs[0].text for out in outputs]
assert texts == EXPECTED_TEXT
texts = _normalize([out.outputs[0].text for out in outputs])
for i, (got, expected) in enumerate(zip(texts, EXPECTED_TEXT)):
assert got == expected, (
f"Output mismatch at index {i}:\n"
f" got: {got!r}\n"
f" expected: {expected!r}"
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -149,9 +179,17 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine)

output_tokens_list.append(output_tokens)

texts = [
tokenizer.decode(output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE)
for output_tokens in output_tokens_list
]
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
assert texts == EXPECTED_TEXT
texts = _normalize(
[
tokenizer.decode(
output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE
)
for output_tokens in output_tokens_list
]
)
for i, (got, expected) in enumerate(zip(texts, EXPECTED_TEXT)):
assert got == expected, (
f"Output mismatch at index {i}:\n"
f" got: {got!r}\n"
f" expected: {expected!r}"
)
6 changes: 6 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def _nvml():
if current_platform.is_rocm()
else []
)
# Python-API equivalent of ROCM_EXTRA_ARGS for use with EngineArgs kwargs.
ROCM_ENGINE_KWARGS: dict = (
{"enable_prefix_caching": False, "max_num_seqs": 1}
if current_platform.is_rocm()
else {}
)


class RemoteVLLMServer:
Expand Down
Loading