diff --git a/tests/models/multimodal/generation/test_voxtral_realtime.py b/tests/models/multimodal/generation/test_voxtral_realtime.py index b38345dc4fbf..cac79b237171 100644 --- a/tests/models/multimodal/generation/test_voxtral_realtime.py +++ b/tests/models/multimodal/generation/test_voxtral_realtime.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib from dataclasses import asdict import pytest +import pytest_asyncio from mistral_common.audio import Audio from mistral_common.protocol.instruct.chunk import RawAudio from mistral_common.protocol.transcription.request import ( @@ -17,18 +19,21 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.v1.engine.async_llm import AsyncLLM +from ....utils import ROCM_ENGINE_KWARGS + MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602" -ENGINE_CONFIG = dict( - model=MODEL_NAME, - max_model_len=8192, - max_num_seqs=4, - limit_mm_per_prompt={"audio": 1}, - config_format="mistral", - load_format="mistral", - tokenizer_mode="mistral", - enforce_eager=True, - gpu_memory_utilization=0.9, -) +ENGINE_CONFIG = { + "model": MODEL_NAME, + "max_model_len": 8192, + "max_num_seqs": 4, + "limit_mm_per_prompt": {"audio": 1}, + "config_format": "mistral", + "load_format": "mistral", + "tokenizer_mode": "mistral", + "enforce_eager": True, + "gpu_memory_utilization": 0.9, + **ROCM_ENGINE_KWARGS, +} EXPECTED_TEXT = [ @@ -49,6 +54,14 @@ ] +def _normalize(texts: list[str]) -> list[str]: + # The model occasionally transcribes "OBS" as "a base hit" and + # "oh, my" as "oh my", but both are acoustically valid. Normalise so + # the assertion is stable across runs and hardware. + texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my") + return texts + + @pytest.fixture def audio_assets() -> list[AudioAsset]: return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] @@ -60,15 +73,27 @@ def tokenizer() -> MistralTokenizer: @pytest.fixture -def engine() -> LLM: +def engine(): engine_args = EngineArgs(**ENGINE_CONFIG) - return LLM(**asdict(engine_args)) + llm = LLM(**asdict(engine_args)) + try: + yield llm + finally: + with contextlib.suppress(Exception): + llm.llm_engine.engine_core.shutdown() + import torch + torch.accelerator.empty_cache() -@pytest.fixture -def async_engine() -> AsyncLLM: + +@pytest_asyncio.fixture +async def async_engine(): engine_args = AsyncEngineArgs(**ENGINE_CONFIG) - return AsyncLLM.from_engine_args(engine_args) + llm = AsyncLLM.from_engine_args(engine_args) + try: + yield llm + finally: + llm.shutdown() def test_voxtral_realtime_forward(audio_assets, tokenizer, engine): @@ -108,8 +133,13 @@ def from_file(file_path: str): sampling_params=sampling_params, ) - texts = [out.outputs[0].text for out in outputs] - assert texts == EXPECTED_TEXT + texts = _normalize([out.outputs[0].text for out in outputs]) + for i, (got, expected) in enumerate(zip(texts, EXPECTED_TEXT)): + assert got == expected, ( + f"Output mismatch at index {i}:\n" + f" got: {got!r}\n" + f" expected: {expected!r}" + ) @pytest.mark.asyncio @@ -149,9 +179,17 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine) output_tokens_list.append(output_tokens) - texts = [ - tokenizer.decode(output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE) - for output_tokens in output_tokens_list - ] - texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my") - assert texts == EXPECTED_TEXT + texts = _normalize( + [ + tokenizer.decode( + output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE + ) + for output_tokens in output_tokens_list + ] + ) + for i, (got, expected) in enumerate(zip(texts, EXPECTED_TEXT)): + assert got == expected, ( + f"Output mismatch at index {i}:\n" + f" got: {got!r}\n" + f" expected: {expected!r}" + ) diff --git a/tests/utils.py b/tests/utils.py index d14c32e29548..df0025256c88 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -122,6 +122,12 @@ def _nvml(): if current_platform.is_rocm() else [] ) +# Python-API equivalent of ROCM_EXTRA_ARGS for use with EngineArgs kwargs. +ROCM_ENGINE_KWARGS: dict = ( + {"enable_prefix_caching": False, "max_num_seqs": 1} + if current_platform.is_rocm() + else {} +) class RemoteVLLMServer: