From b27df372941b2b7803cf4507220b7af68f68caf2 Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Thu, 9 Apr 2026 10:31:53 +0800 Subject: [PATCH 01/19] refactor conftest Signed-off-by: wangyu <410167048@qq.com> --- docs/contributing/ci/tests_markers.md | 4 +- docs/contributing/ci/tests_style.md | 2 +- tests/benchmarks/metrics/test_metrics.py | 67 + tests/benchmarks/patch/test_patch.py | 54 + tests/conftest.py | 3245 +---------------- tests/dfx/{conftest.py => helpers.py} | 2 +- tests/dfx/perf/scripts/run_benchmark.py | 4 +- .../scripts/test_benchmark_stability.py | 4 +- .../lora/{conftest.py => helpers.py} | 0 tests/diffusion/lora/test_lora_manager.py | 2 +- .../diffusion/models/bagel/test_bagel_lora.py | 2 +- .../quantization/test_quantization_quality.py | 2 +- .../diffusion/test_diffusion_step_pipeline.py | 2 +- tests/diffusion/test_stage_diffusion_proc.py | 10 + tests/e2e/accuracy/conftest.py | 15 +- tests/e2e/accuracy/helpers.py | 15 + tests/e2e/accuracy/test_gebench_h100_smoke.py | 4 +- .../accuracy/test_gedit_bench_h100_smoke.py | 4 +- .../test_wan22_i2v_video_similarity.py | 4 +- .../test_async_omni_collective_rpc.py | 2 +- .../test_async_omni_qwen_image_generate.py | 2 +- .../custom_pipeline/test_worker_extension.py | 2 +- .../offline_inference/test_bagel_img2img.py | 4 +- .../offline_inference/test_bagel_text2img.py | 4 +- .../test_bagel_understanding.py | 4 +- tests/e2e/offline_inference/test_cache_dit.py | 2 +- .../e2e/offline_inference/test_cosyvoice3.py | 4 +- .../test_diffusion_cpu_offload.py | 2 +- .../test_diffusion_layerwise_offload.py | 2 +- .../offline_inference/test_expert_parallel.py | 2 +- .../test_flux_autoround_w4a16.py | 2 +- .../e2e/offline_inference/test_magi_human.py | 2 +- .../offline_inference/test_mammoth_moda2.py | 2 +- tests/e2e/offline_inference/test_omnivoice.py | 2 +- .../e2e/offline_inference/test_ovis_image.py | 2 +- .../test_quantization_fp8.py | 2 +- .../offline_inference/test_qwen2_5_omni.py | 10 +- .../e2e/offline_inference/test_qwen3_omni.py | 8 +- .../offline_inference/test_qwen3_tts_base.py | 4 +- .../test_qwen3_tts_customvoice.py | 4 +- .../test_qwen_image_diffusion_batching.py | 2 +- .../test_sequence_parallel.py | 2 +- .../test_stable_audio_model.py | 2 +- tests/e2e/offline_inference/test_t2i_model.py | 2 +- tests/e2e/offline_inference/test_teacache.py | 2 +- .../e2e/offline_inference/test_voxtral_tts.py | 4 +- .../test_zimage_parallelism.py | 2 +- .../online_serving/test_bagel_expansion.py | 10 +- tests/e2e/online_serving/test_bagel_online.py | 4 +- .../e2e/online_serving/test_cosyvoice3_tts.py | 4 +- .../online_serving/test_flux2_expansion.py | 8 +- .../test_flux_2_dev_expansion.py | 10 +- .../test_flux_kontext_expansion.py | 10 +- .../test_hunyuan_video_15_expansion.py | 8 +- .../e2e/online_serving/test_image_gen_edit.py | 2 +- .../test_images_generations_lora.py | 4 +- .../test_longcat_image_edit_expansion.py | 12 +- .../test_longcat_image_expansion.py | 10 +- tests/e2e/online_serving/test_mimo_audio.py | 11 +- tests/e2e/online_serving/test_omnivoice.py | 4 +- tests/e2e/online_serving/test_qwen2_5_omni.py | 13 +- tests/e2e/online_serving/test_qwen3_omni.py | 13 +- .../test_qwen3_omni_expansion.py | 13 +- .../e2e/online_serving/test_qwen3_tts_base.py | 4 +- .../test_qwen3_tts_base_expansion.py | 4 +- .../online_serving/test_qwen3_tts_batch.py | 9 +- .../test_qwen3_tts_customvoice.py | 4 +- .../test_qwen3_tts_customvoice_expansion.py | 4 +- .../test_qwen3_tts_speaker_embedding.py | 4 +- .../test_qwen3_tts_websocket.py | 4 +- .../test_qwen_image_edit_expansion.py | 12 +- .../test_qwen_image_expansion.py | 10 +- .../test_qwen_image_layered_expansion.py | 13 +- .../e2e/online_serving/test_sd3_expansion.py | 8 +- .../test_video_generation_api.py | 4 +- tests/e2e/online_serving/test_voxtral_tts.py | 4 +- .../online_serving/test_wan22_expansion.py | 10 +- .../test_wan_2_1_vace_expansion.py | 8 +- .../online_serving/test_zimage_expansion.py | 8 +- tests/engine/test_async_omni_engine_abort.py | 2 +- tests/engine/test_orchestrator.py | 510 +++ tests/examples/conftest.py | 354 +- tests/examples/helpers.py | 353 ++ .../offline_inference/test_text_to_image.py | 6 +- .../online_serving/test_qwen2_5_omni.py | 7 +- .../online_serving/test_qwen3_omni.py | 7 +- .../online_serving/test_text_to_image.py | 7 +- tests/helpers/__init__.py | 8 + tests/helpers/assertions.py | 414 +++ tests/helpers/env.py | 207 ++ tests/helpers/fixtures/__init__.py | 1 + tests/helpers/fixtures/env.py | 56 + tests/helpers/fixtures/log.py | 7 + tests/helpers/fixtures/run_args.py | 16 + tests/helpers/fixtures/runtime.py | 80 + tests/helpers/mark.py | 135 + tests/helpers/media.py | 550 +++ .../utils.py => helpers/process.py} | 92 +- tests/helpers/runtime.py | 1024 ++++++ tests/helpers/stage_config.py | 193 + tests/utils.py | 621 ---- tools/pre_commit/check_pickle_imports.py | 3 +- vllm_omni/benchmarks/metrics/metrics.py | 2 +- vllm_omni/benchmarks/patch/patch.py | 4 + 104 files changed, 3939 insertions(+), 4492 deletions(-) create mode 100644 tests/benchmarks/metrics/test_metrics.py rename tests/dfx/{conftest.py => helpers.py} (98%) rename tests/diffusion/lora/{conftest.py => helpers.py} (100%) create mode 100644 tests/e2e/accuracy/helpers.py create mode 100644 tests/engine/test_orchestrator.py create mode 100644 tests/examples/helpers.py create mode 100644 tests/helpers/__init__.py create mode 100644 tests/helpers/assertions.py create mode 100644 tests/helpers/env.py create mode 100644 tests/helpers/fixtures/__init__.py create mode 100644 tests/helpers/fixtures/env.py create mode 100644 tests/helpers/fixtures/log.py create mode 100644 tests/helpers/fixtures/run_args.py create mode 100644 tests/helpers/fixtures/runtime.py create mode 100644 tests/helpers/mark.py create mode 100644 tests/helpers/media.py rename tests/{e2e/offline_inference/utils.py => helpers/process.py} (58%) create mode 100644 tests/helpers/runtime.py create mode 100644 tests/helpers/stage_config.py delete mode 100644 tests/utils.py diff --git a/docs/contributing/ci/tests_markers.md b/docs/contributing/ci/tests_markers.md index 7c1ba1c73bd..f6145e81160 100644 --- a/docs/contributing/ci/tests_markers.md +++ b/docs/contributing/ci/tests_markers.md @@ -38,7 +38,7 @@ Defined in `pyproject.toml`: ### Example usage for markers ```python -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test @pytest.mark.core_model @pytest.mark.omni @@ -105,7 +105,7 @@ This decorator is intended to make hardware-aware, cross-platform test authoring `hardware_marks` returns a list of pytest mark objects with the same signature as `@hardware_test`. Use it when you need more flexibility, such as attaching hardware marks to individual `pytest.param` entries rather than an entire test function. ```python -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks MULTI_CARD_MARKS = hardware_marks( res={"cuda": "H100", "rocm": "MI325", "npu": "A2"}, num_cards=2 diff --git a/docs/contributing/ci/tests_style.md b/docs/contributing/ci/tests_style.md index 8b10cf4cc1c..f32525264b3 100644 --- a/docs/contributing/ci/tests_style.md +++ b/docs/contributing/ci/tests_style.md @@ -428,7 +428,7 @@ from pathlib import Path import pytest from vllm.assets.video import VideoAsset -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from ..multi_stages.conftest import OmniRunner # Optional: set process start method for workers diff --git a/tests/benchmarks/metrics/test_metrics.py b/tests/benchmarks/metrics/test_metrics.py new file mode 100644 index 00000000000..f531a5026a3 --- /dev/null +++ b/tests/benchmarks/metrics/test_metrics.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Unit tests for metrics.py +""" + +import pytest +from vllm.benchmarks.serve import TaskType + +from vllm_omni.benchmarks.metrics.metrics import calculate_metrics +from vllm_omni.benchmarks.patch.patch import MixRequestFuncOutput + +pytestmark = [pytest.mark.core_model, pytest.mark.benchmark, pytest.mark.cpu] + + +def _make_output(prompt_len: int, output_tokens: int = 10) -> MixRequestFuncOutput: + """Build a minimal successful MixRequestFuncOutput for metrics aggregation.""" + output = MixRequestFuncOutput() + output.success = True + output.prompt_len = prompt_len + output.output_tokens = output_tokens + output.generated_text = "x" * output_tokens + output.ttft = 0.1 + output.text_latency = 1.0 + output.latency = 1.0 + output.start_time = 0.0 + output.itl = [0.1] * max(output_tokens - 1, 0) + output.audio_ttfp = 0.0 + output.audio_rtf = 0.0 + output.audio_duration = 0.0 + output.audio_frames = 0 + output.input_audio_duration = 0.0 + output.error = "" + return output + + +# ============================================================================ +# total_input Tests +# ============================================================================ + + +def test_total_input_aggregated_from_output_prompt_len(): + """Test that total_input sums outputs[i].prompt_len, not input_requests[i].prompt_len.""" + outputs = [_make_output(4992), _make_output(3000)] + + metrics, _ = calculate_metrics( + input_requests=[], + outputs=outputs, + dur_s=10.0, + tokenizer=None, + selected_percentiles=[99.0], + goodput_config_dict={}, + task_type=TaskType.GENERATION, + selected_percentile_metrics=[], + max_concurrency=None, + request_rate=float("inf"), + benchmark_duration=10.0, + ) + + assert metrics.total_input == 7992, ( + "total_input should aggregate from outputs[i].prompt_len to reflect the true multimodal input token count" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/benchmarks/patch/test_patch.py b/tests/benchmarks/patch/test_patch.py index 39b7f84fb49..35a18aea33c 100644 --- a/tests/benchmarks/patch/test_patch.py +++ b/tests/benchmarks/patch/test_patch.py @@ -574,5 +574,59 @@ async def test_text_latency_value_consistency(self, mocker: MockerFixture): ) +# ============================================================================ +# prompt_len Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_prompt_len_assigned_from_usage(mocker: MockerFixture): + # Arrange: request claims prompt_len=100, but server reports 4992 (multimodal). + request_input = RequestFuncInput( + model="test-model", + model_name="test-model", + prompt="test prompt", + api_url="http://test.com/v1/chat/completions", + prompt_len=100, + output_len=20, + ) + + chunks = [ + create_sse_chunk( + { + "choices": [{"delta": {"content": "Hello"}}], + "modality": "text", + } + ), + create_sse_chunk( + { + "choices": [{"delta": {"content": " world"}}], + "modality": "text", + } + ), + # Final usage chunk emitted because stream_options.include_usage=True. + create_sse_chunk( + { + "choices": [], + "usage": {"prompt_tokens": 4992, "completion_tokens": 2, "total_tokens": 4994}, + } + ), + b"data: [DONE]\n\n", + ] + + mock_response = MockResponse(200, chunks) + mock_session = mocker.AsyncMock() + mock_session.post = mocker.MagicMock(return_value=mock_response) + + # Act + output = await async_request_openai_chat_omni_completions(request_input, mock_session) + + # Assert + assert output.success is True + assert output.prompt_len == 4992, ( + "prompt_len should be overridden by usage.prompt_tokens to reflect the true multimodal input token count" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/tests/conftest.py b/tests/conftest.py index 8e9a7bf9280..685a7c663a1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3180 +1,65 @@ -import base64 -import datetime -import io -import json -import math -import os -import random -import re -import tempfile - -import requests - -os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" -# Set CPU device for CI environments without GPU -if "VLLM_TARGET_DEVICE" not in os.environ: - os.environ["VLLM_TARGET_DEVICE"] = "cpu" - -import concurrent.futures -import gc -import multiprocessing -import socket -import subprocess -import sys -import threading -import time -import uuid -from collections.abc import Generator -from dataclasses import dataclass -from io import BytesIO -from pathlib import Path -from typing import Any, NamedTuple - -import cv2 -import numpy as np -import psutil -import pytest -import soundfile as sf -import torch -import yaml -from openai import OpenAI, omit -from PIL import Image -from transformers import pipeline -from vllm import TextPrompt -from vllm.distributed.parallel_state import cleanup_dist_env_and_memory -from vllm.logger import init_logger -from vllm.utils.network_utils import get_open_port - -from vllm_omni.entrypoints.omni import Omni -from vllm_omni.inputs.data import OmniSamplingParams -from vllm_omni.outputs import OmniRequestOutput -from vllm_omni.platforms import current_omni_platform - -logger = init_logger(__name__) - -PromptAudioInput = list[tuple[Any, int]] | tuple[Any, int] | None -PromptImageInput = list[Any] | Any | None -PromptVideoInput = list[Any] | Any | None - -_GENDER_PIPELINE = None -# transformers.Pipeline is not thread-safe; concurrent e2e requests must serialize inference. -_GENDER_PIPELINE_LOCK = threading.Lock() - -# int16 mono PCM from /v1/audio/speech when response_format=pcm (Qwen3-TTS code2wav output rate). -_PCM_SPEECH_SAMPLE_RATE_HZ = 24_000 - - -class OmniServerParams(NamedTuple): - model: str - port: int | None = None - stage_config_path: str | None = None - server_args: list[str] | None = None - env_dict: dict[str, str] | None = None - use_omni: bool = True - - -def assert_image_diffusion_response( - response, - request_config: dict[str, Any], - run_level: str = None, -) -> None: - """ - Validate image diffusion response. - - Expected request_config schema: - { - "request_type": "image", - "extra_body": { - "num_outputs_per_prompt": 1, - "width": ..., - "height": ..., - ... - } - } - """ - assert response.images is not None, "Image response is None" - assert len(response.images) > 0, "No images in response" - - extra_body = request_config.get("extra_body") or {} - - num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt") - if num_outputs_per_prompt is not None: - assert len(response.images) == num_outputs_per_prompt, ( - f"Expected {num_outputs_per_prompt} images, got {len(response.images)}" - ) - - if run_level == "advanced_model": - width = extra_body.get("width") - height = extra_body.get("height") - - if width is not None or height is not None: - for img in response.images: - assert_image_valid(img, width=width, height=height) - - -def assert_video_diffusion_response( - response, - request_config: dict[str, Any], - run_level: str = None, -) -> None: - """ - Validate video diffusion response. - - Expected request_config schema: - { - "request_type": "video", - "form_data": { - "prompt": "...", - "num_frames": ..., - "width": ..., - "height": ..., - "fps": ..., - ... - } - } - """ - form_data = request_config.get("form_data", {}) - - assert response.videos is not None, "Video response is None" - assert len(response.videos) > 0, "No videos in response" - - expected_frames = _maybe_int(form_data.get("num_frames")) - expected_width = _maybe_int(form_data.get("width")) - expected_height = _maybe_int(form_data.get("height")) - expected_fps = _maybe_int(form_data.get("fps")) - - for vid_bytes in response.videos: - assert_video_valid( - vid_bytes, - num_frames=expected_frames, - width=expected_width, - height=expected_height, - fps=expected_fps, - ) - - -def assert_audio_diffusion_response( - response, - request_config: dict[str, Any], - run_level: str = None, -) -> None: - """ - Validate audio diffusion response. - """ - raise NotImplementedError("Audio validation is not implemented yet") - # consider using assert_audio_valid defined above - - -def _maybe_int(value: Any) -> int | None: - if value is None: - return None - return int(value) - - -def assert_image_valid(image: Path | Image.Image, *, width: int | None = None, height: int | None = None): - """Assert the file is a loadable image with optional exact dimensions.""" - if isinstance(image, Path): - assert image.exists(), f"Image not found: {image}" - image = Image.open(image) - image.load() - assert image.width > 0 and image.height > 0 - if width is not None: - assert image.width == width, f"Expected width={width}, got {image.width}" - if height is not None: - assert image.height == height, f"Expected height={height}, got {image.height}" - return image - - -def assert_video_valid( - video: Path | bytes | BytesIO, - *, - num_frames: int | None = None, - width: int | None = None, - height: int | None = None, - fps: float | None = None, -) -> dict[str, int | float]: - """Assert the MP4 has the expected resolution and exact frame count.""" - temp_path = None - cap = None - try: - # Normalize input to file path - if isinstance(video, Path): - if not video.exists(): - raise AssertionError(f"Video file not found: {video}") - video_path = str(video) - else: - # Create temp file for bytes/BytesIO - suffix = ".mp4" - with tempfile.NamedTemporaryFile(delete=False, suffix=suffix, mode="wb") as tmp: - if isinstance(video, bytes): - tmp.write(video) - elif isinstance(video, BytesIO): - tmp.write(video.getvalue()) - else: - raise TypeError(f"Unsupported video type: {type(video)}") - temp_path = Path(tmp.name) - video_path = str(temp_path) - - # Open video capture - cap = cv2.VideoCapture(video_path) - if not cap.isOpened(): - raise AssertionError(f"Failed to open video: {video_path}") - - # Extract properties - actual_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - actual_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - actual_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - actual_fps = cap.get(cv2.CAP_PROP_FPS) - - actual_num_frames = 0 - while True: - ok, _frame = cap.read() - if not ok: - break - actual_num_frames += 1 - - # Basic validity checks - if actual_num_frames <= 0: - raise AssertionError(f"Invalid frame count: {actual_num_frames} (must be > 0)") - if actual_width <= 0 or actual_height <= 0: - raise AssertionError(f"Invalid dimensions: {actual_width}x{actual_height} (must be > 0)") - if actual_fps <= 0: - raise AssertionError(f"Invalid FPS: {actual_fps} (must be > 0)") - - # Validate against expectations - if num_frames is not None: - expected_num_frames = (num_frames // 4) * 4 + 1 - assert actual_num_frames == expected_num_frames, ( - f"Frame count mismatch: expected {num_frames}, got {actual_num_frames}" - ) - if width is not None: - assert actual_width == width, f"Width mismatch: expected {width}px, got {actual_width}px" - if height is not None: - assert actual_height == height, f"Height mismatch: expected {height}px, got {actual_height}px" - if fps is not None: - # Use tolerance for float comparison (codec rounding) - assert abs(actual_fps - fps) < 0.5, f"FPS mismatch: expected {fps}, got {actual_fps:.2f}" - - return {"num_frames": actual_num_frames, "width": actual_width, "height": actual_height, "fps": actual_fps} - - except Exception as e: - print(f"ERROR: {type(e).__name__}: {e}", flush=True) - raise - - finally: - # Cleanup resources - if cap is not None: - cap.release() - if temp_path and temp_path.exists(): - try: - temp_path.unlink() - except OSError: - pass - - -def assert_audio_valid(path: Path, *, sample_rate: int, channels: int, duration_s: float) -> None: - """Assert the WAV has the expected sample rate, channel count, and duration.""" - assert path.exists(), f"Audio not found: {path}" - info = sf.info(str(path)) - assert info.samplerate == sample_rate, f"Expected sample_rate={sample_rate}, got {info.samplerate}" - assert info.channels == channels, f"Expected {channels} channel(s), got {info.channels}" - expected_frames = int(duration_s * sample_rate) - assert info.frames == expected_frames, ( - f"Expected {expected_frames} frames ({duration_s}s @ {sample_rate} Hz), got {info.frames}" - ) - - -def decode_b64_image(b64: str): - img = Image.open(BytesIO(base64.b64decode(b64))) - img.load() - return img - - -@pytest.fixture(scope="session") -def model_prefix() -> str: - """Optional model-path prefix from MODEL_PREFIX env var. - Useful if models are downloaded to non-default local directories. - """ - prefix = os.environ.get("MODEL_PREFIX", "") - return f"{prefix.rstrip('/')}/" if prefix else "" - - -@pytest.fixture(autouse=True) -def default_vllm_config(): - """Set a default VllmConfig for all tests. - - This fixture is auto-used for all tests to ensure that any test - that directly instantiates vLLM CustomOps (e.g., RMSNorm, LayerNorm) - or model components has the required VllmConfig context. - - This fixture is required for vLLM 0.14.0+ where CustomOp initialization - requires a VllmConfig context set via set_current_vllm_config(). - """ - from vllm.config import DeviceConfig, VllmConfig, set_current_vllm_config - - # Use CPU device if no GPU is available (e.g., in CI environments) - has_gpu = torch.cuda.is_available() and torch.cuda.device_count() > 0 - device = "cuda" if has_gpu else "cpu" - device_config = DeviceConfig(device=device) - - with set_current_vllm_config(VllmConfig(device_config=device_config)): - yield - - -@pytest.fixture(autouse=True) -def clean_gpu_memory_between_tests(): - print("\n=== PRE-TEST GPU CLEANUP ===") - _run_pre_test_cleanup() - yield - _run_post_test_cleanup() - - -@pytest.fixture(autouse=True) -def log_test_name_before_test(request): - print(f"--- Running test: {request.node.name}") - yield - - -def _run_pre_test_cleanup(enable_force=False): - if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force: - print("GPU cleanup disabled") - return - - print("Pre-test GPU status:") - - num_gpus = torch.cuda.device_count() - if num_gpus > 0: - try: - from tests.utils import wait_for_gpu_memory_to_clear - - wait_for_gpu_memory_to_clear( - devices=list(range(num_gpus)), - threshold_ratio=0.05, - ) - except Exception as e: - print(f"Pre-test cleanup note: {e}") - - -def _run_post_test_cleanup(enable_force=False): - if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force: - print("GPU cleanup disabled") - return - - if torch.cuda.is_available(): - gc.collect() - torch.cuda.empty_cache() - - print("Post-test GPU status:") - _print_gpu_processes() - - -def _print_gpu_processes(): - """Print GPU information including nvidia-smi and system processes""" - - print("\n" + "=" * 80) - print("NVIDIA GPU Information (nvidia-smi)") - print("=" * 80) - - try: - nvidia_result = subprocess.run( - ["nvidia-smi"], - capture_output=True, - text=True, - timeout=5, - ) - - if nvidia_result.returncode == 0: - lines = nvidia_result.stdout.strip().split("\n") - for line in lines[:20]: - print(line) - - if len(lines) > 20: - print(f"... (showing first 20 of {len(lines)} lines)") - else: - print("nvidia-smi command failed") - - except (subprocess.TimeoutExpired, FileNotFoundError): - print("nvidia-smi not available or timed out") - except Exception as e: - print(f"Error running nvidia-smi: {e}") - - print("\n" + "=" * 80) - print("Detailed GPU Processes (nvidia-smi pmon)") - print("=" * 80) - - try: - pmon_result = subprocess.run( - ["nvidia-smi", "pmon", "-c", "1"], - capture_output=True, - text=True, - timeout=3, - ) - - if pmon_result.returncode == 0 and pmon_result.stdout.strip(): - print(pmon_result.stdout) - else: - print("No active GPU processes found via nvidia-smi pmon") - - except Exception: - print("nvidia-smi pmon not available") - - print("\n" + "=" * 80) - print("System Processes with GPU keywords") - print("=" * 80) - - -def dummy_messages_from_mix_data( - system_prompt: dict[str, Any] = None, - video_data_url: Any = None, - audio_data_url: Any = None, - image_data_url: Any = None, - content_text: str = None, -): - """Create messages with video、image、audio data URL for OpenAI API.""" - - if content_text is not None: - content = [{"type": "text", "text": content_text}] - else: - content = [] - - media_items = [] - if isinstance(video_data_url, list): - for video_url in video_data_url: - media_items.append((video_url, "video")) - else: - media_items.append((video_data_url, "video")) - - if isinstance(image_data_url, list): - for url in image_data_url: - media_items.append((url, "image")) - else: - media_items.append((image_data_url, "image")) - - if isinstance(audio_data_url, list): - for url in audio_data_url: - media_items.append((url, "audio")) - else: - media_items.append((audio_data_url, "audio")) - - content.extend( - {"type": f"{media_type}_url", f"{media_type}_url": {"url": url}} - for url, media_type in media_items - if url is not None - ) - messages = [{"role": "user", "content": content}] - if system_prompt is not None: - messages = [system_prompt] + messages - return messages - - -def generate_synthetic_audio( - duration: int, # seconds - num_channels: int, # 1:Mono,2:Stereo 5:5.1 surround sound - sample_rate: int = 48000, # Default use 48000Hz. - save_to_file: bool = False, -) -> dict[str, Any]: - """ - Generate TTS speech with pyttsx3 and return base64 string. - """ - - import pyttsx3 - import soundfile as sf - - def _pick_voice(engine: pyttsx3.Engine) -> str | None: - voices = engine.getProperty("voices") - if not voices: - return None - - preferred_tokens = ( - "natural", - "jenny", - "sonia", - "susan", - "zira", - "aria", - "hazel", - "samantha", - "ava", - "allison", - "female", - "woman", - "english-us", - "en-us", - "english", - ) - discouraged_tokens = ( - "espeak", - "robot", - "mbrola", - "microsoft david", - "male", - "man", - ) - - best_voice = voices[0] - best_score = float("-inf") - for voice in voices: - voice_text = f"{getattr(voice, 'id', '')} {getattr(voice, 'name', '')}".lower() - voice_languages = " ".join( - lang.decode(errors="ignore") if isinstance(lang, bytes) else str(lang) - for lang in getattr(voice, "languages", []) - ).lower() - combined_text = f"{voice_text} {voice_languages}" - score = 0 - for idx, token in enumerate(preferred_tokens): - if token in combined_text: - score += 20 - idx - for token in discouraged_tokens: - if token in combined_text: - score -= 10 - if "english" in combined_text or "en_" in combined_text or "en-" in combined_text: - score += 4 - if "en-us" in combined_text or "english-us" in combined_text: - score += 4 - if score > best_score: - best_score = score - best_voice = voice - - return best_voice.id - - def _resample_audio(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray: - if src_sr == dst_sr or len(audio) == 0: - return audio.astype(np.float32) - - src_len = audio.shape[0] - dst_len = max(1, int(round(src_len * float(dst_sr) / float(src_sr)))) - src_idx = np.arange(src_len, dtype=np.float32) - dst_idx = np.linspace(0, src_len - 1, dst_len, dtype=np.float32) - - resampled_channels: list[np.ndarray] = [] - for ch in range(audio.shape[1]): - resampled_channels.append(np.interp(dst_idx, src_idx, audio[:, ch]).astype(np.float32)) - return np.stack(resampled_channels, axis=1) - - def _match_channels(audio: np.ndarray, target_channels: int) -> np.ndarray: - current_channels = audio.shape[1] - if current_channels == target_channels: - return audio.astype(np.float32) - if target_channels == 1: - return np.mean(audio, axis=1, keepdims=True, dtype=np.float32) - if current_channels == 1: - return np.repeat(audio, target_channels, axis=1).astype(np.float32) - - collapsed = np.mean(audio, axis=1, keepdims=True, dtype=np.float32) - return np.repeat(collapsed, target_channels, axis=1).astype(np.float32) - - def _trim_silence(audio: np.ndarray, threshold: float = 0.01) -> np.ndarray: - if len(audio) == 0: - return audio - energy = np.max(np.abs(audio), axis=1) - voiced = np.where(energy > threshold)[0] - if len(voiced) == 0: - return audio - start = max(0, int(voiced[0]) - int(sample_rate * 0.02)) - end = min(len(audio), int(voiced[-1]) + int(sample_rate * 0.04) + 1) - return audio[start:end] - - def _enhance_speech(audio: np.ndarray) -> np.ndarray: - if len(audio) == 0: - return audio.astype(np.float32) - enhanced = audio.astype(np.float32).copy() - enhanced -= np.mean(enhanced, axis=0, keepdims=True, dtype=np.float32) - if len(enhanced) > 1: - preemphasis = enhanced.copy() - preemphasis[1:] = enhanced[1:] - 0.94 * enhanced[:-1] - enhanced = 0.7 * enhanced + 0.3 * preemphasis - # Mild dynamic-range compression for ASR/TTS robustness. - enhanced = np.sign(enhanced) * np.sqrt(np.abs(enhanced)) - # Light fade to avoid clicks after trimming/repeating. - fade = min(len(enhanced) // 4, max(1, int(sample_rate * 0.01))) - if fade > 1: - ramp_in = np.linspace(0.0, 1.0, fade, dtype=np.float32) - ramp_out = np.linspace(1.0, 0.0, fade, dtype=np.float32) - enhanced[:fade] *= ramp_in[:, None] - enhanced[-fade:] *= ramp_out[:, None] - peak = float(np.max(np.abs(enhanced))) - if peak > 1e-8: - enhanced = enhanced / peak * 0.95 - return enhanced.astype(np.float32) - - phrase_text = "test" - num_samples = int(sample_rate * max(1, duration)) - audio_data = np.zeros((num_samples, num_channels), dtype=np.float32) - - engine = pyttsx3.init() - engine.setProperty("rate", 112) - engine.setProperty("volume", 1.0) - selected_voice = _pick_voice(engine) - if selected_voice is not None: - engine.setProperty("voice", selected_voice) - - temp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) - temp_wav.close() - - try: - engine.save_to_file(phrase_text, temp_wav.name) - engine.runAndWait() - engine.stop() - - ready = False - for _ in range(50): - if os.path.exists(temp_wav.name) and os.path.getsize(temp_wav.name) > 44: - ready = True - break - time.sleep(0.1) - - if not ready: - raise RuntimeError("pyttsx3 did not produce a WAV file in time.") - - tts_audio, tts_sr = sf.read(temp_wav.name, dtype="float32", always_2d=True) - finally: - if os.path.exists(temp_wav.name): - os.unlink(temp_wav.name) - - if len(tts_audio) == 0: - raise RuntimeError("pyttsx3 produced an empty WAV file.") - - tts_audio = _resample_audio(tts_audio, tts_sr, sample_rate) - tts_audio = _match_channels(tts_audio, num_channels) - tts_audio = _trim_silence(tts_audio, threshold=0.012) - tts_audio = _enhance_speech(tts_audio) - - lead_silence = min(int(sample_rate * 0.02), num_samples // 8) - pause_samples = int(sample_rate * 0.18) - start = lead_silence - phrase_len = tts_audio.shape[0] - - while start < num_samples: - take = min(phrase_len, num_samples - start) - audio_data[start : start + take] = tts_audio[:take] - start += phrase_len + pause_samples - - max_amp = float(np.max(np.abs(audio_data))) - if max_amp > 0: - audio_data = audio_data / max_amp * 0.95 - - audio_bytes: bytes | None = None - output_path: str | None = None - result: dict[str, Any] = { - "np_array": audio_data.copy(), - } - - if save_to_file: - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - output_path = f"audio_{num_channels}ch_{timestamp}.wav" - - try: - sf.write(output_path, audio_data, sample_rate, format="WAV", subtype="PCM_16") - print(f"Audio saved: {output_path}") - - with open(output_path, "rb") as f: - audio_bytes = f.read() - except Exception as e: - print(f"Save failed: {e}") - save_to_file = False - - # If not saving or save failed, create in memory - if not save_to_file or audio_bytes is None: - buffer = io.BytesIO() - sf.write(buffer, audio_data, sample_rate, format="WAV", subtype="PCM_16") - buffer.seek(0) - audio_bytes = buffer.read() - - # Return result - base64_audio = base64.b64encode(audio_bytes).decode("utf-8") - result["base64"] = base64_audio - # Always include file_path to avoid KeyError in callers. - result["file_path"] = output_path if save_to_file and output_path else None - - return result - - -def _mux_mp4_bytes_with_synthetic_audio( - video_mp4_bytes: bytes, - *, - num_frames: int, - fps: float = 30.0, - sample_rate: int = 48000, -) -> bytes: - """ - Mux a video-only MP4 with mono TTS audio from :func:`generate_synthetic_audio` (AAC). - - Audio length is at least the video duration in whole seconds (rounded up); ffmpeg - ``-shortest`` trims to the video when the WAV is longer. - - Uses ffmpeg from ``imageio_ffmpeg`` when available, else ``ffmpeg`` on PATH. - If TTS or mux fails, returns ``video_mp4_bytes`` unchanged. - - Mux subprocess does **not** use ``capture_output=True``: ffmpeg can block writing - to a full stderr pipe while :func:`subprocess.run` waits for exit (classic deadlock). - """ - duration_sec = num_frames / fps if fps > 0 else 0.0 - # generate_synthetic_audio(duration=int) uses at least 1s of buffer internally - duration_int = max(1, int(math.ceil(duration_sec))) - - try: - audio_result = generate_synthetic_audio( - duration=duration_int, - num_channels=1, - sample_rate=sample_rate, - save_to_file=False, - ) - audio_pcm = audio_result["np_array"] - except Exception as e: - logger.warning("Synthetic video: generate_synthetic_audio failed (%s); using video-only MP4.", e) - return video_mp4_bytes - - try: - import imageio_ffmpeg - - ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe() - except Exception: - ffmpeg_exe = "ffmpeg" - - import tempfile - - try: - with tempfile.TemporaryDirectory(prefix="syn_vid_mux_") as tmp: - vid_path = os.path.join(tmp, "video.mp4") - wav_path = os.path.join(tmp, "audio.wav") - out_path = os.path.join(tmp, "out.mp4") - with open(vid_path, "wb") as f: - f.write(video_mp4_bytes) - sf.write(wav_path, audio_pcm, sample_rate, format="WAV", subtype="PCM_16") - cmd = [ - ffmpeg_exe, - "-y", - "-nostdin", - "-hide_banner", - "-loglevel", - "error", - "-i", - vid_path, - "-i", - wav_path, - "-c:v", - "copy", - "-c:a", - "aac", - "-b:a", - "128k", - "-shortest", - "-movflags", - "+faststart", - out_path, - ] - subprocess.run( - cmd, - check=True, - stdin=subprocess.DEVNULL, - timeout=300, - ) - with open(out_path, "rb") as f: - return f.read() - except ( - FileNotFoundError, - subprocess.CalledProcessError, - subprocess.TimeoutExpired, - OSError, - ) as e: - logger.warning("Synthetic video: audio mux failed (%s); using video-only MP4.", e) - return video_mp4_bytes - - -def generate_synthetic_video( - width: int, - height: int, - num_frames: int, - save_to_file: bool = False, - *, - embed_audio: bool = False, -) -> dict[str, Any]: - """Generate synthetic video with bouncing balls and base64 MP4. - - When ``embed_audio`` is True, muxes mono AAC from :func:`generate_synthetic_audio` - (TTS + ffmpeg) into the MP4; otherwise returns video-only MP4 (faster when tests do - not need an audio track). - """ - - import cv2 - import imageio - - # Create random balls - num_balls = random.randint(3, 8) - balls = [] - - for _ in range(num_balls): - radius = min(width, height) // 8 - if radius < 1: - raise ValueError(f"Video dimensions ({width}x{height}) are too small for synthetic video generation") - x = random.randint(radius, width - radius) - y = random.randint(radius, height - radius) - - speed = random.uniform(3.0, 8.0) - angle = random.uniform(0, 2 * math.pi) - vx = speed * math.cos(angle) - vy = speed * math.sin(angle) - - # OpenCV uses BGR format, but imageio expects RGB - # We'll create in BGR first, then convert to RGB later - color_bgr = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255)) - - balls.append({"x": x, "y": y, "vx": vx, "vy": vy, "radius": radius, "color_bgr": color_bgr}) - - # Generate video frames - video_frames = [] - - for frame_idx in range(num_frames): - # Create black background (BGR format) - frame_bgr = np.zeros((height, width, 3), dtype=np.uint8) - - for ball in balls: - # Update position - ball["x"] += ball["vx"] - ball["y"] += ball["vy"] - - # Boundary collision detection - if ball["x"] - ball["radius"] <= 0 or ball["x"] + ball["radius"] >= width: - ball["vx"] = -ball["vx"] - ball["x"] = max(ball["radius"], min(width - ball["radius"], ball["x"])) - - if ball["y"] - ball["radius"] <= 0 or ball["y"] + ball["radius"] >= height: - ball["vy"] = -ball["vy"] - ball["y"] = max(ball["radius"], min(height - ball["radius"], ball["y"])) - - # Use cv2 to draw circle - x, y = int(ball["x"]), int(ball["y"]) - radius = ball["radius"] - - # Draw solid circle (main circle) - cv2.circle(frame_bgr, (x, y), radius, ball["color_bgr"], -1) - - # Add simple 3D effect: draw a brighter center - if radius > 3: # Only add highlight when radius is large enough - highlight_radius = max(1, radius // 2) - highlight_x = max(highlight_radius, min(x - radius // 4, width - highlight_radius)) - highlight_y = max(highlight_radius, min(y - radius // 4, height - highlight_radius)) - - # Create highlight color (brighter) - highlight_color = tuple(min(c + 40, 255) for c in ball["color_bgr"]) - cv2.circle(frame_bgr, (highlight_x, highlight_y), highlight_radius, highlight_color, -1) - - # Convert BGR to RGB for imageio - frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) - video_frames.append(frame_rgb) - - video_array = np.array(video_frames) - result = { - "np_array": video_array, - } - saved_file_path = None - - fps = 30 - buffer = io.BytesIO() - writer_kwargs = { - "format": "mp4", - "fps": fps, - "codec": "libx264", - "quality": 7, - "pixelformat": "yuv420p", - "macro_block_size": 16, - "ffmpeg_params": [ - "-preset", - "medium", - "-crf", - "23", - "-movflags", - "+faststart", - "-pix_fmt", - "yuv420p", - "-vf", - f"scale={width}:{height}", - ], - } - - try: - with imageio.get_writer(buffer, **writer_kwargs) as writer: - for frame in video_frames: - writer.append_data(frame) - buffer.seek(0) - video_only_bytes = buffer.read() - except Exception as e: - print(f"Warning: Failed to encode synthetic video: {e}") - raise - - if embed_audio: - video_bytes = _mux_mp4_bytes_with_synthetic_audio(video_only_bytes, num_frames=num_frames, fps=float(fps)) - else: - video_bytes = video_only_bytes - - if save_to_file: - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - output_path = f"video_{width}x{height}_{timestamp}.mp4" - try: - with open(output_path, "wb") as f: - f.write(video_bytes) - saved_file_path = output_path - print(f"Video saved to: {saved_file_path}") - except Exception as e: - print(f"Warning: Failed to save video to file {output_path}: {e}") - - base64_video = base64.b64encode(video_bytes).decode("utf-8") - - result["base64"] = base64_video - if save_to_file and saved_file_path: - result["file_path"] = saved_file_path - - return result - - -def generate_synthetic_image(width: int, height: int, save_to_file: bool = False) -> dict[str, Any]: - """Generate synthetic image with randomly colored squares and return base64 string.""" - from PIL import Image, ImageDraw - - # Create white background - image = Image.new("RGB", (width, height), (255, 255, 255)) - draw = ImageDraw.Draw(image) - - # Generate random number of squares - num_squares = random.randint(3, 8) - - for _ in range(num_squares): - # Random square size - square_size = random.randint(min(width, height) // 8, min(width, height) // 4) - - # Random position - x = random.randint(0, width - square_size - 1) - y = random.randint(0, height - square_size - 1) - - # Random color - color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) - - # Random border width - border_width = random.randint(1, 5) - - # Draw square - draw.rectangle([x, y, x + square_size, y + square_size], fill=color, outline=(0, 0, 0), width=border_width) - - image_array = np.array(image) - result = {"np_array": image_array.copy()} - - # Handle file saving - image_bytes = None - saved_file_path = None - - if save_to_file: - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - output_path = f"image_{width}x{height}_{timestamp}.jpg" - - try: - # Save image to file - image.save(output_path, format="JPEG", quality=85, optimize=True) - saved_file_path = output_path - print(f"Image saved to: {saved_file_path}") - - # Read file for base64 encoding - with open(output_path, "rb") as f: - image_bytes = f.read() - - except Exception as e: - print(f"Warning: Failed to save image to file {output_path}: {e}") - save_to_file = False - - # If not saving or save failed, create in memory - if not save_to_file or image_bytes is None: - buffer = io.BytesIO() - image.save(buffer, format="JPEG", quality=85, optimize=True) - buffer.seek(0) - image_bytes = buffer.read() - - # Generate base64 - base64_image = base64.b64encode(image_bytes).decode("utf-8") - - # Return result - result["base64"] = base64_image - if save_to_file and saved_file_path: - result["file_path"] = saved_file_path - - return result - - -def preprocess_text(text): - import opencc - - word_to_num = { - "zero": "0", - "one": "1", - "two": "2", - "three": "3", - "four": "4", - "five": "5", - "six": "6", - "seven": "7", - "eight": "8", - "nine": "9", - "ten": "10", - } - - for word, num in word_to_num.items(): - pattern = r"\b" + re.escape(word) + r"\b" - text = re.sub(pattern, num, text, flags=re.IGNORECASE) - - text = re.sub(r"[^\w\s]", "", text) - text = re.sub(r"\s+", " ", text) - cc = opencc.OpenCC("t2s") - text = cc.convert(text) - - # Special handling for spaces between Chinese characters: - # - Keep single spaces between English words/numbers - # - Remove spaces only when surrounded by Chinese characters on both sides to prevent incorrect word segmentation - text = re.sub(r"(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])", "", text) - - return text.lower().strip() - - -def cosine_similarity_text(text1, text2, n: int = 3): - from collections import Counter - - if not text1 or not text2: - return 0.0 - - text1 = preprocess_text(text1) - text2 = preprocess_text(text2) - print(f"cosine similarity text1 is: {text1}, text2 is: {text2}") - - ngrams1 = [text1[i : i + n] for i in range(len(text1) - n + 1)] - ngrams2 = [text2[i : i + n] for i in range(len(text2) - n + 1)] - - counter1 = Counter(ngrams1) - counter2 = Counter(ngrams2) - - all_ngrams = set(counter1.keys()) | set(counter2.keys()) - vec1 = [counter1.get(ng, 0) for ng in all_ngrams] - vec2 = [counter2.get(ng, 0) for ng in all_ngrams] - - dot_product = sum(a * b for a, b in zip(vec1, vec2)) - norm1 = sum(a * a for a in vec1) ** 0.5 - norm2 = sum(b * b for b in vec2) ** 0.5 - - if norm1 == 0 or norm2 == 0: - return 0.0 - return dot_product / (norm1 * norm2) - - -def convert_audio_to_text(audio_data): - """ - Convert base64 encoded audio data to text using speech recognition. - """ - audio_data = base64.b64decode(audio_data) - output_path = f"./test_{uuid.uuid4().hex}.wav" - with open(output_path, "wb") as audio_file: - audio_file.write(audio_data) - - print(f"audio data is saved: {output_path}") - text = convert_audio_file_to_text(output_path=output_path) - return text - - -def _merge_base64_audio_to_segment(base64_list: list[str]): - """Merge a list of base64-encoded audio chunks into one pydub AudioSegment.""" - from pydub import AudioSegment - - merged = None - for b64 in base64_list: - raw = base64.b64decode(b64.split(",", 1)[-1]) - seg = AudioSegment.from_file(io.BytesIO(raw)) - merged = seg if merged is None else merged + seg - return merged - - -def _whisper_transcribe_in_current_process(output_path: str) -> str: - import whisper - - # Multi-GPU: use last visible device to avoid colliding with default device 0; single device uses 0. - device_index = None - if current_omni_platform.is_available(): - n = current_omni_platform.get_device_count() - if n == 1: - device_index = 0 - elif n > 1: - device_index = n - 1 - - if device_index is not None: - torch_device = current_omni_platform.get_torch_device(device_index) - current_omni_platform.set_device(torch_device) - device = str(torch_device) - use_accelerator = True - else: - use_accelerator = False - device = "cpu" - model = whisper.load_model("small", device=device) - try: - text = model.transcribe( - output_path, - temperature=0.0, - word_timestamps=True, - condition_on_previous_text=False, - )["text"] - finally: - del model - gc.collect() - if use_accelerator: - current_omni_platform.synchronize() - current_omni_platform.empty_cache() - - return text or "" - - -def convert_audio_file_to_text(output_path: str) -> str: - """Convert an audio file to text in an isolated subprocess.""" - # Import locally to avoid impacting test module import time. - ctx = multiprocessing.get_context("spawn") - with concurrent.futures.ProcessPoolExecutor(max_workers=1, mp_context=ctx) as executor: - future = executor.submit(_whisper_transcribe_in_current_process, output_path) - return future.result() - - -def convert_audio_bytes_to_text(raw_bytes: bytes) -> str: - """ - Write container audio bytes (WAV, etc.) to a temp WAV file suitable for Whisper/ffmpeg. - Normalizes with soundfile to PCM_16 WAV when possible to avoid codec issues. - """ - output_path = f"./test_{uuid.uuid4().hex}.wav" - data, samplerate = sf.read(io.BytesIO(raw_bytes)) - sf.write(output_path, data, samplerate, format="WAV", subtype="PCM_16") - text = convert_audio_file_to_text(output_path) - return text - - -def modify_stage_config( - yaml_path: str, - updates: dict[str, Any] = None, - deletes: dict[str, Any] = None, -) -> str: - """ - Modify configurations in a YAML file, supporting both top-level and stage-specific modifications, - including addition, modification, and deletion of configurations. - - Args: - yaml_path: Path to the YAML configuration file. - updates: Dictionary containing both top-level and stage-specific modifications to add or update. - Format: { - 'async_chunk': True, - 'stage_args': { - 0: {'engine_args.max_model_len': 5800}, - 1: {'engine_args.max_num_seqs': 2} - } - } - deletes: Dictionary containing configurations to delete. - Format: { - 'old_config': None, # Delete entire key - 'stage_args': { - 0: ['engine_args.old_param'], - 1: ['runtime.unused_setting'] - } - } - - Returns: - str: Path to the newly created modified YAML file with timestamp suffix. - """ - path = Path(yaml_path) - if not path.exists(): - raise FileNotFoundError(f"yaml does not exist: {path}") - - try: - with open(yaml_path, encoding="utf-8") as f: - config = yaml.safe_load(f) or {} - except Exception as e: - raise ValueError(f"Cannot parse YAML file: {e}") - - # Helper function to apply update - def apply_update(config_dict: dict, key_path: str, value: Any) -> None: - """Apply update to dictionary using dot-separated path.""" - # Handle direct list assignment (e.g., engine_input_source: [1, 2]) - if "." not in key_path: - # Simple key, set directly - config_dict[key_path] = value - return - - current = config_dict - keys = key_path.split(".") - - for i in range(len(keys) - 1): - key = keys[i] - - # Handle list indices - if key.isdigit() and isinstance(current, list): - index = int(key) - if index < 0: - raise ValueError(f"Negative list index not allowed: {index}") - if index >= len(current): - # Expand list if needed - while len(current) <= index: - # If we need to go deeper (more keys after this), create a dict - # Otherwise, create None placeholder - current.append({} if i < len(keys) - 2 else None) - current = current[index] - elif isinstance(current, dict): - # Handle dictionary keys - if key not in current: - # If there are more keys after this, create appropriate structure - if i < len(keys) - 1: - # Check if next key is a digit (list index) or string (dict key) - if keys[i + 1].isdigit(): - current[key] = [] - else: - current[key] = {} - else: - # This is the last key, create based on value type - current[key] = [] if isinstance(value, list) else {} - elif not isinstance(current[key], (dict, list)) and i < len(keys) - 1: - # If current value is not dict/list but we need to go deeper, replace it - if keys[i + 1].isdigit(): - current[key] = [] - else: - current[key] = {} - current = current[key] - else: - # Current is not a dict or list, cannot traverse further - raise TypeError( - f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}" - ) - - # Set the final value - last_key = keys[-1] - if isinstance(current, list) and last_key.isdigit(): - # Setting a value in a list by index - index = int(last_key) - if index < 0: - raise ValueError(f"Negative list index not allowed: {index}") - if index >= len(current): - # Expand list if needed - while len(current) <= index: - current.append(None) - current[index] = value - elif isinstance(current, dict): - # Special case: if the value is a list and we're setting a top-level key - # Example: updating engine_input_source with [1, 2] - current[last_key] = value - else: - # Current is not a dict, cannot set key - raise TypeError(f"Cannot set value at {key_path}. Current type is {type(current).__name__}, expected dict.") - - # Helper function to delete by path - def delete_by_path(config_dict: dict, path: str) -> None: - """Delete configuration by dot-separated path.""" - if not path: - return - - current = config_dict - keys = path.split(".") - - # Traverse to the parent - for i in range(len(keys) - 1): - key = keys[i] - - # Handle list indices - if key.isdigit() and isinstance(current, list): - index = int(key) - if index < 0 or index >= len(current): - raise KeyError(f"List index {index} out of bounds") - current = current[index] - elif isinstance(current, dict): - if key not in current: - raise KeyError(f"Path {'.'.join(keys[: i + 1])} does not exist") - current = current[key] - else: - raise TypeError( - f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}" - ) - - # Delete the item - last_key = keys[-1] - - if isinstance(current, list) and last_key.isdigit(): - index = int(last_key) - if index < 0 or index >= len(current): - raise KeyError(f"List index {index} out of bounds") - del current[index] - elif isinstance(current, dict) and last_key in current: - del current[last_key] - else: - print(f"Path {path} does not exist") - - # Apply deletions first - if deletes: - for key, value in deletes.items(): - if key == "stage_args": - if value and isinstance(value, dict): - stage_args = config.get("stage_args", []) - if not stage_args: - raise ValueError("stage_args does not exist in config") - - for stage_id, delete_paths in value.items(): - if not delete_paths: - continue - - # Find stage by ID - target_stage = None - for stage in stage_args: - if stage.get("stage_id") == int(stage_id): - target_stage = stage - break - - if target_stage is None: - continue - - # Delete specified paths in this stage - for path in delete_paths: - if path: # Skip empty paths - delete_by_path(target_stage, path) - elif "." in key: - # Delete using dot-separated path - delete_by_path(config, key) - elif value is None and key in config: - # Delete entire key - del config[key] - - # Apply updates - if updates: - for key, value in updates.items(): - if key == "stage_args": - if value and isinstance(value, dict): - stage_args = config.get("stage_args", []) - if not stage_args: - raise ValueError("stage_args does not exist in config") - - for stage_id, stage_updates in value.items(): - # Find stage by ID - target_stage = None - for stage in stage_args: - if stage.get("stage_id") == int(stage_id): - target_stage = stage - break - - if target_stage is None: - available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s] - raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}") - - # Apply updates to this stage - for path, val in stage_updates.items(): - # Check if this is a simple key (not dot-separated) - # Example: 'engine_input_source' vs 'engine_args.max_model_len' - if "." not in path: - # Direct key assignment (e.g., updating a list value) - target_stage[path] = val - else: - # Dot-separated path (e.g., nested dict access) - apply_update(target_stage, path, val) - elif "." in key: - # Apply using dot-separated path - apply_update(config, key, value) - else: - # Direct top-level key - config[key] = value - - # Unique suffix: multiple modify_stage_config calls in one process often run - # within the same second (e.g. test_qwen3_omni_expansion imports both - # get_chunk_config and get_batch_token_config). int(time.time()) would collide - # and the later write would overwrite the earlier YAML on disk. - base_name = yaml_path.rsplit(".", 1)[0] if "." in yaml_path else yaml_path - output_path = f"{base_name}_{time.time_ns()}.yaml" - - with open(output_path, "w", encoding="utf-8") as f: - yaml.dump(config, f, default_flow_style=None, sort_keys=False, allow_unicode=True, indent=2) - - return output_path - - -class OmniServer: - """Omniserver for vLLM-Omni tests.""" - - def __init__( - self, - model: str, - serve_args: list[str], - *, - port: int | None = None, - env_dict: dict[str, str] | None = None, - use_omni: bool = True, - ) -> None: - _run_pre_test_cleanup(enable_force=True) - _run_post_test_cleanup(enable_force=True) - cleanup_dist_env_and_memory() - self.model = model - self.serve_args = serve_args - self.env_dict = env_dict - self.use_omni = use_omni - self.proc: subprocess.Popen | None = None - self.host = "127.0.0.1" - if port is None: - self.port = get_open_port() - else: - self.port = port - - def _start_server(self) -> None: - """Start the vLLM-Omni server subprocess.""" - env = os.environ.copy() - env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - if self.env_dict is not None: - env.update(self.env_dict) - - cmd = [ - sys.executable, - "-m", - "vllm_omni.entrypoints.cli.main", - "serve", - self.model, - "--host", - self.host, - "--port", - str(self.port), - ] - if self.use_omni: - cmd.append("--omni") - cmd += self.serve_args - - print(f"Launching OmniServer with: {' '.join(cmd)}") - self.proc = subprocess.Popen( - cmd, - env=env, - cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # Set working directory to vllm-omni root - ) - - # Wait for server to be ready - max_wait = 1200 # 20 minutes - start_time = time.time() - while time.time() - start_time < max_wait: - # Check for process status - ret = self.proc.poll() - if ret is not None: - raise RuntimeError(f"Server processes exited with code {ret} before becoming ready.") - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.settimeout(1) - result = sock.connect_ex((self.host, self.port)) - if result == 0: - print(f"Server ready on {self.host}:{self.port}") - return - time.sleep(2) - - raise RuntimeError(f"Server failed to start within {max_wait} seconds") - - def _kill_process_tree(self, pid): - """kill process and its children with verification""" - try: - parent = psutil.Process(pid) - children = parent.children(recursive=True) - - # Get all PIDs first - all_pids = [pid] + [child.pid for child in children] - - # Terminate children - for child in children: - try: - child.terminate() - except psutil.NoSuchProcess: - pass - - # Wait for children - gone, still_alive = psutil.wait_procs(children, timeout=10) - - # Kill remaining children - for child in still_alive: - try: - child.kill() - except psutil.NoSuchProcess: - pass - - # Terminate parent - try: - parent.terminate() - parent.wait(timeout=10) - except (psutil.NoSuchProcess, psutil.TimeoutExpired): - try: - parent.kill() - except psutil.NoSuchProcess: - pass - - # VERIFICATION: Check if all processes are gone - time.sleep(1) # Give system time - alive_processes = [] - for check_pid in all_pids: - if psutil.pid_exists(check_pid): - alive_processes.append(check_pid) - - if alive_processes: - print(f"Warning: Processes still alive: {alive_processes}") - # Optional: Try system kill - import subprocess - - for alive_pid in alive_processes: - try: - subprocess.run(["kill", "-9", str(alive_pid)], timeout=2) - except Exception as e: - print(f"Cleanup failed: {e}") - - except psutil.NoSuchProcess: - pass - - def __enter__(self): - self._start_server() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.proc: - self._kill_process_tree(self.proc.pid) - _run_pre_test_cleanup(enable_force=True) - _run_post_test_cleanup(enable_force=True) - cleanup_dist_env_and_memory() - - -def pytest_addoption(parser): - parser.addoption( - "--run-level", - action="store", - default="core_model", - choices=["core_model", "advanced_model"], - help="Test level to run: L2, L3", - ) - - -@pytest.fixture(scope="session") -def run_level(request) -> str: - """A command-line argument that specifies the level of tests to run in this session. - See https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/ci/CI_5levels/""" - return request.config.getoption("--run-level") - - -_omni_server_lock = threading.Lock() - - -@pytest.fixture(scope="module") -def omni_server(request: pytest.FixtureRequest, run_level: str, model_prefix: str) -> Generator[OmniServer, Any, None]: - """Start vLLM-Omni server as a subprocess with actual model weights. - Uses session scope so the server starts only once for the entire test session. - Multi-stage initialization can take 10-20+ minutes. - """ - with _omni_server_lock: - params: OmniServerParams = request.param - model = model_prefix + params.model - port = params.port - stage_config_path = params.stage_config_path - if run_level == "advanced_model" and stage_config_path is not None: - with open(stage_config_path, encoding="utf-8") as f: - cfg = yaml.safe_load(f) or {} - stage_ids = [stage["stage_id"] for stage in cfg.get("stage_args", []) if "stage_id" in stage] - stage_config_path = modify_stage_config( - stage_config_path, - deletes={"stage_args": {stage_id: ["engine_args.load_format"] for stage_id in stage_ids}}, - ) - - server_args = params.server_args or [] - if params.use_omni: - server_args = ["--stage-init-timeout", "120", *server_args] - if stage_config_path is not None: - server_args += ["--stage-configs-path", stage_config_path] - - with ( - OmniServer( - model, - server_args, - port=port, - env_dict=params.env_dict, - use_omni=params.use_omni, - ) - if port - else OmniServer( - model, - server_args, - env_dict=params.env_dict, - use_omni=params.use_omni, - ) - ) as server: - print("OmniServer started successfully") - yield server - print("OmniServer stopping...") - - print("OmniServer stopped") - - -@dataclass -class OmniResponse: - text_content: str | None = None - audio_data: list[str] | None = None - audio_content: str | None = None - audio_format: str | None = None - audio_bytes: bytes | None = None - similarity: float | None = None - e2e_latency: float | None = None - success: bool = False - error_message: str | None = None - - -@dataclass -class DiffusionResponse: - text_content: str | None = None - images: list[Image.Image] | None = None - audios: list[Any] | None = None - videos: list[Any] | None = None - e2e_latency: float | None = None - success: bool = False - error_message: str | None = None - - -def _load_gender_pipeline(): - """ - Lazy-load a cached audio-classification pipeline for gender. - - We prefer the pipeline wrapper because it encapsulates processor/model loading - and avoids direct AutoProcessor.from_pretrained call sites in this file. - """ - global _GENDER_PIPELINE - if _GENDER_PIPELINE is not None: - return _GENDER_PIPELINE - - model_name = "7wolf/wav2vec2-base-gender-classification" - try: - # device=-1 forces CPU for pipeline. - _GENDER_PIPELINE = pipeline( - task="audio-classification", - model=model_name, - device=-1, - ) - return _GENDER_PIPELINE - except Exception as exc: # pragma: no cover - best-effort fallback - print(f"Warning: failed to create gender pipeline '{model_name}': {exc}") - _GENDER_PIPELINE = None - return None - - -def _median_pitch_hz_from_autocorr(mono: np.ndarray, sr: int) -> float | None: - """ - Rough median F0 (Hz) over short-time frames. Used to debias wav2vec2 gender head on TTS, - which often labels lower-pitched synthetic speech as female under load or on clean signals. - Returns None if the clip is too short or mostly unvoiced. - """ - x = np.asarray(mono, dtype=np.float64) - x = x - np.mean(x) - if x.size < int(0.15 * sr): - return None - frame_len = int(0.04 * sr) - hop = max(frame_len // 2, 1) - f0_min_hz, f0_max_hz = 70.0, 400.0 - lag_min = max(1, int(sr / f0_max_hz)) - lag_max = min(frame_len - 2, int(sr / f0_min_hz)) - if lag_max <= lag_min: - return None - win = np.hamming(frame_len) - pitches: list[float] = [] - for start in range(0, int(x.shape[0]) - frame_len, hop): - frame = x[start : start + frame_len] * win - frame = frame - np.mean(frame) - if float(np.sqrt(np.mean(frame**2))) < 1e-4: - continue - ac = np.correlate(frame, frame, mode="full")[frame_len - 1 :] - ac = ac / (float(ac[0]) + 1e-12) - region = ac[lag_min : lag_max + 1] - peak_rel = int(np.argmax(region)) - peak_lag = peak_rel + lag_min - if peak_lag <= 0: - continue - f0 = float(sr) / float(peak_lag) - if f0_min_hz <= f0 <= f0_max_hz: - pitches.append(f0) - if len(pitches) < 4: - return None - return float(np.median(np.asarray(pitches, dtype=np.float64))) - - -def _estimate_voice_gender_from_audio(audio_bytes: bytes) -> str: - """ - Estimate voice gender from audio using a small pre-trained classification model. - - Uses a cached `audio-classification` pipeline to classify the clip. - Returns 'male' / 'female' when the model confidence is >= 0.9 and the label - maps to one of these; otherwise returns 'unknown'. If the model is unavailable - or inference fails, returns 'unknown' to keep tests stable. - - Under concurrent tests, a global lock serializes pipeline calls (the HF pipeline is not - thread-safe). A coarse F0 median can correct systematic "male -> female" errors on TTS audio. - """ - data, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True) - if data.size == 0: - raise ValueError("Empty audio") - mono = np.mean(data, axis=1) - - try: - target_sr = 16000 - if int(sr) != target_sr and mono.size > 1: - src_len = int(mono.shape[0]) - dst_len = max(1, int(round(src_len * float(target_sr) / float(sr)))) - src_idx = np.arange(src_len, dtype=np.float32) - dst_idx = np.linspace(0, src_len - 1, dst_len, dtype=np.float32) - mono = np.interp(dst_idx, src_idx, mono.astype(np.float32, copy=False)).astype(np.float32) - sr = target_sr - - median_f0 = _median_pitch_hz_from_autocorr(mono, sr) - - clf = _load_gender_pipeline() - if clf is None: - print("gender model not available, returning 'unknown'") - return "unknown" - - # transformers pipeline returns a list of {label, score} (highest score first). - with _GENDER_PIPELINE_LOCK: - outputs = clf(mono, sampling_rate=sr) - if not outputs: - return "unknown" - - top = outputs[0] - label = str(top.get("label", "")).lower() - conf = float(top.get("score", 0.0)) - - if conf < 0.5: - gender = "unknown" - # Some models use non-English labels (e.g., Russian). Normalize to 'male'/'female'. - elif ("female" in label) or ("жен" in label): - gender = "female" - elif ("male" in label) or ("муж" in label): - gender = "male" - else: - gender = "unknown" - - # Debias: wav2vec2 gender heads often call TTS / band-limited male speech "female". - # Low median F0 (~speech male range) + female label -> trust pitch when score is not overwhelming. - if gender == "female" and median_f0 is not None and median_f0 < 165.0 and conf < 0.88: - print(f"gender pitch assist: reclassifying female->male (median_f0={median_f0:.1f} Hz, conf={conf:.3f})") - gender = "male" - elif gender == "male" and median_f0 is not None and median_f0 > 230.0 and conf < 0.88: - print(f"gender pitch assist: reclassifying male->female (median_f0={median_f0:.1f} Hz, conf={conf:.3f})") - gender = "female" - - print( - f"gender classifier: label={label}, conf={conf:.3f}, gender={gender}" - + (f", median_f0={median_f0:.1f}Hz" if median_f0 is not None else "") - ) - return gender - except Exception as exc: # pragma: no cover - best-effort fallback - print(f"Warning: gender classification failed, returning 'unknown': {exc}") - return "unknown" - - -_PRESET_VOICE_GENDER_MAP: dict[str, str] = { - "serena": "female", - "uncle_fu": "male", - "chelsie": "female", - "clone": "female", - "ethan": "male", -} - - -def _assert_preset_voice_gender_from_audio( - audio_bytes: bytes | None, - voice_name: str | None, -) -> None: - """If ``voice_name`` matches a known preset, assert classifier gender matches (skip when unknown).""" - if not voice_name or not audio_bytes: - return - key = str(voice_name).lower() - expected_gender = _PRESET_VOICE_GENDER_MAP.get(key) - if expected_gender is None: - return - estimated_gender = _estimate_voice_gender_from_audio(audio_bytes) - print(f"Preset voice gender check: preset={key!r}, estimated={estimated_gender!r}, expected={expected_gender!r}") - if estimated_gender != "unknown": - assert estimated_gender == expected_gender, ( - f"{voice_name!r} is expected {expected_gender}, but estimated gender is {estimated_gender!r}" - ) - - -# Threshold aligned with _compute_pcm_hnr_db docstring (clean clone vs distorted). -_MIN_PCM_SPEECH_HNR_DB = 1.0 - - -def _compute_pcm_hnr_db(pcm_samples: np.ndarray, sr: int = _PCM_SPEECH_SAMPLE_RATE_HZ) -> float: - """Compute mean Harmonic-to-Noise Ratio (dB) for speech quality. - - Clean cloned speech has HNR > 1.2 dB; distorted speech (e.g. lost - ref_code decoder context) drops below 1.0 dB. - """ - frame_len = int(0.03 * sr) # 30ms frames - hop = frame_len // 2 - hnr_values: list[float] = [] - - for start in range(0, len(pcm_samples) - frame_len, hop): - frame = pcm_samples[start : start + frame_len].astype(np.float32, copy=False) - frame = frame - np.mean(frame) - if np.max(np.abs(frame)) < 0.01: - continue - ac = np.correlate(frame, frame, mode="full")[len(frame) - 1 :] - ac = ac / (ac[0] + 1e-10) - min_lag = int(sr / 400) - max_lag = min(int(sr / 80), len(ac)) - if min_lag >= max_lag: - continue - peak = float(np.max(ac[min_lag:max_lag])) - if 0 < peak < 1: - hnr_values.append(10 * np.log10(peak / (1 - peak + 1e-10))) - - return float(np.mean(hnr_values)) if hnr_values else 0.0 - - -def _assert_pcm_int16_speech_hnr(audio_bytes: bytes) -> None: - """Validate harmonic-to-noise ratio on raw int16 PCM from /v1/audio/speech.""" - assert audio_bytes is not None and len(audio_bytes) >= 2, "missing PCM bytes" - assert len(audio_bytes) % 2 == 0, "PCM byte length must be aligned to int16" - pcm_samples = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0 - hnr = _compute_pcm_hnr_db(pcm_samples) - print(f"PCM speech HNR: {hnr:.2f} dB (threshold: {_MIN_PCM_SPEECH_HNR_DB} dB)") - assert hnr >= _MIN_PCM_SPEECH_HNR_DB, ( - f"Audio distortion detected: HNR={hnr:.2f} dB < {_MIN_PCM_SPEECH_HNR_DB} dB. " - "Voice clone decoder may be losing ref_code speaker context on later chunks." - ) - - -def assert_omni_response(response: OmniResponse, request_config: dict[str, Any], run_level): - """ - Validate response results. - - Args: - response: OmniResponse object - - Raises: - AssertionError: When the response does not meet validation criteria - """ - assert response.success, "The request failed." - e2e_latency = response.e2e_latency - if e2e_latency is not None: - print(f"the e2e latency is: {e2e_latency}") - - modalities = request_config.get("modalities", ["text", "audio"]) - - if run_level == "advanced_model": - if "audio" in modalities: - assert response.audio_content is not None, "No audio output is generated" - print(f"audio content is: {response.audio_content}") - speaker = request_config.get("speaker") - if speaker: - _assert_preset_voice_gender_from_audio( - response.audio_bytes, - speaker, - ) - - if "text" in modalities: - assert response.text_content is not None, "No text output is generated" - print(f"text content is: {response.text_content}") - - # Verify image description - word_types = ["text", "image", "audio", "video"] - keywords_dict = request_config.get("key_words", {}) - for word_type in word_types: - keywords = keywords_dict.get(word_type) - if "text" in modalities: - if keywords: - text_lower = response.text_content.lower() - assert any(str(kw).lower() in text_lower for kw in keywords), ( - "The output does not contain any of the keywords." - ) - else: - if keywords: - audio_lower = response.audio_content.lower() - assert any(str(kw).lower() in audio_lower for kw in keywords), ( - "The output does not contain any of the keywords." - ) - - # Verify similarity (Whisper transcript vs streamed/detokenized text) - if "text" in modalities and "audio" in modalities: - assert response.similarity is not None and response.similarity > 0.9, ( - "The audio content is not same as the text" - ) - print(f"similarity is: {response.similarity}") - - -def assert_audio_speech_response( - response: OmniResponse, - request_config: dict[str, Any], - run_level: str, -) -> None: - """ - Validate /v1/audio/speech response: success, optional format check, transcription similarity - and gender (non-PCM only for advanced_model), and int16 PCM HNR when response_format is pcm. - """ - assert response.success, "The request failed." - - req_fmt = request_config.get("response_format") - - if req_fmt == "pcm" and response.audio_bytes: - _assert_pcm_int16_speech_hnr(response.audio_bytes) - if response.audio_format: - assert "pcm" in response.audio_format.lower(), ( - f"Expected audio/pcm content-type, got {response.audio_format!r}" - ) - - elif req_fmt == "wav" and response.audio_format: - assert req_fmt in response.audio_format, ( - f"The response audio format {response.audio_format} don't match the request audio format {req_fmt}" - ) - - e2e_latency = response.e2e_latency - if e2e_latency is not None: - print(f"the avg e2e latency is: {e2e_latency}") - - if run_level == "advanced_model" and req_fmt != "pcm": - # Text–audio semantic similarity check (skipped for raw PCM: no Whisper transcript). - expected_text = request_config.get("input") - if expected_text: - transcript = (response.audio_content or "").strip() - print(f"audio content is: {transcript}") - print(f"input text is: {expected_text}") - similarity = cosine_similarity_text(transcript.lower(), expected_text.lower()) - print(f"Cosine similarity: {similarity:.3f}") - assert similarity > 0.9, ( - f"Transcript doesn't match input: similarity={similarity:.2f}, transcript='{transcript}'" - ) - - # Voice gender consistency check (preset names in ``_PRESET_VOICE_GENDER_MAP``). - # When the estimator returns 'unknown', we treat it as inconclusive and do NOT fail the test. - _assert_preset_voice_gender_from_audio( - response.audio_bytes, - request_config.get("voice"), - ) - - -def assert_diffusion_response(response: DiffusionResponse, request_config: dict[str, Any], run_level: str = None): - """ - Validate diffusion response results. - - Dispatcher that routes validation to modality-specific assert functions. - - Args: - response: DiffusionResponse object. - request_config: Request configuration dictionary. - run_level: Test run level (e.g. "core_model", "advanced_model") - - Raises: - AssertionError: When the response does not meet validation criteria - KeyError: When the request_config does not contain necessary parameters for validation - """ - assert response.success, "The request failed." - - e2e_latency = response.e2e_latency - if e2e_latency is not None: - print(f"the avg e2e is: {e2e_latency}") - - has_any_content = any(content is not None for content in (response.images, response.videos, response.audios)) - assert has_any_content, "Response contains no images, videos, or audios" - - if response.images is not None: - assert_image_diffusion_response( - response=response, - request_config=request_config, - run_level=run_level, - ) - - if response.videos is not None: - assert_video_diffusion_response( - response=response, - request_config=request_config, - run_level=run_level, - ) - - if response.audios is not None: - assert_audio_diffusion_response( - response=response, - request_config=request_config, - run_level=run_level, - ) - - -class OpenAIClientHandler: - """ - OpenAI client handler class, encapsulating both streaming and non-streaming response processing logic. - - This class integrates OpenAI API request sending, response handling, and validation functionality, - supporting both single request and concurrent request modes. - """ - - def __init__( - self, host: str = "127.0.0.1", port: int = get_open_port(), api_key: str = "EMPTY", run_level: str = None - ): - """ - Initialize the OpenAI client. - - Args: - host: vLLM-Omni server host address - port: vLLM-Omni server port - api_key: API key (defaults to "EMPTY") - """ - self.base_url = f"http://{host}:{port}" - self.client = OpenAI(base_url=f"http://{host}:{port}/v1", api_key=api_key) - self.run_level = run_level - - def _process_stream_omni_response(self, chat_completion) -> OmniResponse: - """ - Process streaming responses. - - Args: - chat_completion: OpenAI streaming response object - request_config: Request configuration dictionary - - Returns: - OmniResponse: Processed response object - """ - result = OmniResponse() - start_time = time.perf_counter() - - try: - text_content = "" - audio_data = [] - - for chunk in chat_completion: - for choice in chunk.choices: - # Get content data - if hasattr(choice, "delta"): - content = getattr(choice.delta, "content", None) - else: - content = None - - # Get modality type - modality = getattr(chunk, "modality", None) - - # Process content based on modality type - if modality == "audio" and content: - audio_data.append(content) - elif modality == "text" and content: - text_content += content if content else "" - - # Calculate end-to-end latency - result.e2e_latency = time.perf_counter() - start_time - - # Process audio and text content - audio_content = None - similarity = None - - if audio_data or text_content: - if audio_data: - merged_seg = _merge_base64_audio_to_segment(audio_data) - wav_buf = BytesIO() - merged_seg.export(wav_buf, format="wav") - result.audio_bytes = wav_buf.getvalue() - audio_content = convert_audio_bytes_to_text(result.audio_bytes) - if audio_content and text_content: - similarity = cosine_similarity_text(audio_content.lower(), text_content.lower()) - - # Populate result object - result.text_content = text_content - result.audio_data = audio_data - result.audio_content = audio_content - result.similarity = similarity - result.success = True - - except Exception as e: - result.error_message = f"Stream processing error: {str(e)}" - print(f"Error: {result.error_message}") - - return result - - def _process_non_stream_omni_response(self, chat_completion) -> OmniResponse: - """ - Process non-streaming responses. - - Args: - chat_completion: OpenAI non-streaming response object - request_config: Request configuration dictionary - - Returns: - OmniResponse: Processed response object - """ - result = OmniResponse() - start_time = time.perf_counter() - - try: - audio_data = None - text_content = None - - # Iterate through all choices - for choice in chat_completion.choices: - # Process audio data - if hasattr(choice.message, "audio") and choice.message.audio is not None: - audio_message = choice.message - audio_data = audio_message.audio.data - - # Process text content - if hasattr(choice.message, "content") and choice.message.content is not None: - text_content = choice.message.content - - # Calculate end-to-end latency - result.e2e_latency = time.perf_counter() - start_time - - # Process audio and text content - audio_content = None - similarity = None - - if audio_data or text_content: - if audio_data: - result.audio_bytes = base64.b64decode(audio_data) - audio_content = convert_audio_bytes_to_text(result.audio_bytes) - if audio_content and text_content: - similarity = cosine_similarity_text(audio_content.lower(), text_content.lower()) - - # Populate result object - result.text_content = text_content - result.audio_content = audio_content - result.similarity = similarity - result.success = True - - except Exception as e: - result.error_message = f"Non-stream processing error: {str(e)}" - print(f"Error: {result.error_message}") - - return result - - def _process_diffusion_response(self, chat_completion) -> DiffusionResponse: - """ - Process diffusion responses (image generation/editing). - - Args: - chat_completion: OpenAI response object - - Returns: - DiffusionResponse: Processed response object - """ - result = DiffusionResponse() - start_time = time.perf_counter() - - try: - images = [] - # [TODO] reading video and audio output from API response for later validation - - for choice in chat_completion.choices: - if hasattr(choice.message, "content") and choice.message.content is not None: - content = choice.message.content - if isinstance(content, list): - for item in content: - if isinstance(item, dict): - image_url = item.get("image_url", {}).get("url") - else: - image_url_obj = getattr(item, "image_url", None) - image_url = hasattr(image_url_obj, "url", None) if image_url_obj else None - if image_url and image_url.startswith("data:image"): - b64_data = image_url.split(",", 1)[1] - img = decode_b64_image(b64_data) - images.append(img) - - result.e2e_latency = time.perf_counter() - start_time - result.images = images if images else None - result.success = True - - except Exception as e: - result.error_message = f"Diffusion response processing error: {str(e)}" - print(f"Error: {result.error_message}") - - return result - - def _process_stream_audio_speech_response(self, response, *, response_format: str | None = None) -> OmniResponse: - """ - Process streaming /v1/audio/speech responses into an OmniResponse. - - This mirrors _process_stream_omni_response but operates on low-level - audio bytes and produces an OmniResponse with audio_content filled - from Whisper transcription. - """ - result = OmniResponse() - start_time = time.perf_counter() - - try: - # Aggregate all audio bytes from the streaming response. - data = bytearray() - - # Preferred OpenAI helper. - if hasattr(response, "iter_bytes") and callable(getattr(response, "iter_bytes")): - for chunk in response.iter_bytes(): - if chunk: - data.extend(chunk) - else: - # Generic iterable-of-bytes fallback (e.g., generator or list of chunks). - try: - iterator = iter(response) - except TypeError: - iterator = None - - if iterator is not None: - for chunk in iterator: - if not chunk: - continue - if isinstance(chunk, (bytes, bytearray)): - data.extend(chunk) - elif hasattr(chunk, "data"): - data.extend(chunk.data) # type: ignore[arg-type] - elif hasattr(chunk, "content"): - data.extend(chunk.content) # type: ignore[arg-type] - else: - raise TypeError(f"Unsupported stream chunk type: {type(chunk)}") - else: - raise TypeError(f"Unsupported audio speech streaming response type: {type(response)}") - - raw_bytes = bytes(data) - if response_format == "pcm": - transcript = None - else: - transcript = convert_audio_bytes_to_text(raw_bytes) - - # Populate OmniResponse. - result.audio_bytes = raw_bytes - result.audio_content = transcript - result.e2e_latency = time.perf_counter() - start_time - result.success = True - result.audio_format = getattr(response, "response", None) - if result.audio_format is not None: - result.audio_format = result.audio_format.headers.get("content-type", "") - - except Exception as e: - result.error_message = f"Audio speech stream processing error: {str(e)}" - print(f"Error: {result.error_message}") - - return result - - def _process_non_stream_audio_speech_response( - self, response, *, response_format: str | None = None - ) -> OmniResponse: - """ - Process non-streaming /v1/audio/speech responses into an OmniResponse. - - This mirrors _process_non_stream_omni_response but for the binary - audio payload returned by audio.speech.create. - """ - result = OmniResponse() - start_time = time.perf_counter() - - try: - # OpenAI non-streaming audio.speech.create returns HttpxBinaryResponseContent (.read() or .content) - if hasattr(response, "read") and callable(getattr(response, "read")): - raw_bytes = response.read() - elif hasattr(response, "content"): - raw_bytes = response.content # type: ignore[assignment] - else: - raise TypeError(f"Unsupported audio speech response type: {type(response)}") - - if response_format == "pcm": - transcript = None - else: - transcript = convert_audio_bytes_to_text(raw_bytes) - - result.audio_bytes = raw_bytes - result.audio_content = transcript - result.e2e_latency = time.perf_counter() - start_time - result.success = True - result.audio_format = getattr(response, "response", None) - if result.audio_format is not None: - result.audio_format = result.audio_format.headers.get("content-type", "") - - except Exception as e: - result.error_message = f"Audio speech non-stream processing error: {str(e)}" - print(f"Error: {result.error_message}") - - return result - - def send_omni_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]: - """ - Send OpenAI requests. - - Args: - request_config: Request configuration dictionary containing parameters like model, messages, stream. - Optional ``use_audio_in_video`` (bool): when true, sets - ``extra_body["mm_processor_kwargs"] = {"use_audio_in_video": True}`` for Qwen-Omni video+audio - extraction. - Optional top-level ``speaker`` (str): Qwen3-Omni preset TTS speaker name; sent as - ``extra_body["speaker"]`` to ``chat.completions.create``. - request_num: Number of requests, defaults to 1 (single request) - - Returns: - List[OmniResponse]: List of response objects - """ - - responses = [] - stream = request_config.get("stream", False) - modalities = request_config.get("modalities", ["text", "audio"]) - - extra_body: dict[str, Any] = {} - if "speaker" in request_config: - extra_body["speaker"] = request_config["speaker"] - if request_config.get("use_audio_in_video"): - mm = dict(extra_body.get("mm_processor_kwargs") or {}) - mm["use_audio_in_video"] = True - extra_body["mm_processor_kwargs"] = mm - extra_body_arg: dict[str, Any] | None = extra_body if extra_body else None - - create_kwargs: dict[str, Any] = { - "model": request_config.get("model"), - "messages": request_config.get("messages"), - "stream": stream, - "modalities": modalities, - } - if extra_body_arg is not None: - create_kwargs["extra_body"] = extra_body_arg - - if request_num == 1: - # Send single request - chat_completion = self.client.chat.completions.create(**create_kwargs) - - if stream: - response = self._process_stream_omni_response(chat_completion) - else: - response = self._process_non_stream_omni_response(chat_completion) - - assert_omni_response(response, request_config, run_level=self.run_level) - responses.append(response) - - else: - # Send concurrent requests: run create + process in worker so e2e_latency includes full round-trip. - def _one_omni_request(): - start = time.perf_counter() - worker_kwargs: dict[str, Any] = { - "model": request_config.get("model"), - "messages": request_config.get("messages"), - "modalities": modalities, - "stream": stream, - } - if extra_body_arg is not None: - worker_kwargs["extra_body"] = extra_body_arg - chat_completion = self.client.chat.completions.create(**worker_kwargs) - if stream: - response = self._process_stream_omni_response(chat_completion) - else: - response = self._process_non_stream_omni_response(chat_completion) - response.e2e_latency = time.perf_counter() - start - return response - - with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor: - futures = [executor.submit(_one_omni_request) for _ in range(request_num)] - for future in concurrent.futures.as_completed(futures): - response = future.result() - assert_omni_response(response, request_config, run_level=self.run_level) - responses.append(response) - - return responses - - def send_audio_speech_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]: - """ - Call the /v1/audio/speech endpoint using the same configuration-dict - style as send_omni_request, but via the OpenAI Python client's - audio.speech APIs. - - Expected keys in request_config: - - model: model name/path (required) - - input: text to synthesize (required) - - response_format: audio format such as "wav" or "pcm" (optional) - - task_type, ref_text, ref_audio: TTS-specific extras (optional, passed via extra_body) - - timeout: request timeout in seconds (float, optional, default 120.0) - - stream: whether to use streaming API (bool, optional, default False) - """ - timeout = float(request_config.get("timeout", 120.0)) - - model = request_config["model"] - text_input = request_config["input"] - stream = bool(request_config.get("stream", False)) - voice = request_config.get("voice", None) - - # Standard OpenAI param: use omit when not provided to keep default behavior. - response_format = request_config.get("response_format", omit) - - # Qwen3-TTS custom fields, forwarded via extra_body. - extra_body: dict[str, Any] = {} - # Keep this list aligned with vllm_omni.entrypoints.openai.protocol.audio params. - for key in ("task_type", "ref_text", "ref_audio", "language", "max_new_tokens"): - if key in request_config: - extra_body[key] = request_config[key] - - responses: list[OmniResponse] = [] - - speech_fmt: str | None = None if response_format is omit else str(response_format).lower() - - if request_num == 1: - if stream: - # Use streaming response helper. - with self.client.audio.speech.with_streaming_response.create( - model=model, - input=text_input, - response_format=response_format, - extra_body=extra_body or None, - timeout=timeout, - voice=voice, - ) as resp: - omni_resp = self._process_stream_audio_speech_response(resp, response_format=speech_fmt) - else: - # Non-streaming response. - resp = self.client.audio.speech.create( - model=model, - input=text_input, - response_format=response_format, - extra_body=extra_body or None, - timeout=timeout, - voice=voice, - ) - omni_resp = self._process_non_stream_audio_speech_response(resp, response_format=speech_fmt) - - assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level) - responses.append(omni_resp) - return responses - else: - # request_num > 1: concurrent requests (use same params as single-request path) - - if stream: - - def _stream_task(): - with self.client.audio.speech.with_streaming_response.create( - model=model, - input=text_input, - response_format=response_format, - extra_body=extra_body or None, - timeout=timeout, - voice=voice, - ) as resp: - return self._process_stream_audio_speech_response(resp, response_format=speech_fmt) - - with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor: - futures = [executor.submit(_stream_task) for _ in range(request_num)] - for future in concurrent.futures.as_completed(futures): - omni_resp = future.result() - assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level) - responses.append(omni_resp) - else: - with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor: - futures = [] - for _ in range(request_num): - future = executor.submit( - self.client.audio.speech.create, - model=model, - input=text_input, - response_format=response_format, - extra_body=extra_body or None, - timeout=timeout, - voice=voice, - ) - futures.append(future) - - for future in concurrent.futures.as_completed(futures): - resp = future.result() - omni_resp = self._process_non_stream_audio_speech_response(resp, response_format=speech_fmt) - assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level) - responses.append(omni_resp) - - return responses - - def send_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]: - """ - Send OpenAI requests for diffusion models. - - Args: - request_config: Request configuration dictionary containing parameters like model, messages - request_num: Number of requests to send concurrently, defaults to 1 (single request) - Returns: - List[OmniResponse]: List of response objects - """ - responses = [] - stream = request_config.get("stream", False) - modalities = request_config.get("modalities", omit) # Most diffusion models don't require modalities param - extra_body = request_config.get("extra_body", None) - - if stream: - raise NotImplementedError("Streaming is not currently implemented for diffusion model e2e test") - - if request_num == 1: - # Send single request - chat_completion = self.client.chat.completions.create( - model=request_config.get("model"), - messages=request_config.get("messages"), - extra_body=extra_body, - modalities=modalities, - ) - - response = self._process_diffusion_response(chat_completion) - assert_diffusion_response(response, request_config, run_level=self.run_level) - responses.append(response) - - else: - # Send concurrent requests - with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor: - futures = [] - - # Submit all request tasks - for _ in range(request_num): - future = executor.submit( - self.client.chat.completions.create, - model=request_config.get("model"), - messages=request_config.get("messages"), - modalities=modalities, - extra_body=extra_body, - ) - futures.append(future) - - # Process completed tasks - for future in concurrent.futures.as_completed(futures): - chat_completion = future.result() - response = self._process_diffusion_response(chat_completion) - assert_diffusion_response(response, request_config, run_level=self.run_level) - responses.append(response) - - return responses - - def send_video_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]: - """ - Send native /v1/videos requests. - """ - if request_num != 1: - raise NotImplementedError("Concurrent video diffusion requests are not currently implemented") - - if request_config.get("stream", False): - raise NotImplementedError("Streaming is not currently implemented for video diffusion e2e test") - - form_data = request_config.get("form_data") - if not isinstance(form_data, dict): - raise ValueError("Video request_config must contain 'form_data'") - - if not form_data.get("prompt"): - raise ValueError("Video request_config['form_data'] must contain 'prompt'") - - normalized_form_data = {key: str(value) for key, value in form_data.items() if value is not None} - - files: dict[str, tuple[str, BytesIO, str]] = {} - image_reference = request_config.get("image_reference") - if image_reference: - if image_reference.startswith("data:image"): - header, encoded = image_reference.split(",", 1) - content_type = header.split(";")[0].removeprefix("data:") - extension = content_type.split("/")[-1] - file_data = base64.b64decode(encoded) - - files["input_reference"] = ( - f"reference.{extension}", - BytesIO(file_data), - content_type, - ) - else: - normalized_form_data["image_reference"] = json.dumps({"image_url": image_reference}) - - result = DiffusionResponse() - start_time = time.perf_counter() - - try: - create_url = self._build_url("/v1/videos") - response = requests.post( - create_url, - data=normalized_form_data, - files=files, - headers={"Accept": "application/json"}, - timeout=60, - ) - response.raise_for_status() - - job_data = response.json() - video_id = job_data["id"] - - self._wait_until_video_completed(video_id) - - video_content = self._download_video_content(video_id) - - result.success = True - result.videos = [video_content] - result.e2e_latency = time.perf_counter() - start_time - - assert_diffusion_response(result, request_config, run_level=self.run_level) - - except Exception as e: - result.success = False - result.error_message = f"Diffusion response processing error: {e}" - assert False, result.error_message - - return [result] - - def _wait_until_video_completed( - self, - video_id: str, - poll_interval_seconds: int = 2, - timeout_seconds: int = 300, - ) -> None: - status_url = self._build_url(f"/v1/videos/{video_id}") - deadline = time.monotonic() + timeout_seconds - - while time.monotonic() < deadline: - status_resp = requests.get( - status_url, - headers={"Accept": "application/json"}, - timeout=30, - ) - status_resp.raise_for_status() - - status_data = status_resp.json() - current_status = status_data["status"] - - if current_status == "completed": - return - - if current_status == "failed": - error_msg = status_data.get("last_error", "Unknown error") - raise RuntimeError(f"Job failed: {error_msg}") - - time.sleep(poll_interval_seconds) - - raise TimeoutError(f"Video job {video_id} did not complete within {timeout_seconds}s") - - def _download_video_content(self, video_id: str) -> bytes: - download_url = self._build_url(f"/v1/videos/{video_id}/content") - video_resp = requests.get(download_url, stream=True, timeout=60) - video_resp.raise_for_status() - - video_bytes = BytesIO() - for chunk in video_resp.iter_content(chunk_size=8192): - if chunk: - video_bytes.write(chunk) - - return video_bytes.getvalue() - - def _build_url(self, path: str) -> str: - return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}" - - -@pytest.fixture -def openai_client(omni_server: OmniServer, run_level: str): - """Create OpenAIClientHandler fixture to facilitate communication with OmniServer - with encapsulated request sending, concurrent requests, response handling, and validation.""" - return OpenAIClientHandler(host=omni_server.host, port=omni_server.port, api_key="EMPTY", run_level=run_level) - - -class OmniRunner: - """ - Offline test runner for Omni models. - """ - - def __init__( - self, - model_name: str, - seed: int = 42, - stage_init_timeout: int = 300, - batch_timeout: int = 10, - init_timeout: int = 300, - shm_threshold_bytes: int = 65536, - log_stats: bool = False, - stage_configs_path: str | None = None, - **kwargs, - ) -> None: - """ - Initialize an OmniRunner for testing. - - Args: - model_name: The model name or path - seed: Random seed for reproducibility - stage_init_timeout: Timeout for initializing a single stage in seconds - batch_timeout: Timeout for batching in seconds - init_timeout: Timeout for initializing stages in seconds - shm_threshold_bytes: Threshold for using shared memory - log_stats: Enable detailed statistics logging - stage_configs_path: Optional path to YAML stage config file - **kwargs: Additional arguments passed to Omni - """ - cleanup_dist_env_and_memory() - _run_pre_test_cleanup(enable_force=True) - _run_post_test_cleanup(enable_force=True) - self.model_name = model_name - self.seed = seed - - self.omni = Omni( - model=model_name, - log_stats=log_stats, - stage_init_timeout=stage_init_timeout, - batch_timeout=batch_timeout, - init_timeout=init_timeout, - shm_threshold_bytes=shm_threshold_bytes, - stage_configs_path=stage_configs_path, - **kwargs, - ) - - def _estimate_prompt_len( - self, - additional_information: dict[str, Any], - model_name: str, - _cache: dict[str, Any] = {}, - ) -> int: - """Estimate prompt_token_ids placeholder length for the Talker stage. - - The AR Talker replaces all input embeddings via ``preprocess``, so the - placeholder values are irrelevant but the **length** must match the - embeddings that ``preprocess`` will produce. - """ - try: - from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import Qwen3TTSConfig - from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import ( - Qwen3TTSTalkerForConditionalGeneration, - ) - - if model_name not in _cache: - from transformers import AutoTokenizer - - tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left") - cfg = Qwen3TTSConfig.from_pretrained(model_name, trust_remote_code=True) - _cache[model_name] = (tok, getattr(cfg, "talker_config", None)) - - tok, tcfg = _cache[model_name] - task_type = (additional_information.get("task_type") or ["CustomVoice"])[0] - return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information( - additional_information=additional_information, - task_type=task_type, - tokenize_prompt=lambda t: tok(t, padding=False)["input_ids"], - codec_language_id=getattr(tcfg, "codec_language_id", None), - spk_is_dialect=getattr(tcfg, "spk_is_dialect", None), - ) - except Exception as exc: - logger.warning("Failed to estimate prompt length, using fallback 2048: %s", exc) - return 2048 - - def get_default_sampling_params_list(self) -> list[OmniSamplingParams]: - """ - Get a list of default sampling parameters for all stages. - - Returns: - List of SamplingParams with default decoding for each stage - """ - if not hasattr(self.omni, "default_sampling_params_list"): - raise AttributeError("Omni.default_sampling_params_list is not available") - return list(self.omni.default_sampling_params_list) - - def get_omni_inputs( - self, - prompts: list[str] | str, - system_prompt: str | None = None, - audios: PromptAudioInput = None, - images: PromptImageInput = None, - videos: PromptVideoInput = None, - mm_processor_kwargs: dict[str, Any] | None = None, - modalities: list[str] | None = None, - ) -> list[TextPrompt]: - """ - Construct Omni input format from prompts and multimodal data. - - Args: - prompts: Text prompt(s) - either a single string or list of strings - system_prompt: Optional system prompt (defaults to Qwen system prompt) - audios: Audio input(s) - tuple of (audio_array, sample_rate) or list of tuples - images: Image input(s) - PIL Image or list of PIL Images - videos: Video input(s) - numpy array or list of numpy arrays - mm_processor_kwargs: Optional processor kwargs (e.g., use_audio_in_video) - - Returns: - List of prompt dictionaries suitable for Omni.generate() - """ - if system_prompt is None: - system_prompt = ( - "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " - "Group, capable of perceiving auditory and visual inputs, as well as " - "generating text and speech." - ) - - video_padding_token = "<|VIDEO|>" - image_padding_token = "<|IMAGE|>" - audio_padding_token = "<|AUDIO|>" - - if "Qwen3-Omni-30B-A3B-Instruct" in self.model_name: - video_padding_token = "<|video_pad|>" - image_padding_token = "<|image_pad|>" - audio_padding_token = "<|audio_pad|>" - - if isinstance(prompts, str): - prompts = [prompts] - - # Qwen-TTS: follow examples/offline_inference/qwen3_tts/end2end.py style. - # Stage 0 expects token placeholders + additional_information (text/speaker/task_type/...), - # and Talker replaces embeddings in preprocess based on additional_information only. - is_tts_model = "Qwen3-TTS" in self.model_name or "qwen3_tts" in self.model_name.lower() - if is_tts_model and modalities == ["audio"]: - tts_kw = mm_processor_kwargs or {} - task_type = tts_kw.get("task_type", "CustomVoice") - speaker = tts_kw.get("speaker", "Vivian") - language = tts_kw.get("language", "Auto") - max_new_tokens = int(tts_kw.get("max_new_tokens", 2048)) - ref_audio = tts_kw.get("ref_audio", None) - ref_text = tts_kw.get("ref_text", None) - - omni_inputs: list[TextPrompt] = [] - for prompt_text in prompts: - text_str = str(prompt_text).strip() or " " - additional_information: dict[str, Any] = { - "task_type": [task_type], - "text": [text_str], - "language": [language], - "speaker": [speaker], - "max_new_tokens": [max_new_tokens], - } - if ref_audio is not None: - additional_information["ref_audio"] = [ref_audio] - if ref_text is not None: - additional_information["ref_text"] = [ref_text] - # Use official helper to get correct placeholder length - plen = self._estimate_prompt_len(additional_information, self.model_name) - input_dict: TextPrompt = { - "prompt_token_ids": [0] * plen, - "additional_information": additional_information, - } - omni_inputs.append(input_dict) - return omni_inputs - - def _normalize_mm_input(mm_input, num_prompts): - if mm_input is None: - return [None] * num_prompts - if isinstance(mm_input, list): - if len(mm_input) != num_prompts: - raise ValueError( - f"Multimodal input list length ({len(mm_input)}) must match prompts length ({num_prompts})" - ) - return mm_input - return [mm_input] * num_prompts - - num_prompts = len(prompts) - audios_list = _normalize_mm_input(audios, num_prompts) - images_list = _normalize_mm_input(images, num_prompts) - videos_list = _normalize_mm_input(videos, num_prompts) - - omni_inputs = [] - for i, prompt_text in enumerate(prompts): - user_content = "" - multi_modal_data = {} - - audio = audios_list[i] - if audio is not None: - if isinstance(audio, list): - for _ in audio: - user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>" - multi_modal_data["audio"] = audio - else: - user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>" - multi_modal_data["audio"] = audio - - image = images_list[i] - if image is not None: - if isinstance(image, list): - for _ in image: - user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>" - multi_modal_data["image"] = image - else: - user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>" - multi_modal_data["image"] = image - - video = videos_list[i] - if video is not None: - if isinstance(video, list): - for _ in video: - user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>" - multi_modal_data["video"] = video - else: - user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>" - multi_modal_data["video"] = video - - user_content += prompt_text - - full_prompt = ( - f"<|im_start|>system\n{system_prompt}<|im_end|>\n" - f"<|im_start|>user\n{user_content}<|im_end|>\n" - f"<|im_start|>assistant\n" - ) - - input_dict: TextPrompt = {"prompt": full_prompt} - if multi_modal_data: - input_dict["multi_modal_data"] = multi_modal_data - if modalities: - input_dict["modalities"] = modalities - if mm_processor_kwargs: - input_dict["mm_processor_kwargs"] = mm_processor_kwargs - - omni_inputs.append(input_dict) - - return omni_inputs - - def generate( - self, - prompts: list[TextPrompt], - sampling_params_list: list[OmniSamplingParams] | None = None, - ) -> list[OmniRequestOutput]: - """ - Generate outputs for the given prompts. - - Args: - prompts: List of prompt dictionaries with 'prompt' and optionally - 'multi_modal_data' keys - sampling_params_list: List of sampling parameters for each stage. - If None, uses default parameters. - - Returns: - List of OmniRequestOutput objects from stages with final_output=True - """ - if sampling_params_list is None: - sampling_params_list = self.get_default_sampling_params_list() - - return self.omni.generate(prompts, sampling_params_list) - - def generate_multimodal( - self, - prompts: list[str] | str, - sampling_params_list: list[OmniSamplingParams] | None = None, - system_prompt: str | None = None, - audios: PromptAudioInput = None, - images: PromptImageInput = None, - videos: PromptVideoInput = None, - mm_processor_kwargs: dict[str, Any] | None = None, - modalities: list[str] | None = None, - ) -> list[OmniRequestOutput]: - """ - Convenience method to generate with multimodal inputs. - - Args: - prompts: Text prompt(s) - sampling_params_list: List of sampling parameters for each stage - system_prompt: Optional system prompt - audios: Audio input(s) - images: Image input(s) - videos: Video input(s) - mm_processor_kwargs: Optional processor kwargs - - Returns: - List of OmniRequestOutput objects from stages with final_output=True - """ - omni_inputs = self.get_omni_inputs( - prompts=prompts, - system_prompt=system_prompt, - audios=audios, - images=images, - videos=videos, - mm_processor_kwargs=mm_processor_kwargs, - modalities=modalities, - ) - return self.generate(omni_inputs, sampling_params_list) - - def start_profile( - self, - profile_prefix: str | None = None, - stages: list[int] | None = None, - ) -> list[Any]: - """Start profiling specified stages. - - Args: - profile_prefix: Optional prefix for the trace file names. - stages: List of stage IDs to profile. If None, profiles all stages. - - Returns: - List of results from each stage. - """ - return self.omni.start_profile(profile_prefix=profile_prefix, stages=stages) - - def stop_profile(self, stages: list[int] | None = None) -> list[Any]: - """Stop profiling specified stages. - - Args: - stages: List of stage IDs to profile. If None, stops all stages. - - Returns: - List of results from each stage. - """ - return self.omni.stop_profile(stages=stages) - - def _cleanup_process(self): - try: - keywords = ["enginecore"] - matched = [] - - for proc in psutil.process_iter(["pid", "name", "cmdline", "username"]): - try: - cmdline = " ".join(proc.cmdline()).lower() if proc.cmdline() else "" - name = proc.name().lower() - - is_process = any(keyword in cmdline for keyword in keywords) or any( - keyword in name for keyword in keywords - ) - - if is_process: - print(f"Found vllm process: PID={proc.pid}, cmd={cmdline[:100]}") - matched.append(proc) - except (psutil.NoSuchProcess, psutil.AccessDenied): - pass - - for proc in matched: - try: - proc.terminate() - except (psutil.NoSuchProcess, psutil.AccessDenied): - pass - - _, still_alive = psutil.wait_procs(matched, timeout=5) - for proc in still_alive: - try: - proc.kill() - except (psutil.NoSuchProcess, psutil.AccessDenied): - pass - - if still_alive: - _, stubborn = psutil.wait_procs(still_alive, timeout=3) - if stubborn: - print(f"Warning: failed to kill residual vllm pids: {[p.pid for p in stubborn]}") - else: - print(f"Force-killed residual vllm pids: {[p.pid for p in still_alive]}") - elif matched: - print(f"Terminated vllm pids: {[p.pid for p in matched]}") - - except Exception as e: - print(f"Error in psutil vllm cleanup: {e}") - - def __enter__(self): - """Context manager entry.""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Context manager exit - cleanup resources.""" - if hasattr(self.omni, "close"): - self.omni.close() - self._cleanup_process() - _run_pre_test_cleanup(enable_force=True) - _run_post_test_cleanup(enable_force=True) - cleanup_dist_env_and_memory() - - -@pytest.fixture(scope="module") -def omni_runner(request, model_prefix): - with _omni_server_lock: - model, stage_config_path = request.param - model = model_prefix + model - with OmniRunner(model, seed=42, stage_configs_path=stage_config_path, stage_init_timeout=300) as runner: - print("OmniRunner started successfully") - yield runner - print("OmniRunner stopping...") - - print("OmniRunner stopped") - - -class OmniRunnerHandler: - def __init__(self, omni_runner): - self.runner = omni_runner - - def _process_output(self, outputs: list[Any]) -> OmniResponse: - result = OmniResponse() - try: - text_content = None - audio_content = None - for stage_output in outputs: - if getattr(stage_output, "final_output_type", None) == "text": - text_content = stage_output.request_output.outputs[0].text - if getattr(stage_output, "final_output_type", None) == "audio": - audio_content = stage_output.request_output.outputs[0].multimodal_output["audio"] - - result.audio_content = audio_content - result.text_content = text_content - result.success = True - - except Exception as e: - result.error_message = f"Output processing error: {str(e)}" - result.success = False - print(f"Error: {result.error_message}") - - return result - - def send_request(self, request_config: dict[str, Any] | None = None) -> OmniResponse: - if request_config is None: - request_config = {} - prompts = request_config.get("prompts") - videos = request_config.get("videos") - images = request_config.get("images") - audios = request_config.get("audios") - modalities = request_config.get("modalities", ["text", "audio"]) - outputs = self.runner.generate_multimodal( - prompts=prompts, videos=videos, images=images, audios=audios, modalities=modalities - ) - response = self._process_output(outputs) - assert_omni_response(response, request_config, run_level="core_model") - return response - - def send_audio_speech_request( - self, - request_config: dict[str, Any], - ) -> OmniResponse: - """ - Offline TTS: text -> audio via generate_multimodal, then validate with assert_audio_speech_response. - - request_config must contain: - - 'input' or 'prompts': text to synthesize. - Optional keys: - - 'voice' -> speaker (CustomVoice) - - 'task_type' -> task_type in additional_information (default: "CustomVoice") - - 'language' -> language in additional_information (default: "Auto") - - 'max_new_tokens' -> max_new_tokens in additional_information (default: 2048) - - 'response_format' -> desired audio format (used only for assertion) - """ - input_text = request_config.get("input") or request_config.get("prompts") - if input_text is None: - raise ValueError("request_config must contain 'input' or 'prompts' for TTS") - if isinstance(input_text, list): - input_text = input_text[0] if input_text else "" - - # Build TTS-specific kwargs passed through to get_omni_inputs for Qwen3-TTS, - # matching examples/offline_inference/qwen3_tts/end2end.py. - mm_processor_kwargs: dict[str, Any] = {} - if "voice" in request_config: - mm_processor_kwargs["speaker"] = request_config["voice"] - if "task_type" in request_config: - mm_processor_kwargs["task_type"] = request_config["task_type"] - if "ref_audio" in request_config: - mm_processor_kwargs["ref_audio"] = request_config["ref_audio"] - if "ref_text" in request_config: - mm_processor_kwargs["ref_text"] = request_config["ref_text"] - if "language" in request_config: - mm_processor_kwargs["language"] = request_config["language"] - if "max_new_tokens" in request_config: - mm_processor_kwargs["max_new_tokens"] = request_config["max_new_tokens"] - - outputs = self.runner.generate_multimodal( - prompts=input_text, - modalities=["audio"], - mm_processor_kwargs=mm_processor_kwargs or None, - ) - mm_out: dict[str, Any] | None = None - for stage_out in outputs: - if getattr(stage_out, "final_output_type", None) == "audio": - mm_out = stage_out.request_output.outputs[0].multimodal_output - break - if mm_out is None: - result = OmniResponse(success=False, error_message="No audio output from pipeline") - assert result.success, result.error_message - return result - - audio_data = mm_out.get("audio") - if audio_data is None: - result = OmniResponse(success=False, error_message="No audio tensor in multimodal output") - assert result.success, result.error_message - return result - - sr_raw = mm_out.get("sr") - sr_val = sr_raw[-1] if isinstance(sr_raw, list) and sr_raw else sr_raw - sr = int(sr_val.item() if hasattr(sr_val, "item") else sr_val) - wav_tensor = torch.cat(audio_data, dim=-1) if isinstance(audio_data, list) else audio_data - wav_buf = io.BytesIO() - sf.write( - wav_buf, - wav_tensor.float().cpu().numpy().reshape(-1), - samplerate=sr, - format="WAV", - subtype="PCM_16", - ) - result = OmniResponse(success=True, audio_bytes=wav_buf.getvalue(), audio_format="audio/wav") - assert_audio_speech_response(result, request_config, run_level="core_model") - return result - - def start_profile( - self, - profile_prefix: str | None = None, - stages: list[int] | None = None, - ) -> list[Any]: - """Start profiling specified stages.""" - return self.runner.start_profile(profile_prefix=profile_prefix, stages=stages) - - def stop_profile(self, stages: list[int] | None = None) -> list[Any]: - """Stop profiling specified stages.""" - return self.runner.stop_profile(stages=stages) - - -@pytest.fixture -def omni_runner_handler(omni_runner): - return OmniRunnerHandler(omni_runner) +""" +Root pytest entrypoint for the vLLM-Omni test suite. + +- `tests/conftest.py` stays thin: plugin registration + compatibility re-exports. +- Importable utilities live under `tests/helpers/`. +- Fixtures live under `tests/fixtures/` and are loaded via `pytest_plugins`. +""" + +from __future__ import annotations + +pytest_plugins = ( + "tests.helpers.fixtures.env", + "tests.helpers.fixtures.log", + "tests.helpers.fixtures.run_args", + "tests.helpers.fixtures.runtime", +) + +# Backward-compatible re-exports. +# (Many tests still import from `tests.conftest`; migrate these imports to `tests.helpers.*` over time.) +from tests.helpers.assertions import ( # noqa: F401,E402 + assert_audio_speech_response, + assert_diffusion_response, + assert_image_diffusion_response, + assert_image_valid, + assert_omni_response, + assert_video_diffusion_response, + assert_video_valid, +) +from tests.helpers.media import ( # noqa: F401,E402 + convert_audio_bytes_to_text, + convert_audio_file_to_text, + cosine_similarity_text, + decode_b64_image, + generate_synthetic_audio, + generate_synthetic_image, + generate_synthetic_video, +) +from tests.helpers.stage_config import ( # noqa: F401,E402 + dummy_messages_from_mix_data, + modify_stage_config, +) + +# Lazy: importing `tests.helpers.runtime` at conftest load runs before session +# autouse fixtures and can scramble vLLM/vllm_omni init order. +_RUNTIME_EXPORT_NAMES = ( + "DiffusionResponse", + "OmniResponse", + "OmniRunner", + "OmniRunnerHandler", + "OmniServer", + "OmniServerParams", + "OpenAIClientHandler", +) + + +def __getattr__(name: str): + if name in _RUNTIME_EXPORT_NAMES: + import tests.helpers.runtime as _runtime + + return getattr(_runtime, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return sorted({*globals(), *_RUNTIME_EXPORT_NAMES}) diff --git a/tests/dfx/conftest.py b/tests/dfx/helpers.py similarity index 98% rename from tests/dfx/conftest.py rename to tests/dfx/helpers.py index e54141b3442..2983e364176 100644 --- a/tests/dfx/conftest.py +++ b/tests/dfx/helpers.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Any -from tests.conftest import modify_stage_config +from tests.helpers.stage_config import modify_stage_config def load_configs(config_path: str) -> list[dict[str, Any]]: diff --git a/tests/dfx/perf/scripts/run_benchmark.py b/tests/dfx/perf/scripts/run_benchmark.py index 9e375fa9fec..565de8e68a7 100644 --- a/tests/dfx/perf/scripts/run_benchmark.py +++ b/tests/dfx/perf/scripts/run_benchmark.py @@ -8,14 +8,14 @@ import pytest -from tests.conftest import OmniServer -from tests.dfx.conftest import ( +from tests.dfx.helpers import ( create_benchmark_indices, create_test_parameter_mapping, create_unique_server_params, get_benchmark_params_for_server, load_configs, ) +from tests.helpers.runtime import OmniServer os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" diff --git a/tests/dfx/stability/scripts/test_benchmark_stability.py b/tests/dfx/stability/scripts/test_benchmark_stability.py index e8568652d18..6d0a1525691 100644 --- a/tests/dfx/stability/scripts/test_benchmark_stability.py +++ b/tests/dfx/stability/scripts/test_benchmark_stability.py @@ -24,8 +24,7 @@ import pytest -from tests.conftest import OmniServer -from tests.dfx.conftest import ( +from tests.dfx.helpers import ( create_benchmark_indices, create_test_parameter_mapping, create_unique_server_params, @@ -33,6 +32,7 @@ load_configs, ) from tests.dfx.perf.scripts.run_benchmark import run_benchmark +from tests.helpers.runtime import OmniServer STABILITY_DIR = Path(__file__).resolve().parent.parent STAGE_CONFIGS_DIR = STABILITY_DIR / "stage_configs" diff --git a/tests/diffusion/lora/conftest.py b/tests/diffusion/lora/helpers.py similarity index 100% rename from tests/diffusion/lora/conftest.py rename to tests/diffusion/lora/helpers.py diff --git a/tests/diffusion/lora/test_lora_manager.py b/tests/diffusion/lora/test_lora_manager.py index 83ac7a1144b..785f5d84217 100644 --- a/tests/diffusion/lora/test_lora_manager.py +++ b/tests/diffusion/lora/test_lora_manager.py @@ -8,7 +8,7 @@ from vllm.lora.lora_weights import LoRALayerWeights from vllm.lora.utils import get_supported_lora_modules -from tests.diffusion.lora.conftest import ( +from tests.diffusion.lora.helpers import ( DummyBaseLayerWithLoRA, FakeLinearBase, fake_replace_submodule, diff --git a/tests/diffusion/models/bagel/test_bagel_lora.py b/tests/diffusion/models/bagel/test_bagel_lora.py index 8cb3446ed53..c285758fe86 100644 --- a/tests/diffusion/models/bagel/test_bagel_lora.py +++ b/tests/diffusion/models/bagel/test_bagel_lora.py @@ -11,7 +11,7 @@ import torch from safetensors.torch import save_file -from tests.diffusion.lora.conftest import ( +from tests.diffusion.lora.helpers import ( DummyBaseLayerWithLoRA, FakeLinearBase, fake_replace_submodule, diff --git a/tests/diffusion/quantization/test_quantization_quality.py b/tests/diffusion/quantization/test_quantization_quality.py index 3d8f1873698..ba6a150c4bb 100644 --- a/tests/diffusion/quantization/test_quantization_quality.py +++ b/tests/diffusion/quantization/test_quantization_quality.py @@ -32,7 +32,7 @@ import pytest import torch -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks # --------------------------------------------------------------------------- # Configuration — add new quantization methods / models here diff --git a/tests/diffusion/test_diffusion_step_pipeline.py b/tests/diffusion/test_diffusion_step_pipeline.py index 68aba9ba3bf..f3c41e9eef0 100644 --- a/tests/diffusion/test_diffusion_step_pipeline.py +++ b/tests/diffusion/test_diffusion_step_pipeline.py @@ -13,7 +13,7 @@ import torch import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.diffusion.data import DiffusionOutput from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin diff --git a/tests/diffusion/test_stage_diffusion_proc.py b/tests/diffusion/test_stage_diffusion_proc.py index c26070ad43f..f1cf4f9b7d1 100644 --- a/tests/diffusion/test_stage_diffusion_proc.py +++ b/tests/diffusion/test_stage_diffusion_proc.py @@ -24,19 +24,29 @@ def step(request): SimpleNamespace( images=["img-1"], _multimodal_output={}, + _custom_output={}, metrics={}, stage_durations={}, peak_memory_mb=0.0, latents=None, + trajectory_latents=None, + trajectory_timesteps=None, + trajectory_log_probs=None, + trajectory_decoded=None, final_output_type="image", ), SimpleNamespace( images=["img-2"], _multimodal_output={}, + _custom_output={}, metrics={}, stage_durations={}, peak_memory_mb=0.0, latents=None, + trajectory_latents=None, + trajectory_timesteps=None, + trajectory_log_probs=None, + trajectory_decoded=None, final_output_type="image", ), ] diff --git a/tests/e2e/accuracy/conftest.py b/tests/e2e/accuracy/conftest.py index 0a81b02075b..c426abb51b0 100644 --- a/tests/e2e/accuracy/conftest.py +++ b/tests/e2e/accuracy/conftest.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -import shutil import subprocess from contextlib import contextmanager from dataclasses import dataclass @@ -10,7 +9,7 @@ import pytest import torch -from tests.conftest import OmniServer, OmniServerParams +from tests.helpers.runtime import OmniServer, OmniServerParams def pytest_addoption(parser): @@ -183,18 +182,6 @@ def accuracy_artifact_root() -> Path: return root -def reset_artifact_dir(path: Path) -> Path: - if path.exists(): - shutil.rmtree(path) - path.mkdir(parents=True, exist_ok=True) - return path - - -def infer_model_label(model: str) -> str: - label = Path(model.rstrip("/\\")).name or "model" - return "".join(char if char.isalnum() or char in {"-", "_"} else "_" for char in label) - - def _build_accuracy_server_config( *, generate_model: str, diff --git a/tests/e2e/accuracy/helpers.py b/tests/e2e/accuracy/helpers.py new file mode 100644 index 00000000000..24b71a471e5 --- /dev/null +++ b/tests/e2e/accuracy/helpers.py @@ -0,0 +1,15 @@ +from pathlib import Path + + +def reset_artifact_dir(path: Path) -> Path: + import shutil + + if path.exists(): + shutil.rmtree(path) + path.mkdir(parents=True, exist_ok=True) + return path + + +def infer_model_label(model: str) -> str: + label = Path(model.rstrip("/\\")).name or "model" + return "".join(char if char.isalnum() or char in {"-", "_"} else "_" for char in label) diff --git a/tests/e2e/accuracy/test_gebench_h100_smoke.py b/tests/e2e/accuracy/test_gebench_h100_smoke.py index b4b83187135..74891926910 100644 --- a/tests/e2e/accuracy/test_gebench_h100_smoke.py +++ b/tests/e2e/accuracy/test_gebench_h100_smoke.py @@ -6,8 +6,8 @@ import pytest from benchmarks.accuracy.text_to_image.gbench import main as gbench_main -from tests.e2e.accuracy.conftest import infer_model_label, reset_artifact_dir -from tests.utils import hardware_test +from tests.e2e.accuracy.helpers import infer_model_label, reset_artifact_dir +from tests.helpers.mark import hardware_test @pytest.mark.advanced_model diff --git a/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py b/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py index ac5f2cb3cfd..5f572a0d788 100644 --- a/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py +++ b/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py @@ -7,8 +7,8 @@ from benchmarks.accuracy.image_to_image.gedit_bench import GROUPS from benchmarks.accuracy.image_to_image.gedit_bench import main as gedit_main -from tests.e2e.accuracy.conftest import infer_model_label, reset_artifact_dir -from tests.utils import hardware_test +from tests.e2e.accuracy.helpers import infer_model_label, reset_artifact_dir +from tests.helpers.mark import hardware_test @pytest.mark.advanced_model diff --git a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py index 3cdda1f9ffa..3aa5da85c24 100644 --- a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py +++ b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py @@ -22,7 +22,6 @@ from diffusers import UniPCMultistepScheduler from PIL import Image -from tests.conftest import OmniServerParams from tests.e2e.accuracy.wan22_i2v.run_wan22_i2v_diffusers_cp import ( _configure_scheduler, _ensure_wan_ftfy_fallback, @@ -48,7 +47,8 @@ SSIM_THRESHOLD, WIDTH, ) -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServerParams def test_parse_video_metadata_extracts_dimensions_and_fps() -> None: diff --git a/tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py b/tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py index 57743d62bf6..bd3f2e09975 100644 --- a/tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py +++ b/tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py @@ -26,7 +26,7 @@ import pytest -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput diff --git a/tests/e2e/offline_inference/custom_pipeline/test_async_omni_qwen_image_generate.py b/tests/e2e/offline_inference/custom_pipeline/test_async_omni_qwen_image_generate.py index f1b4595c9df..4985050eba7 100644 --- a/tests/e2e/offline_inference/custom_pipeline/test_async_omni_qwen_image_generate.py +++ b/tests/e2e/offline_inference/custom_pipeline/test_async_omni_qwen_image_generate.py @@ -14,7 +14,7 @@ import pytest from transformers import AutoTokenizer -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput diff --git a/tests/e2e/offline_inference/custom_pipeline/test_worker_extension.py b/tests/e2e/offline_inference/custom_pipeline/test_worker_extension.py index ffbe703ca78..653b35d7e2f 100644 --- a/tests/e2e/offline_inference/custom_pipeline/test_worker_extension.py +++ b/tests/e2e/offline_inference/custom_pipeline/test_worker_extension.py @@ -10,7 +10,7 @@ from tests.e2e.offline_inference.custom_pipeline.worker_extension import ( vLLMOmniColocateWorkerExtensionForTest, ) -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.diffusion.worker.diffusion_worker import CustomPipelineWorkerExtension from vllm_omni.entrypoints.async_omni import AsyncOmni diff --git a/tests/e2e/offline_inference/test_bagel_img2img.py b/tests/e2e/offline_inference/test_bagel_img2img.py index a0c3f6cc9fc..04e761a5a65 100644 --- a/tests/e2e/offline_inference/test_bagel_img2img.py +++ b/tests/e2e/offline_inference/test_bagel_img2img.py @@ -22,8 +22,8 @@ from PIL import Image from vllm.assets.image import ImageAsset -from tests.conftest import modify_stage_config -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.stage_config import modify_stage_config from vllm_omni.entrypoints.omni import Omni from vllm_omni.platforms import current_omni_platform diff --git a/tests/e2e/offline_inference/test_bagel_text2img.py b/tests/e2e/offline_inference/test_bagel_text2img.py index 7cce8da3a73..6d5447b82fb 100644 --- a/tests/e2e/offline_inference/test_bagel_text2img.py +++ b/tests/e2e/offline_inference/test_bagel_text2img.py @@ -28,8 +28,8 @@ import pytest from PIL import Image -from tests.conftest import modify_stage_config -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.stage_config import modify_stage_config from vllm_omni.entrypoints.omni import Omni from vllm_omni.platforms import current_omni_platform diff --git a/tests/e2e/offline_inference/test_bagel_understanding.py b/tests/e2e/offline_inference/test_bagel_understanding.py index 6f95e7ee00f..99d297ef2ca 100644 --- a/tests/e2e/offline_inference/test_bagel_understanding.py +++ b/tests/e2e/offline_inference/test_bagel_understanding.py @@ -27,8 +27,8 @@ import pytest from vllm.assets.image import ImageAsset -from tests.conftest import modify_stage_config -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.stage_config import modify_stage_config from vllm_omni.entrypoints.omni import Omni MODEL_NAME = "ByteDance-Seed/BAGEL-7B-MoT" diff --git a/tests/e2e/offline_inference/test_cache_dit.py b/tests/e2e/offline_inference/test_cache_dit.py index 0e31413dc07..973210d00d1 100644 --- a/tests/e2e/offline_inference/test_cache_dit.py +++ b/tests/e2e/offline_inference/test_cache_dit.py @@ -15,7 +15,7 @@ import pytest import torch -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.inputs.data import OmniDiffusionSamplingParams # ruff: noqa: E402 diff --git a/tests/e2e/offline_inference/test_cosyvoice3.py b/tests/e2e/offline_inference/test_cosyvoice3.py index 8c88d972d5e..db5debac828 100644 --- a/tests/e2e/offline_inference/test_cosyvoice3.py +++ b/tests/e2e/offline_inference/test_cosyvoice3.py @@ -26,8 +26,8 @@ from huggingface_hub import snapshot_download from vllm.sampling_params import SamplingParams -from tests.conftest import OmniRunner -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniRunner from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config from vllm_omni.model_executor.models.cosyvoice3.tokenizer import get_qwen_tokenizer diff --git a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py index f3830f02e97..e697d2da6b1 100644 --- a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py +++ b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py @@ -6,7 +6,7 @@ import torch from vllm.distributed.parallel_state import cleanup_dist_env_and_memory -from tests.utils import DeviceMemoryMonitor, hardware_test +from tests.helpers.mark import DeviceMemoryMonitor, hardware_test from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.platforms import current_omni_platform diff --git a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py index 6132f1bd0eb..9b8ade781b0 100644 --- a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py +++ b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py @@ -5,7 +5,7 @@ import torch from vllm.distributed.parallel_state import cleanup_dist_env_and_memory -from tests.utils import DeviceMemoryMonitor +from tests.helpers.mark import DeviceMemoryMonitor from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.platforms import current_omni_platform diff --git a/tests/e2e/offline_inference/test_expert_parallel.py b/tests/e2e/offline_inference/test_expert_parallel.py index ba126986ec7..4854894301c 100644 --- a/tests/e2e/offline_inference/test_expert_parallel.py +++ b/tests/e2e/offline_inference/test_expert_parallel.py @@ -18,7 +18,7 @@ import torch.distributed as dist from PIL import Image -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni import Omni from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.inputs.data import OmniDiffusionSamplingParams diff --git a/tests/e2e/offline_inference/test_flux_autoround_w4a16.py b/tests/e2e/offline_inference/test_flux_autoround_w4a16.py index 42aab7f26a8..dc49b4c4854 100644 --- a/tests/e2e/offline_inference/test_flux_autoround_w4a16.py +++ b/tests/e2e/offline_inference/test_flux_autoround_w4a16.py @@ -15,7 +15,7 @@ import torch from vllm.distributed.parallel_state import cleanup_dist_env_and_memory -from tests.utils import DeviceMemoryMonitor, hardware_test +from tests.helpers.mark import DeviceMemoryMonitor, hardware_test from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput from vllm_omni.platforms import current_omni_platform diff --git a/tests/e2e/offline_inference/test_magi_human.py b/tests/e2e/offline_inference/test_magi_human.py index cb711edb572..9d45f84c391 100644 --- a/tests/e2e/offline_inference/test_magi_human.py +++ b/tests/e2e/offline_inference/test_magi_human.py @@ -8,7 +8,7 @@ import numpy as np import pytest -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams diff --git a/tests/e2e/offline_inference/test_mammoth_moda2.py b/tests/e2e/offline_inference/test_mammoth_moda2.py index 5293b5ed1b7..1a4657dfbec 100644 --- a/tests/e2e/offline_inference/test_mammoth_moda2.py +++ b/tests/e2e/offline_inference/test_mammoth_moda2.py @@ -23,7 +23,7 @@ import torch from vllm.sampling_params import SamplingParams -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" diff --git a/tests/e2e/offline_inference/test_omnivoice.py b/tests/e2e/offline_inference/test_omnivoice.py index 4b093e357d9..29414466f59 100644 --- a/tests/e2e/offline_inference/test_omnivoice.py +++ b/tests/e2e/offline_inference/test_omnivoice.py @@ -16,7 +16,7 @@ import numpy as np import pytest -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test MODEL = "k2-fsa/OmniVoice" diff --git a/tests/e2e/offline_inference/test_ovis_image.py b/tests/e2e/offline_inference/test_ovis_image.py index 41e21bca3a9..70fab4fe101 100644 --- a/tests/e2e/offline_inference/test_ovis_image.py +++ b/tests/e2e/offline_inference/test_ovis_image.py @@ -16,7 +16,7 @@ import torch from pytest_mock import MockerFixture -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig # Mock the OvisImageTransformer2DModel to avoid complex init if needed, diff --git a/tests/e2e/offline_inference/test_quantization_fp8.py b/tests/e2e/offline_inference/test_quantization_fp8.py index f71c53de74c..1632df25159 100644 --- a/tests/e2e/offline_inference/test_quantization_fp8.py +++ b/tests/e2e/offline_inference/test_quantization_fp8.py @@ -37,7 +37,7 @@ import pytest import torch -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput diff --git a/tests/e2e/offline_inference/test_qwen2_5_omni.py b/tests/e2e/offline_inference/test_qwen2_5_omni.py index 4c4315aab9c..ca5187848d7 100644 --- a/tests/e2e/offline_inference/test_qwen2_5_omni.py +++ b/tests/e2e/offline_inference/test_qwen2_5_omni.py @@ -6,13 +6,9 @@ import pytest -from tests.conftest import ( - generate_synthetic_audio, - generate_synthetic_image, - generate_synthetic_video, - modify_stage_config, -) -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video +from tests.helpers.stage_config import modify_stage_config from vllm_omni.platforms import current_omni_platform models = ["Qwen/Qwen2.5-Omni-7B"] diff --git a/tests/e2e/offline_inference/test_qwen3_omni.py b/tests/e2e/offline_inference/test_qwen3_omni.py index cc0af437eca..2ed5875b167 100644 --- a/tests/e2e/offline_inference/test_qwen3_omni.py +++ b/tests/e2e/offline_inference/test_qwen3_omni.py @@ -11,11 +11,9 @@ import pytest -from tests.conftest import ( - generate_synthetic_video, - modify_stage_config, -) -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.media import generate_synthetic_video +from tests.helpers.stage_config import modify_stage_config from vllm_omni.platforms import current_omni_platform models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"] diff --git a/tests/e2e/offline_inference/test_qwen3_tts_base.py b/tests/e2e/offline_inference/test_qwen3_tts_base.py index be7bd50a36a..90064dec9bd 100644 --- a/tests/e2e/offline_inference/test_qwen3_tts_base.py +++ b/tests/e2e/offline_inference/test_qwen3_tts_base.py @@ -17,8 +17,8 @@ import pytest -from tests.conftest import modify_stage_config -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.stage_config import modify_stage_config MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-Base" REF_AUDIO_URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav" diff --git a/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py b/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py index 67d72df908c..4283cc3b41a 100644 --- a/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py +++ b/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py @@ -17,8 +17,8 @@ import pytest -from tests.conftest import modify_stage_config -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.stage_config import modify_stage_config MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" diff --git a/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py b/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py index d5f82f893e6..79d7b517cd0 100644 --- a/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py +++ b/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py @@ -37,7 +37,7 @@ import pytest import torch -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput diff --git a/tests/e2e/offline_inference/test_sequence_parallel.py b/tests/e2e/offline_inference/test_sequence_parallel.py index 16239a1c52f..463685be82c 100644 --- a/tests/e2e/offline_inference/test_sequence_parallel.py +++ b/tests/e2e/offline_inference/test_sequence_parallel.py @@ -20,7 +20,7 @@ import torch.distributed as dist from PIL import Image -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni import Omni from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.inputs.data import OmniDiffusionSamplingParams diff --git a/tests/e2e/offline_inference/test_stable_audio_model.py b/tests/e2e/offline_inference/test_stable_audio_model.py index ff4d9b40172..511518ca8d4 100644 --- a/tests/e2e/offline_inference/test_stable_audio_model.py +++ b/tests/e2e/offline_inference/test_stable_audio_model.py @@ -5,7 +5,7 @@ import pytest import torch -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput from vllm_omni.platforms import current_omni_platform diff --git a/tests/e2e/offline_inference/test_t2i_model.py b/tests/e2e/offline_inference/test_t2i_model.py index 77b2b3aaf20..d30ccadf4de 100644 --- a/tests/e2e/offline_inference/test_t2i_model.py +++ b/tests/e2e/offline_inference/test_t2i_model.py @@ -5,7 +5,7 @@ import pytest import torch -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput from vllm_omni.platforms import current_omni_platform diff --git a/tests/e2e/offline_inference/test_teacache.py b/tests/e2e/offline_inference/test_teacache.py index efc0e43e86f..6b145ed1b74 100644 --- a/tests/e2e/offline_inference/test_teacache.py +++ b/tests/e2e/offline_inference/test_teacache.py @@ -15,7 +15,7 @@ import pytest import torch -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.platforms import current_omni_platform diff --git a/tests/e2e/offline_inference/test_voxtral_tts.py b/tests/e2e/offline_inference/test_voxtral_tts.py index b559cc252dc..776be2a56ec 100644 --- a/tests/e2e/offline_inference/test_voxtral_tts.py +++ b/tests/e2e/offline_inference/test_voxtral_tts.py @@ -30,8 +30,8 @@ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from vllm import SamplingParams -from tests.conftest import modify_stage_config -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.stage_config import modify_stage_config from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.omni import Omni diff --git a/tests/e2e/offline_inference/test_zimage_parallelism.py b/tests/e2e/offline_inference/test_zimage_parallelism.py index 9d9db16a408..4b386b9a6e7 100644 --- a/tests/e2e/offline_inference/test_zimage_parallelism.py +++ b/tests/e2e/offline_inference/test_zimage_parallelism.py @@ -22,7 +22,7 @@ from PIL import Image from vllm.distributed.parallel_state import cleanup_dist_env_and_memory -from tests.utils import DeviceMemoryMonitor, hardware_test +from tests.helpers.mark import DeviceMemoryMonitor, hardware_test from vllm_omni import Omni from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.inputs.data import OmniDiffusionSamplingParams diff --git a/tests/e2e/online_serving/test_bagel_expansion.py b/tests/e2e/online_serving/test_bagel_expansion.py index e2d75e0d199..342cd60351d 100644 --- a/tests/e2e/online_serving/test_bagel_expansion.py +++ b/tests/e2e/online_serving/test_bagel_expansion.py @@ -16,13 +16,9 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, - dummy_messages_from_mix_data, -) -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler +from tests.helpers.stage_config import dummy_messages_from_mix_data PROMPT = "A futuristic city skyline at twilight, cyberpunk style, ultra-detailed, high resolution." NEGATIVE_PROMPT = "low quality, blurry, distorted, deformed, watermark" diff --git a/tests/e2e/online_serving/test_bagel_online.py b/tests/e2e/online_serving/test_bagel_online.py index ca24f5f81f7..48a0cfc3e33 100644 --- a/tests/e2e/online_serving/test_bagel_online.py +++ b/tests/e2e/online_serving/test_bagel_online.py @@ -28,8 +28,8 @@ import pytest from vllm.assets.image import ImageAsset -from tests.conftest import OmniServerParams -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServerParams os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" diff --git a/tests/e2e/online_serving/test_cosyvoice3_tts.py b/tests/e2e/online_serving/test_cosyvoice3_tts.py index 276b1782f52..e05b5e34f41 100644 --- a/tests/e2e/online_serving/test_cosyvoice3_tts.py +++ b/tests/e2e/online_serving/test_cosyvoice3_tts.py @@ -16,8 +16,8 @@ import pytest -from tests.conftest import OmniServerParams -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServerParams MODEL = "FunAudioLLM/Fun-CosyVoice3-0.5B-2512" diff --git a/tests/e2e/online_serving/test_flux2_expansion.py b/tests/e2e/online_serving/test_flux2_expansion.py index 336bd83a1d2..ce06ad56461 100644 --- a/tests/e2e/online_serving/test_flux2_expansion.py +++ b/tests/e2e/online_serving/test_flux2_expansion.py @@ -11,12 +11,8 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, -) -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler FOUR_CARD_FEATURE_MARKS = hardware_marks(res={"cuda": "L4"}, num_cards=4) POSITIVE_PROMPT = "A cat sitting on a windowsill" diff --git a/tests/e2e/online_serving/test_flux_2_dev_expansion.py b/tests/e2e/online_serving/test_flux_2_dev_expansion.py index eba0fbda225..44603a8e93b 100644 --- a/tests/e2e/online_serving/test_flux_2_dev_expansion.py +++ b/tests/e2e/online_serving/test_flux_2_dev_expansion.py @@ -16,13 +16,9 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, - dummy_messages_from_mix_data, -) -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler +from tests.helpers.stage_config import dummy_messages_from_mix_data MODEL = "black-forest-labs/FLUX.2-dev" PROMPT = "A cinematic mountain landscape at sunrise, dramatic clouds, ultra-detailed, realistic photography." diff --git a/tests/e2e/online_serving/test_flux_kontext_expansion.py b/tests/e2e/online_serving/test_flux_kontext_expansion.py index c13e1e8189d..c85d8a3c3c0 100644 --- a/tests/e2e/online_serving/test_flux_kontext_expansion.py +++ b/tests/e2e/online_serving/test_flux_kontext_expansion.py @@ -5,13 +5,9 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, - dummy_messages_from_mix_data, - generate_synthetic_image, -) +from tests.helpers.media import generate_synthetic_image +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler +from tests.helpers.stage_config import dummy_messages_from_mix_data EDIT_PROMPT = "Transform this modern, geometrist image into a Vincent van Gogh style impressionist painting." NEGATIVE_PROMPT = "blurry, low quality, modern, geometrist" diff --git a/tests/e2e/online_serving/test_hunyuan_video_15_expansion.py b/tests/e2e/online_serving/test_hunyuan_video_15_expansion.py index de950edb900..7af4c43443d 100644 --- a/tests/e2e/online_serving/test_hunyuan_video_15_expansion.py +++ b/tests/e2e/online_serving/test_hunyuan_video_15_expansion.py @@ -11,12 +11,8 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, -) -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler PROMPT = "A cat walking across a sunlit garden, cinematic lighting, slow motion." NEGATIVE_PROMPT = "low quality, blurry, distorted" diff --git a/tests/e2e/online_serving/test_image_gen_edit.py b/tests/e2e/online_serving/test_image_gen_edit.py index 7db740f2037..56747abd16e 100644 --- a/tests/e2e/online_serving/test_image_gen_edit.py +++ b/tests/e2e/online_serving/test_image_gen_edit.py @@ -22,7 +22,7 @@ from vllm.assets.image import ImageAsset from vllm.utils.network_utils import get_open_port -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" # Increase timeout for downloading assets from S3 (default 5s is too short for CI) diff --git a/tests/e2e/online_serving/test_images_generations_lora.py b/tests/e2e/online_serving/test_images_generations_lora.py index 8c826591a56..c011f2714f0 100644 --- a/tests/e2e/online_serving/test_images_generations_lora.py +++ b/tests/e2e/online_serving/test_images_generations_lora.py @@ -22,8 +22,8 @@ from PIL import Image from safetensors.torch import save_file -from tests.conftest import OmniServer -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServer os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" diff --git a/tests/e2e/online_serving/test_longcat_image_edit_expansion.py b/tests/e2e/online_serving/test_longcat_image_edit_expansion.py index 8a2cfbcc145..28f6c3de005 100644 --- a/tests/e2e/online_serving/test_longcat_image_edit_expansion.py +++ b/tests/e2e/online_serving/test_longcat_image_edit_expansion.py @@ -13,14 +13,10 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, - dummy_messages_from_mix_data, - generate_synthetic_image, -) -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks +from tests.helpers.media import generate_synthetic_image +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler +from tests.helpers.stage_config import dummy_messages_from_mix_data EDIT_PROMPT = "Transform this modern image into a cinematic animation style with vibrant colors and soft lighting." NEGATIVE_PROMPT = "blurry, low quality, distorted, oversaturated" diff --git a/tests/e2e/online_serving/test_longcat_image_expansion.py b/tests/e2e/online_serving/test_longcat_image_expansion.py index 161e7cd2e65..f0b0ca905d0 100644 --- a/tests/e2e/online_serving/test_longcat_image_expansion.py +++ b/tests/e2e/online_serving/test_longcat_image_expansion.py @@ -13,13 +13,9 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, - dummy_messages_from_mix_data, -) -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler +from tests.helpers.stage_config import dummy_messages_from_mix_data TEXT_TO_IMAGE_PROMPT = ( "A cinematic illustration of a cat typing on a silver laptop, soft window light, highly detailed." diff --git a/tests/e2e/online_serving/test_mimo_audio.py b/tests/e2e/online_serving/test_mimo_audio.py index 43eeb773355..3349e0a8a63 100644 --- a/tests/e2e/online_serving/test_mimo_audio.py +++ b/tests/e2e/online_serving/test_mimo_audio.py @@ -9,13 +9,10 @@ import pytest -from tests.conftest import ( - OmniServerParams, - dummy_messages_from_mix_data, - generate_synthetic_audio, - modify_stage_config, -) -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.media import generate_synthetic_audio +from tests.helpers.runtime import OmniServerParams +from tests.helpers.stage_config import dummy_messages_from_mix_data, modify_stage_config from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) diff --git a/tests/e2e/online_serving/test_omnivoice.py b/tests/e2e/online_serving/test_omnivoice.py index ec1981aab22..f2d88e4f05b 100644 --- a/tests/e2e/online_serving/test_omnivoice.py +++ b/tests/e2e/online_serving/test_omnivoice.py @@ -17,8 +17,8 @@ import httpx import pytest -from tests.conftest import OmniServerParams -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServerParams MODEL = "k2-fsa/OmniVoice" diff --git a/tests/e2e/online_serving/test_qwen2_5_omni.py b/tests/e2e/online_serving/test_qwen2_5_omni.py index e2913ce0215..e7e5338207b 100644 --- a/tests/e2e/online_serving/test_qwen2_5_omni.py +++ b/tests/e2e/online_serving/test_qwen2_5_omni.py @@ -7,15 +7,10 @@ import pytest -from tests.conftest import ( - OmniServerParams, - dummy_messages_from_mix_data, - generate_synthetic_audio, - generate_synthetic_image, - generate_synthetic_video, - modify_stage_config, -) -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video +from tests.helpers.runtime import OmniServerParams +from tests.helpers.stage_config import dummy_messages_from_mix_data, modify_stage_config from vllm_omni.platforms import current_omni_platform os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py index fcda20ba388..326a04c67d9 100644 --- a/tests/e2e/online_serving/test_qwen3_omni.py +++ b/tests/e2e/online_serving/test_qwen3_omni.py @@ -7,15 +7,10 @@ import pytest -from tests.conftest import ( - OmniServerParams, - dummy_messages_from_mix_data, - generate_synthetic_audio, - generate_synthetic_image, - generate_synthetic_video, - modify_stage_config, -) -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video +from tests.helpers.runtime import OmniServerParams +from tests.helpers.stage_config import dummy_messages_from_mix_data, modify_stage_config from vllm_omni.platforms import current_omni_platform os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" diff --git a/tests/e2e/online_serving/test_qwen3_omni_expansion.py b/tests/e2e/online_serving/test_qwen3_omni_expansion.py index 0bcc86840ba..a9d3a515d4b 100644 --- a/tests/e2e/online_serving/test_qwen3_omni_expansion.py +++ b/tests/e2e/online_serving/test_qwen3_omni_expansion.py @@ -13,15 +13,10 @@ import pytest -from tests.conftest import ( - OmniServerParams, - dummy_messages_from_mix_data, - generate_synthetic_audio, - generate_synthetic_image, - generate_synthetic_video, - modify_stage_config, -) -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video +from tests.helpers.runtime import OmniServerParams +from tests.helpers.stage_config import dummy_messages_from_mix_data, modify_stage_config model = "Qwen/Qwen3-Omni-30B-A3B-Instruct" diff --git a/tests/e2e/online_serving/test_qwen3_tts_base.py b/tests/e2e/online_serving/test_qwen3_tts_base.py index 002f9d99724..15fe21e5cd9 100644 --- a/tests/e2e/online_serving/test_qwen3_tts_base.py +++ b/tests/e2e/online_serving/test_qwen3_tts_base.py @@ -16,8 +16,8 @@ import pytest -from tests.conftest import OmniServerParams -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServerParams MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-Base" diff --git a/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py b/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py index 3c33485e4f4..e4f03f20fbc 100644 --- a/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py +++ b/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py @@ -16,8 +16,8 @@ import pytest -from tests.conftest import OmniServerParams -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServerParams MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-Base" diff --git a/tests/e2e/online_serving/test_qwen3_tts_batch.py b/tests/e2e/online_serving/test_qwen3_tts_batch.py index d0d6336618e..3f0b2cd27f2 100644 --- a/tests/e2e/online_serving/test_qwen3_tts_batch.py +++ b/tests/e2e/online_serving/test_qwen3_tts_batch.py @@ -22,12 +22,9 @@ import pytest import yaml -from tests.conftest import ( - OmniServer, - convert_audio_file_to_text, - cosine_similarity_text, -) -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.media import convert_audio_file_to_text, cosine_similarity_text +from tests.helpers.runtime import OmniServer MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice" diff --git a/tests/e2e/online_serving/test_qwen3_tts_customvoice.py b/tests/e2e/online_serving/test_qwen3_tts_customvoice.py index fb60df725ba..825893a5560 100644 --- a/tests/e2e/online_serving/test_qwen3_tts_customvoice.py +++ b/tests/e2e/online_serving/test_qwen3_tts_customvoice.py @@ -16,8 +16,8 @@ import pytest -from tests.conftest import OmniServerParams -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServerParams MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" diff --git a/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py b/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py index 03a985896e4..8503132dfe9 100644 --- a/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py +++ b/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py @@ -16,8 +16,8 @@ import pytest -from tests.conftest import OmniServerParams -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServerParams MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" diff --git a/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py b/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py index 64e13e1557d..8cd48fff0a4 100644 --- a/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py +++ b/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py @@ -18,8 +18,8 @@ import httpx import pytest -from tests.conftest import OmniServer -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServer MODEL_BASE = "Qwen/Qwen3-TTS-12Hz-0.6B-Base" MODEL_BASE_1_7B = "Qwen/Qwen3-TTS-12Hz-1.7B-Base" diff --git a/tests/e2e/online_serving/test_qwen3_tts_websocket.py b/tests/e2e/online_serving/test_qwen3_tts_websocket.py index df051460119..52b1a7474e4 100644 --- a/tests/e2e/online_serving/test_qwen3_tts_websocket.py +++ b/tests/e2e/online_serving/test_qwen3_tts_websocket.py @@ -12,8 +12,8 @@ import pytest import websockets -from tests.conftest import OmniServer -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServer os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" diff --git a/tests/e2e/online_serving/test_qwen_image_edit_expansion.py b/tests/e2e/online_serving/test_qwen_image_edit_expansion.py index 14e4c915b6b..8b3f9b0e95b 100644 --- a/tests/e2e/online_serving/test_qwen_image_edit_expansion.py +++ b/tests/e2e/online_serving/test_qwen_image_edit_expansion.py @@ -7,14 +7,10 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, - dummy_messages_from_mix_data, - generate_synthetic_image, -) -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks +from tests.helpers.media import generate_synthetic_image +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler +from tests.helpers.stage_config import dummy_messages_from_mix_data EDIT_PROMPT = "Transform this modern, geometrist image into a Vincent van Gogh style impressionist painting." MULTI_EDIT_PROMPT = ( diff --git a/tests/e2e/online_serving/test_qwen_image_expansion.py b/tests/e2e/online_serving/test_qwen_image_expansion.py index 88e56cc3e10..7c31042694e 100644 --- a/tests/e2e/online_serving/test_qwen_image_expansion.py +++ b/tests/e2e/online_serving/test_qwen_image_expansion.py @@ -12,13 +12,9 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, - dummy_messages_from_mix_data, -) -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler +from tests.helpers.stage_config import dummy_messages_from_mix_data T2I_PROMPT = "A photo of a cat sitting on a laptop keyboard, digital art style." NEGATIVE_PROMPT = "blurry, low quality" diff --git a/tests/e2e/online_serving/test_qwen_image_layered_expansion.py b/tests/e2e/online_serving/test_qwen_image_layered_expansion.py index fc73801c0e0..8f08be928fc 100644 --- a/tests/e2e/online_serving/test_qwen_image_layered_expansion.py +++ b/tests/e2e/online_serving/test_qwen_image_layered_expansion.py @@ -14,15 +14,10 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, - decode_b64_image, - dummy_messages_from_mix_data, - generate_synthetic_image, -) -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks +from tests.helpers.media import decode_b64_image, generate_synthetic_image +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler +from tests.helpers.stage_config import dummy_messages_from_mix_data MODEL = "Qwen/Qwen-Image-Layered" EDIT_PROMPT = "Decompose this image into layers." diff --git a/tests/e2e/online_serving/test_sd3_expansion.py b/tests/e2e/online_serving/test_sd3_expansion.py index 3ed5cc5f308..37e590ba3e1 100644 --- a/tests/e2e/online_serving/test_sd3_expansion.py +++ b/tests/e2e/online_serving/test_sd3_expansion.py @@ -4,12 +4,8 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, -) -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler FOUR_CARD_FEATURE_MARKS = hardware_marks(res={"cuda": "L4"}, num_cards=4) POSITIVE_PROMPT = "A serene mountain landscape at sunset" diff --git a/tests/e2e/online_serving/test_video_generation_api.py b/tests/e2e/online_serving/test_video_generation_api.py index 0711a1048e3..6a8fe45875a 100644 --- a/tests/e2e/online_serving/test_video_generation_api.py +++ b/tests/e2e/online_serving/test_video_generation_api.py @@ -16,8 +16,8 @@ import pytest import requests -from tests.conftest import OmniServer -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServer os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" diff --git a/tests/e2e/online_serving/test_voxtral_tts.py b/tests/e2e/online_serving/test_voxtral_tts.py index f795288f375..2dd46fcfaa8 100644 --- a/tests/e2e/online_serving/test_voxtral_tts.py +++ b/tests/e2e/online_serving/test_voxtral_tts.py @@ -17,8 +17,8 @@ import httpx import pytest -from tests.conftest import OmniServerParams -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServerParams MODEL = "mistralai/Voxtral-4B-TTS-2603" diff --git a/tests/e2e/online_serving/test_wan22_expansion.py b/tests/e2e/online_serving/test_wan22_expansion.py index e5e2d748d58..7e5bc912113 100644 --- a/tests/e2e/online_serving/test_wan22_expansion.py +++ b/tests/e2e/online_serving/test_wan22_expansion.py @@ -19,13 +19,9 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, - generate_synthetic_image, -) -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks +from tests.helpers.media import generate_synthetic_image +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler PROMPT = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." NEGATIVE_PROMPT = "low quality, blurry, distorted face, extra limbs, bad anatomy, watermark, logo, text, ugly, deformed, mutated, jpeg artifacts" diff --git a/tests/e2e/online_serving/test_wan_2_1_vace_expansion.py b/tests/e2e/online_serving/test_wan_2_1_vace_expansion.py index 0de70afe862..1f7a8c0722f 100644 --- a/tests/e2e/online_serving/test_wan_2_1_vace_expansion.py +++ b/tests/e2e/online_serving/test_wan_2_1_vace_expansion.py @@ -23,12 +23,8 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, -) -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler MODEL = "Wan-AI/Wan2.1-VACE-1.3B-diffusers" PROMPT = "A cat walking slowly across a sunlit garden path" diff --git a/tests/e2e/online_serving/test_zimage_expansion.py b/tests/e2e/online_serving/test_zimage_expansion.py index 9f90ec855b6..679233c82a9 100644 --- a/tests/e2e/online_serving/test_zimage_expansion.py +++ b/tests/e2e/online_serving/test_zimage_expansion.py @@ -13,12 +13,8 @@ import pytest -from tests.conftest import ( - OmniServer, - OmniServerParams, - OpenAIClientHandler, -) -from tests.utils import hardware_marks +from tests.helpers.mark import hardware_marks +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler MODEL = "Tongyi-MAI/Z-Image-Turbo" PROMPT = "A high-detail studio photo of an orange tabby cat sitting on a laptop keyboard." diff --git a/tests/engine/test_async_omni_engine_abort.py b/tests/engine/test_async_omni_engine_abort.py index 34fdf45ea25..20fa9e3f82e 100644 --- a/tests/engine/test_async_omni_engine_abort.py +++ b/tests/engine/test_async_omni_engine_abort.py @@ -8,7 +8,7 @@ from vllm import SamplingParams from vllm.inputs import PromptType -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.entrypoints.async_omni import AsyncOmni os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" diff --git a/tests/engine/test_orchestrator.py b/tests/engine/test_orchestrator.py new file mode 100644 index 00000000000..7bf2eccf7f5 --- /dev/null +++ b/tests/engine/test_orchestrator.py @@ -0,0 +1,510 @@ +from __future__ import annotations + +import asyncio +import concurrent.futures +import queue +import threading +import time +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any + +import janus +import pytest +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.sampling_params import SamplingParams + +from vllm_omni.engine.orchestrator import Orchestrator +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +@dataclass +class OrchestratorFixture: + orchestrator: Orchestrator + request_sync_q: Any + output_sync_q: Any + queues: tuple[janus.Queue, ...] + thread: threading.Thread + result_future: concurrent.futures.Future[None] + + +class FakeStageClient: + def __init__( + self, + *, + stage_type: str = "llm", + final_output: bool = False, + final_output_type: str = "text", + next_inputs: list[dict] | None = None, + ) -> None: + self.stage_type = stage_type + self.final_output = final_output + self.final_output_type = final_output_type + self.next_inputs = list(next_inputs or []) + self.custom_process_input_func = None + self.add_request_calls: list[tuple] = [] + self.abort_calls: list[list[str]] = [] + self.shutdown_calls = 0 + self._engine_core_outputs = queue.Queue() + self._diffusion_outputs = queue.Queue() + + # Orchestrator-facing interface. + async def add_request_async(self, *args, **_kwargs) -> None: + self.add_request_calls.append(args) + + async def get_output_async(self): + try: + return self._engine_core_outputs.get_nowait() + except queue.Empty: + return SimpleNamespace(outputs=[]) + + def get_diffusion_output_nowait(self): + try: + return self._diffusion_outputs.get_nowait() + except queue.Empty: + return None + + def set_engine_outputs(self, outputs) -> None: + return None + + def process_engine_inputs(self, stage_list, prompt=None): + return list(self.next_inputs) + + async def abort_requests_async(self, request_ids: list[str]) -> None: + self.abort_calls.append(list(request_ids)) + + def shutdown(self) -> None: + self.shutdown_calls += 1 + + # Test helpers for seeding fake stage outputs. + def push_engine_core_outputs(self, outputs) -> None: + self._engine_core_outputs.put_nowait(outputs) + + def push_diffusion_output(self, output) -> None: + self._diffusion_outputs.put_nowait(output) + + +class FakeOutputProcessor: + def __init__(self, *, request_outputs: list[object] | None = None) -> None: + self.request_outputs = list(request_outputs or []) + + def add_request(self, *_args, **_kwargs) -> None: + return None + + def process_outputs(self, *_args, **_kwargs): + return SimpleNamespace( + request_outputs=list(self.request_outputs), + reqs_to_abort=[], + ) + + def update_scheduler_stats(self, _scheduler_stats) -> None: + return None + + +def _sampling_params(max_tokens: int = 4) -> SamplingParams: + return SamplingParams(max_tokens=max_tokens) + + +def _engine_core_outputs(tag: str, timestamp: float) -> SimpleNamespace: + return SimpleNamespace(outputs=[tag], timestamp=timestamp, scheduler_stats=None) + + +def _build_request_output( + request_id: str, + *, + token_ids: list[int] | None = None, + prompt_token_ids: list[int] | None = None, + finished: bool = True, + text: str = "test", +) -> RequestOutput: + completion = CompletionOutput( + index=0, + text=text, + token_ids=list(token_ids or [1, 2]), + cumulative_logprob=0.0, + logprobs=None, + finish_reason="stop" if finished else None, + stop_reason=None, + ) + return RequestOutput( + request_id=request_id, + prompt="prompt", + prompt_token_ids=list(prompt_token_ids or [10, 11]), + prompt_logprobs=None, + outputs=[completion], + finished=finished, + metrics=None, + lora_request=None, + ) + + +def _build_harness( + stage_clients: list[object], + *, + output_processors: list[object] | None = None, + stage_vllm_configs: list[object] | None = None, + async_chunk: bool = False, +) -> OrchestratorFixture: + if output_processors is None: + output_processors = [FakeOutputProcessor() for _ in stage_clients] + if stage_vllm_configs is None: + stage_vllm_configs = [SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)) for _ in stage_clients] + + ready_future: concurrent.futures.Future[tuple[Orchestrator, janus.Queue, janus.Queue, janus.Queue]] = ( + concurrent.futures.Future() + ) + result_future: concurrent.futures.Future[None] = concurrent.futures.Future() + + def _runner() -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def _run() -> None: + request_queue = janus.Queue() + output_queue = janus.Queue() + rpc_queue = janus.Queue() + orchestrator = Orchestrator( + request_async_queue=request_queue.async_q, + output_async_queue=output_queue.async_q, + rpc_async_queue=rpc_queue.async_q, + stage_clients=stage_clients, + output_processors=output_processors, + stage_vllm_configs=stage_vllm_configs, + async_chunk=async_chunk, + ) + ready_future.set_result((orchestrator, request_queue, output_queue, rpc_queue)) + await orchestrator.run() + + try: + loop.run_until_complete(_run()) + result_future.set_result(None) + except Exception as exc: + result_future.set_exception(exc) + finally: + try: + pending = [task for task in asyncio.all_tasks(loop) if not task.done()] + for task in pending: + task.cancel() + if pending: + loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + asyncio.set_event_loop(None) + loop.close() + + thread = threading.Thread(target=_runner, daemon=True, name="test-orchestrator") + thread.start() + + orchestrator, request_queue, output_queue, rpc_queue = ready_future.result(timeout=5) + return OrchestratorFixture( + orchestrator=orchestrator, + request_sync_q=request_queue.sync_q, + output_sync_q=output_queue.sync_q, + queues=(request_queue, output_queue, rpc_queue), + thread=thread, + result_future=result_future, + ) + + +async def _shutdown_orchestrator(orchestrator_fixture: OrchestratorFixture) -> None: + orchestrator_fixture.request_sync_q.put_nowait({"type": "shutdown"}) + await asyncio.to_thread(orchestrator_fixture.thread.join, 5) + if orchestrator_fixture.thread.is_alive(): + raise AssertionError("Timed out waiting for orchestrator thread shutdown") + orchestrator_fixture.result_future.result(timeout=0) + + +async def _wait_for(predicate, *, timeout: float = 2.0) -> None: + deadline = time.monotonic() + timeout + while not predicate(): + if time.monotonic() >= deadline: + raise AssertionError("Timed out waiting for predicate") + await asyncio.sleep(0.01) + + +async def _get_output_message(orchestrator_fixture: OrchestratorFixture, *, timeout: float = 2.0) -> dict: + deadline = time.monotonic() + timeout + while True: + if time.monotonic() >= deadline: + raise AssertionError("Timed out waiting for orchestrator output") + try: + msg = orchestrator_fixture.output_sync_q.get_nowait() + except queue.Empty: + await asyncio.sleep(0.01) + continue + if msg.get("type") == "output": + return msg + + +async def _enqueue_add_request( + orchestrator_fixture: OrchestratorFixture, + *, + request_id: str, + prompt, + original_prompt, + sampling_params_list, + final_stage_id: int, +) -> None: + orchestrator_fixture.request_sync_q.put_nowait( + { + "type": "add_request", + "request_id": request_id, + "prompt": prompt, + "original_prompt": original_prompt, + "sampling_params_list": sampling_params_list, + "final_stage_id": final_stage_id, + } + ) + + +async def _enqueue_abort_request(orchestrator_fixture: OrchestratorFixture, request_ids: list[str]) -> None: + orchestrator_fixture.request_sync_q.put_nowait( + { + "type": "abort", + "request_ids": request_ids, + } + ) + + +@pytest.fixture +def orchestrator_factory(): + fixtures: list[OrchestratorFixture] = [] + + def _factory(*args, **kwargs) -> OrchestratorFixture: + fixture = _build_harness(*args, **kwargs) + fixtures.append(fixture) + return fixture + + yield _factory + + for fixture in fixtures: + if fixture.thread.is_alive(): + fixture.request_sync_q.put_nowait({"type": "shutdown"}) + fixture.thread.join(timeout=5) + for q in fixture.queues: + q.close() + + +@pytest.mark.asyncio +async def test_run_two_stage_llm(orchestrator_factory) -> None: + stage0 = FakeStageClient(stage_type="llm", final_output=False) + stage1 = FakeStageClient( + stage_type="llm", + final_output=True, + next_inputs=[{"prompt_token_ids": [7, 8, 9]}], + ) + processors = [ + FakeOutputProcessor(request_outputs=[_build_request_output("req-llm", token_ids=[3, 4], finished=True)]), + FakeOutputProcessor(request_outputs=[_build_request_output("req-llm", token_ids=[10, 11], finished=True)]), + ] + orchestrator_fixture = orchestrator_factory([stage0, stage1], output_processors=processors) + request = SimpleNamespace(request_id="req-llm", prompt_token_ids=[1, 2, 3]) + + try: + await _enqueue_add_request( + orchestrator_fixture, + request_id="req-llm", + prompt=request, + original_prompt={"prompt": "hello"}, + sampling_params_list=[_sampling_params(), _sampling_params()], + final_stage_id=1, + ) + + await _wait_for(lambda: len(stage0.add_request_calls) == 1) + stage0.push_engine_core_outputs(_engine_core_outputs("stage0-raw", 1.0)) + + await _wait_for(lambda: len(stage1.add_request_calls) == 1) + stage1_request = stage1.add_request_calls[0][0] + assert stage1_request.request_id == "req-llm" + assert stage1_request.prompt_token_ids == [7, 8, 9] + + stage1.push_engine_core_outputs(_engine_core_outputs("stage1-raw", 2.0)) + + output_msg = await _get_output_message(orchestrator_fixture) + + assert output_msg["request_id"] == "req-llm" + assert output_msg["stage_id"] == 1 + assert output_msg["finished"] is True + assert output_msg["engine_outputs"].request_id == "req-llm" + assert "req-llm" not in orchestrator_fixture.orchestrator.request_states + finally: + await _shutdown_orchestrator(orchestrator_fixture) + + +@pytest.mark.asyncio +async def test_run_single_stage_diffusion(orchestrator_factory) -> None: + stage0 = FakeStageClient(stage_type="diffusion", final_output=True, final_output_type="image") + orchestrator_fixture = orchestrator_factory([stage0]) + params = OmniDiffusionSamplingParams() + + try: + await _enqueue_add_request( + orchestrator_fixture, + request_id="req-diff", + prompt={"prompt": "draw a cat"}, + original_prompt={"prompt": "draw a cat"}, + sampling_params_list=[params], + final_stage_id=0, + ) + + await _wait_for(lambda: len(stage0.add_request_calls) == 1) + stage0.push_diffusion_output( + OmniRequestOutput.from_diffusion( + request_id="req-diff", + images=[], + final_output_type="image", + ) + ) + + output_msg = await _get_output_message(orchestrator_fixture) + + assert output_msg["request_id"] == "req-diff" + assert output_msg["stage_id"] == 0 + assert output_msg["finished"] is True + assert output_msg["engine_outputs"].request_id == "req-diff" + assert "req-diff" not in orchestrator_fixture.orchestrator.request_states + finally: + await _shutdown_orchestrator(orchestrator_fixture) + + +@pytest.mark.asyncio +async def test_run_llm_to_diffusion(orchestrator_factory) -> None: + stage0 = FakeStageClient(stage_type="llm", final_output=False) + stage1 = FakeStageClient(stage_type="diffusion", final_output=True, final_output_type="image") + processors = [ + FakeOutputProcessor(request_outputs=[_build_request_output("req-img", token_ids=[3, 4], finished=True)]), + FakeOutputProcessor(), + ] + orchestrator_fixture = orchestrator_factory([stage0, stage1], output_processors=processors) + request = SimpleNamespace(request_id="req-img", prompt_token_ids=[1, 2, 3]) + params = OmniDiffusionSamplingParams() + original_prompt = {"prompt": "draw a fox"} + + try: + await _enqueue_add_request( + orchestrator_fixture, + request_id="req-img", + prompt=request, + original_prompt=original_prompt, + sampling_params_list=[_sampling_params(), params], + final_stage_id=1, + ) + + await _wait_for(lambda: len(stage0.add_request_calls) == 1) + stage0.push_engine_core_outputs(_engine_core_outputs("stage0-raw", 1.0)) + + await _wait_for(lambda: len(stage1.add_request_calls) == 1) + assert stage1.add_request_calls[0] == ("req-img", original_prompt, params) + + stage1.push_diffusion_output( + OmniRequestOutput.from_diffusion( + request_id="req-img", + images=[], + final_output_type="image", + ) + ) + + output_msg = await _get_output_message(orchestrator_fixture) + + assert output_msg["request_id"] == "req-img" + assert output_msg["stage_id"] == 1 + assert output_msg["finished"] is True + assert output_msg["engine_outputs"].request_id == "req-img" + assert "req-img" not in orchestrator_fixture.orchestrator.request_states + finally: + await _shutdown_orchestrator(orchestrator_fixture) + + +@pytest.mark.asyncio +async def test_run_async_chunk(orchestrator_factory) -> None: + stage0 = FakeStageClient(stage_type="llm", final_output=False) + stage1 = FakeStageClient(stage_type="llm", final_output=True) + processors = [ + FakeOutputProcessor(request_outputs=[_build_request_output("req-async", token_ids=[1], finished=True)]), + FakeOutputProcessor(request_outputs=[_build_request_output("req-async", token_ids=[20, 21], finished=True)]), + ] + orchestrator_fixture = orchestrator_factory( + [stage0, stage1], + output_processors=processors, + async_chunk=True, + ) + request = SimpleNamespace(request_id="req-async", prompt_token_ids=[1, 2, 3, 4]) + + try: + await _enqueue_add_request( + orchestrator_fixture, + request_id="req-async", + prompt=request, + original_prompt={"prompt": "hello async"}, + sampling_params_list=[_sampling_params(), _sampling_params()], + final_stage_id=1, + ) + + await _wait_for(lambda: len(stage1.add_request_calls) == 1) + prewarmed_request = stage1.add_request_calls[0][0] + assert prewarmed_request.request_id == "req-async" + assert prewarmed_request.prompt_token_ids + assert all(token_id == 0 for token_id in prewarmed_request.prompt_token_ids) + + stage1.push_engine_core_outputs(_engine_core_outputs("stage1-final", 3.0)) + + output_msg = await _get_output_message(orchestrator_fixture) + + assert output_msg["request_id"] == "req-async" + assert output_msg["stage_id"] == 1 + assert output_msg["finished"] is True + assert "req-async" not in orchestrator_fixture.orchestrator.request_states + finally: + await _shutdown_orchestrator(orchestrator_fixture) + + +@pytest.mark.asyncio +async def test_run_shutdown(orchestrator_factory) -> None: + stages = [ + FakeStageClient(stage_type="llm", final_output=False), + FakeStageClient(stage_type="diffusion", final_output=True, final_output_type="image"), + ] + orchestrator_fixture = orchestrator_factory(stages) + + await _shutdown_orchestrator(orchestrator_fixture) + + assert not orchestrator_fixture.thread.is_alive() + for stage in stages: + assert stage.shutdown_calls == 1 + + +@pytest.mark.asyncio +async def test_run_abort(orchestrator_factory) -> None: + stages = [ + FakeStageClient(stage_type="llm", final_output=False), + FakeStageClient(stage_type="llm", final_output=True), + ] + processors = [ + FakeOutputProcessor(request_outputs=[_build_request_output("req-abort", token_ids=[1], finished=True)]), + FakeOutputProcessor(request_outputs=[_build_request_output("req-abort", token_ids=[2], finished=True)]), + ] + orchestrator_fixture = orchestrator_factory(stages, output_processors=processors) + request = SimpleNamespace(request_id="req-abort", prompt_token_ids=[1, 2, 3]) + + try: + await _enqueue_add_request( + orchestrator_fixture, + request_id="req-abort", + prompt=request, + original_prompt={"prompt": "cancel me"}, + sampling_params_list=[_sampling_params(), _sampling_params()], + final_stage_id=1, + ) + await _wait_for(lambda: len(stages[0].add_request_calls) == 1) + + await _enqueue_abort_request(orchestrator_fixture, ["req-abort"]) + await _wait_for(lambda: all(stage.abort_calls for stage in stages)) + + for stage in stages: + assert stage.abort_calls == [["req-abort"]] + assert "req-abort" not in orchestrator_fixture.orchestrator.request_states + finally: + await _shutdown_orchestrator(orchestrator_fixture) diff --git a/tests/examples/conftest.py b/tests/examples/conftest.py index 137d15f163f..867731b21f9 100644 --- a/tests/examples/conftest.py +++ b/tests/examples/conftest.py @@ -1,353 +1,3 @@ -""" -Shared fixtures, helpers, and path constants for tests/examples/. -""" +"""Pytest fixtures for tests/examples.""" -import json -import os -import re -import shlex -import subprocess -import sys -import tempfile -from collections import defaultdict -from collections.abc import Callable -from pathlib import Path -from typing import Any, NamedTuple, cast - -import pytest -import torch -from safetensors.torch import save_file - -# --------------------------------------------------------------------------- -# Path constants and fixtures -# --------------------------------------------------------------------------- - -REPO_ROOT = Path(__file__).resolve().parents[2] -EXAMPLES = REPO_ROOT / "examples" - -# Use Python tempfile instead of pytest's tmp_path_factory because -# OUTPUT_DIR is needed in test collection time, but tmp_path_factory is only available in test running time. -# It is needed during test collection because extract_readme_snippets replaces LoRA path with a generated one under OUTPUT_DIR, -# and extract_readme_snippets is called at collection time to generate separate test cases for each README code block. -OUTPUT_DIR = ( - REPO_ROOT / prefix - if (prefix := os.environ.get("OUTPUT_DIR")) - else Path(tempfile.mkdtemp(prefix="vllm_omni_test_examples_")) -) - - -# --------------------------------------------------------------------------- -# Code snippet extraction and asset file helpers -# --------------------------------------------------------------------------- - -# parameters: language, code, h2_title -ReadmeSnippetExtractionSkipPredicate = Callable[[str, str, str], tuple[bool, str]] - - -class ReadmeSnippet(NamedTuple): - language: str - code: str - h2_title: str - index_in_section: int - output_file_path: Path | None = None - skip: tuple[bool, str] = (False, "") - - @property - def test_id(self) -> str: - return f"{ReadmeSnippet._slug(self.h2_title)}_{self.index_in_section:03d}" - - @staticmethod - def extract_readme_snippets( - readme_path: Path, - skipif: ReadmeSnippetExtractionSkipPredicate | None = None, - ) -> list["ReadmeSnippet"]: - import mistune - - markdown = mistune.create_markdown(renderer="ast") - tokens = markdown(readme_path.read_text(encoding="utf-8")) - tokens = cast(list[dict[str, Any]], tokens) # mistune's AST renderer always produces a list, not a str - - h2_title = "" - section_counts: defaultdict[str, int] = defaultdict(int) - snippets: list[ReadmeSnippet] = [] - - for token in tokens: - token_type = token.get("type") - - if token_type == "heading": - level = (token.get("attrs") or {}).get("level") - title = ReadmeSnippet._heading_text(token) - if level == 2: - h2_title = title - continue - - if token_type != "block_code": - continue - - try: - info = token.get("attrs").get("info") # type: ignore[reportOptionalMemberAccess] - language = info.strip().split()[0].lower() # type: ignore[reportOptionalMemberAccess] - - # Common shell aliases to "bash" in several markdown renderers. - if language in {"shell", "sh", "ksh", "zsh"}: - language = "bash" - - if language not in {"bash", "python"}: - continue - except AttributeError: - # The fence is missing explicit language info; skip it. - continue - - key = h2_title - section_counts[key] += 1 - code = token.get("raw", "") - output_file_path = None - if language == "bash": - argv = ReadmeSnippet._normalize_bash_command(code, Path(readme_path.parent)) - code = shlex.join(argv) - output_file_path = ReadmeSnippet._output_file_path_from_argv(argv) - if skipif is not None: - skip_config = skipif(language, code, h2_title) - else: - skip_config = (False, "") - snippet = ReadmeSnippet( - language=language, - code=code, - h2_title=h2_title, - index_in_section=section_counts[key], - output_file_path=output_file_path, - skip=skip_config, - ) - snippets.append(snippet) - - return snippets - - @staticmethod - def _normalize_bash_command(command: str, readme_dir: Path) -> list[str]: - line_joined_command = re.sub(r"\\\s*\n", " ", command).strip() - argv = shlex.split(line_joined_command, comments=True) - assert argv, "README bash fence produced an empty command" - - # Normalize python directory and example script location - if argv[0] in {"python", "python3"}: - argv[0] = sys.executable - if len(argv) > 1 and argv[1].endswith(".py"): - script_arg = argv[1] - script_path = Path(script_arg) - if script_path.is_absolute(): - resolved_script = script_path - else: - # Take the file name only, and append script_dir to its front - resolved_script = readme_dir / script_path.name - assert resolved_script.exists(), ( - f"README bash snippet references a script that does not exist: {script_arg} (resolved to {resolved_script})" - ) - argv[1] = str(resolved_script) - - # Normalize LoRA adapter path and ensure README LoRA assets exist. - try: - lora_arg_idx = argv.index("--lora-path") # Raise ValueError if not found - assert len(argv) > lora_arg_idx + 1, "README bash snippet uses --lora-path without a following value" - - lora_dir = OUTPUT_DIR / "lora" - adapter_model = lora_dir / "adapter_model.safetensors" - adapter_config = lora_dir / "adapter_config.json" - if not adapter_model.exists() or not adapter_config.exists(): - write_zimage_lora(lora_dir, v_scale=8.0) - - argv[lora_arg_idx + 1] = str(lora_dir) - except ValueError: - pass - - return argv - - @staticmethod - def _output_file_path_from_argv(argv: list[str]) -> Path | None: - if "--output" not in argv: - return None - output_param_idx = argv.index("--output") - assert len(argv) > output_param_idx + 1, "README bash snippet uses --output without a following value" - output_arg = argv[output_param_idx + 1] - return Path(output_arg) - - @staticmethod - def _slug(text: str) -> str: - return "".join(ch.lower() if ch.isalnum() else "_" for ch in text).strip("_") - - @staticmethod - def _heading_text(token: dict) -> str: - return "".join(child.get("raw", "") for child in token.get("children", [])).strip() - - -# [TODO] Duplicate `_write_zimage_lora` in tests/e2e/online_serving/test_images_generations_lora.py. Combine these helpers and tests/e2e/offline_inference/test_diffusion_lora.py to test/utils later -def write_zimage_lora(adapter_dir: Path, *, q_scale: float = 0.0, k_scale: float = 0.0, v_scale: float = 0.0): - adapter_dir.mkdir(parents=True, exist_ok=True) - - # Z-Image transformer uses dim=3840 by default. - dim = 3840 - module_name = "transformer.layers.0.attention.to_qkv" - rank = 1 - - lora_a = torch.zeros((rank, dim), dtype=torch.float32) - lora_a[0, 0] = 1.0 - - # QKVParallelLinear packs (Q, K, V) => out dim is 3 * dim (tp=1). - lora_b = torch.zeros((3 * dim, rank), dtype=torch.float32) - if q_scale: - lora_b[:dim, 0] = q_scale - if k_scale: - lora_b[dim : 2 * dim, 0] = k_scale - if v_scale: - lora_b[2 * dim :, 0] = v_scale - - save_file( - { - f"base_model.model.{module_name}.lora_A.weight": lora_a, - f"base_model.model.{module_name}.lora_B.weight": lora_b, - }, - str(adapter_dir / "adapter_model.safetensors"), - ) - (adapter_dir / "adapter_config.json").write_text( - json.dumps( - { - "r": rank, - "lora_alpha": rank, - "target_modules": [module_name], - } - ), - encoding="utf-8", - ) - - -# --------------------------------------------------------------------------- -# Code runner and subprocess helpers -# --------------------------------------------------------------------------- - - -class ExampleRunResult(NamedTuple): - run_dir: Path - assets: list[Path] - - -class ExampleRunner: - """Run extracted README snippets and return generated assets. - - The output materials are organized in a three-level directory structure: - - Set at init: `self.output_root` for all tests (from env OUTPUT_DIR) - - Set at `self.run(...)`: `output_subfolder` for a specific example page (e.g., `example_offline_t2i`) - - Generated by `extract_readme_snippets`: `snippet.test_id` for a specific code block (matching H2 titles, e.g., `basic_usage_001`) - """ - - IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp"} - - def __init__(self, output_root: Path) -> None: - self.output_root = output_root - - def run( - self, snippet: ReadmeSnippet, *, output_subfolder: Path = Path("."), env: dict[str, str] | None = None - ) -> ExampleRunResult: - run_dir = self.output_root / output_subfolder / snippet.test_id - run_dir.mkdir(parents=True, exist_ok=True) - - if snippet.language == "python": - assets = self._run_python_snippet(snippet, run_dir, env) - return ExampleRunResult(run_dir=run_dir, assets=assets) - - if snippet.language == "bash": - asset = self._run_bash_snippet(snippet, run_dir, env) - return ExampleRunResult(run_dir=run_dir, assets=[asset]) - - raise AssertionError(f"Unsupported snippet language: {snippet.language}") - - def _run_python_snippet( - self, snippet: ReadmeSnippet, run_dir: Path, env: dict[str, str] | None = None - ) -> list[Path]: - # Saving the script to a temporary file and `run_cmd` it. - # Not using `exec(snippet.code)` because the output is lost. - script_path = run_dir / "snippet.py" - script_path.write_text(snippet.code, encoding="utf-8") - - before = self._collect_images(run_dir) - run_cmd([sys.executable, str(script_path)], cwd=run_dir, env=env) - after = self._collect_images(run_dir) - - assets = sorted(after - before) - return assets - - def _run_bash_snippet(self, snippet: ReadmeSnippet, run_dir: Path, env: dict[str, str] | None = None) -> Path: - run_cmd(snippet.code, shell=True, cwd=run_dir, env=env) - - assert snippet.output_file_path is not None, ( - f"README bash snippet is missing --output argument: {snippet.test_id}. " - "The test script cannot guess the output file path." - ) - - # If the code snippet declares a relative path for the output file, append this path to the parent output collection directory. - # If the code snippet declares an absolute path (not likely but just in case), the return value resolution removes `run_dir`, also correctly pointing to this file. - return run_dir / snippet.output_file_path - - def _collect_images(self, root: Path) -> set[Path]: - return {path for path in root.rglob("*") if path.suffix.lower() in self.IMAGE_SUFFIXES} - - -@pytest.fixture -def example_runner() -> ExampleRunner: - return ExampleRunner(output_root=OUTPUT_DIR) - - -def run_cmd( - command: list[str] | str, - *, - shell: bool = False, - env: dict[str, str] | None = None, - cwd: Path | str | None = None, -) -> str: - """Run a command as a subprocess; assert zero exit code and return stdout. - - Output is fully captured and returned as a string so callers can parse it - (e.g. with :func:`extract_content_after_keyword`). - Use this for scripts whose printed output is part of the test assertion. - """ - if env is not None: - env = {**os.environ.copy(), **env} - result = subprocess.run(command, capture_output=True, text=True, shell=shell, env=env, cwd=cwd) - - if result.returncode != 0: - print(f"STDERR: {result.stderr}") - raise subprocess.CalledProcessError(result.returncode, command) - - all_output = result.stdout - print(f"All output:\n{all_output}") - return all_output - - -# --------------------------------------------------------------------------- -# Output validation helpers -# --------------------------------------------------------------------------- - - -def extract_content_after_keyword(keywords: str, text: str) -> str: - """Return the text that follows *keywords* in *text* (regex match). - - Raises ``AssertionError`` if the keyword is not found, so test failures - produce a clear message pointing at the missing keyword. - """ - matches = re.findall(rf"{keywords}\s*(.+)", text, re.DOTALL) - - if not matches: - raise AssertionError(f"Keywords {keywords} not found in provided text output") - return matches[0] - - -def strip_trailing_audio_saved_line(text: str) -> str: - """Drop trailing ``Audio saved to ...`` lines from captured client stdout. - - ``openai_chat_completion_client_for_multimodal_generation.py`` may print - ``Chat completion output from text:`` for one choice and ``Audio saved to`` - for another; :func:`extract_content_after_keyword` uses ``re.DOTALL`` and - would otherwise keep the audio progress line inside the *text* segment. - """ - lines = text.splitlines() - while lines and lines[-1].strip().startswith("Audio saved to"): - lines.pop() - return "\n".join(lines).strip() +from tests.examples.helpers import example_runner # noqa: F401 diff --git a/tests/examples/helpers.py b/tests/examples/helpers.py new file mode 100644 index 00000000000..137d15f163f --- /dev/null +++ b/tests/examples/helpers.py @@ -0,0 +1,353 @@ +""" +Shared fixtures, helpers, and path constants for tests/examples/. +""" + +import json +import os +import re +import shlex +import subprocess +import sys +import tempfile +from collections import defaultdict +from collections.abc import Callable +from pathlib import Path +from typing import Any, NamedTuple, cast + +import pytest +import torch +from safetensors.torch import save_file + +# --------------------------------------------------------------------------- +# Path constants and fixtures +# --------------------------------------------------------------------------- + +REPO_ROOT = Path(__file__).resolve().parents[2] +EXAMPLES = REPO_ROOT / "examples" + +# Use Python tempfile instead of pytest's tmp_path_factory because +# OUTPUT_DIR is needed in test collection time, but tmp_path_factory is only available in test running time. +# It is needed during test collection because extract_readme_snippets replaces LoRA path with a generated one under OUTPUT_DIR, +# and extract_readme_snippets is called at collection time to generate separate test cases for each README code block. +OUTPUT_DIR = ( + REPO_ROOT / prefix + if (prefix := os.environ.get("OUTPUT_DIR")) + else Path(tempfile.mkdtemp(prefix="vllm_omni_test_examples_")) +) + + +# --------------------------------------------------------------------------- +# Code snippet extraction and asset file helpers +# --------------------------------------------------------------------------- + +# parameters: language, code, h2_title +ReadmeSnippetExtractionSkipPredicate = Callable[[str, str, str], tuple[bool, str]] + + +class ReadmeSnippet(NamedTuple): + language: str + code: str + h2_title: str + index_in_section: int + output_file_path: Path | None = None + skip: tuple[bool, str] = (False, "") + + @property + def test_id(self) -> str: + return f"{ReadmeSnippet._slug(self.h2_title)}_{self.index_in_section:03d}" + + @staticmethod + def extract_readme_snippets( + readme_path: Path, + skipif: ReadmeSnippetExtractionSkipPredicate | None = None, + ) -> list["ReadmeSnippet"]: + import mistune + + markdown = mistune.create_markdown(renderer="ast") + tokens = markdown(readme_path.read_text(encoding="utf-8")) + tokens = cast(list[dict[str, Any]], tokens) # mistune's AST renderer always produces a list, not a str + + h2_title = "" + section_counts: defaultdict[str, int] = defaultdict(int) + snippets: list[ReadmeSnippet] = [] + + for token in tokens: + token_type = token.get("type") + + if token_type == "heading": + level = (token.get("attrs") or {}).get("level") + title = ReadmeSnippet._heading_text(token) + if level == 2: + h2_title = title + continue + + if token_type != "block_code": + continue + + try: + info = token.get("attrs").get("info") # type: ignore[reportOptionalMemberAccess] + language = info.strip().split()[0].lower() # type: ignore[reportOptionalMemberAccess] + + # Common shell aliases to "bash" in several markdown renderers. + if language in {"shell", "sh", "ksh", "zsh"}: + language = "bash" + + if language not in {"bash", "python"}: + continue + except AttributeError: + # The fence is missing explicit language info; skip it. + continue + + key = h2_title + section_counts[key] += 1 + code = token.get("raw", "") + output_file_path = None + if language == "bash": + argv = ReadmeSnippet._normalize_bash_command(code, Path(readme_path.parent)) + code = shlex.join(argv) + output_file_path = ReadmeSnippet._output_file_path_from_argv(argv) + if skipif is not None: + skip_config = skipif(language, code, h2_title) + else: + skip_config = (False, "") + snippet = ReadmeSnippet( + language=language, + code=code, + h2_title=h2_title, + index_in_section=section_counts[key], + output_file_path=output_file_path, + skip=skip_config, + ) + snippets.append(snippet) + + return snippets + + @staticmethod + def _normalize_bash_command(command: str, readme_dir: Path) -> list[str]: + line_joined_command = re.sub(r"\\\s*\n", " ", command).strip() + argv = shlex.split(line_joined_command, comments=True) + assert argv, "README bash fence produced an empty command" + + # Normalize python directory and example script location + if argv[0] in {"python", "python3"}: + argv[0] = sys.executable + if len(argv) > 1 and argv[1].endswith(".py"): + script_arg = argv[1] + script_path = Path(script_arg) + if script_path.is_absolute(): + resolved_script = script_path + else: + # Take the file name only, and append script_dir to its front + resolved_script = readme_dir / script_path.name + assert resolved_script.exists(), ( + f"README bash snippet references a script that does not exist: {script_arg} (resolved to {resolved_script})" + ) + argv[1] = str(resolved_script) + + # Normalize LoRA adapter path and ensure README LoRA assets exist. + try: + lora_arg_idx = argv.index("--lora-path") # Raise ValueError if not found + assert len(argv) > lora_arg_idx + 1, "README bash snippet uses --lora-path without a following value" + + lora_dir = OUTPUT_DIR / "lora" + adapter_model = lora_dir / "adapter_model.safetensors" + adapter_config = lora_dir / "adapter_config.json" + if not adapter_model.exists() or not adapter_config.exists(): + write_zimage_lora(lora_dir, v_scale=8.0) + + argv[lora_arg_idx + 1] = str(lora_dir) + except ValueError: + pass + + return argv + + @staticmethod + def _output_file_path_from_argv(argv: list[str]) -> Path | None: + if "--output" not in argv: + return None + output_param_idx = argv.index("--output") + assert len(argv) > output_param_idx + 1, "README bash snippet uses --output without a following value" + output_arg = argv[output_param_idx + 1] + return Path(output_arg) + + @staticmethod + def _slug(text: str) -> str: + return "".join(ch.lower() if ch.isalnum() else "_" for ch in text).strip("_") + + @staticmethod + def _heading_text(token: dict) -> str: + return "".join(child.get("raw", "") for child in token.get("children", [])).strip() + + +# [TODO] Duplicate `_write_zimage_lora` in tests/e2e/online_serving/test_images_generations_lora.py. Combine these helpers and tests/e2e/offline_inference/test_diffusion_lora.py to test/utils later +def write_zimage_lora(adapter_dir: Path, *, q_scale: float = 0.0, k_scale: float = 0.0, v_scale: float = 0.0): + adapter_dir.mkdir(parents=True, exist_ok=True) + + # Z-Image transformer uses dim=3840 by default. + dim = 3840 + module_name = "transformer.layers.0.attention.to_qkv" + rank = 1 + + lora_a = torch.zeros((rank, dim), dtype=torch.float32) + lora_a[0, 0] = 1.0 + + # QKVParallelLinear packs (Q, K, V) => out dim is 3 * dim (tp=1). + lora_b = torch.zeros((3 * dim, rank), dtype=torch.float32) + if q_scale: + lora_b[:dim, 0] = q_scale + if k_scale: + lora_b[dim : 2 * dim, 0] = k_scale + if v_scale: + lora_b[2 * dim :, 0] = v_scale + + save_file( + { + f"base_model.model.{module_name}.lora_A.weight": lora_a, + f"base_model.model.{module_name}.lora_B.weight": lora_b, + }, + str(adapter_dir / "adapter_model.safetensors"), + ) + (adapter_dir / "adapter_config.json").write_text( + json.dumps( + { + "r": rank, + "lora_alpha": rank, + "target_modules": [module_name], + } + ), + encoding="utf-8", + ) + + +# --------------------------------------------------------------------------- +# Code runner and subprocess helpers +# --------------------------------------------------------------------------- + + +class ExampleRunResult(NamedTuple): + run_dir: Path + assets: list[Path] + + +class ExampleRunner: + """Run extracted README snippets and return generated assets. + + The output materials are organized in a three-level directory structure: + - Set at init: `self.output_root` for all tests (from env OUTPUT_DIR) + - Set at `self.run(...)`: `output_subfolder` for a specific example page (e.g., `example_offline_t2i`) + - Generated by `extract_readme_snippets`: `snippet.test_id` for a specific code block (matching H2 titles, e.g., `basic_usage_001`) + """ + + IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp"} + + def __init__(self, output_root: Path) -> None: + self.output_root = output_root + + def run( + self, snippet: ReadmeSnippet, *, output_subfolder: Path = Path("."), env: dict[str, str] | None = None + ) -> ExampleRunResult: + run_dir = self.output_root / output_subfolder / snippet.test_id + run_dir.mkdir(parents=True, exist_ok=True) + + if snippet.language == "python": + assets = self._run_python_snippet(snippet, run_dir, env) + return ExampleRunResult(run_dir=run_dir, assets=assets) + + if snippet.language == "bash": + asset = self._run_bash_snippet(snippet, run_dir, env) + return ExampleRunResult(run_dir=run_dir, assets=[asset]) + + raise AssertionError(f"Unsupported snippet language: {snippet.language}") + + def _run_python_snippet( + self, snippet: ReadmeSnippet, run_dir: Path, env: dict[str, str] | None = None + ) -> list[Path]: + # Saving the script to a temporary file and `run_cmd` it. + # Not using `exec(snippet.code)` because the output is lost. + script_path = run_dir / "snippet.py" + script_path.write_text(snippet.code, encoding="utf-8") + + before = self._collect_images(run_dir) + run_cmd([sys.executable, str(script_path)], cwd=run_dir, env=env) + after = self._collect_images(run_dir) + + assets = sorted(after - before) + return assets + + def _run_bash_snippet(self, snippet: ReadmeSnippet, run_dir: Path, env: dict[str, str] | None = None) -> Path: + run_cmd(snippet.code, shell=True, cwd=run_dir, env=env) + + assert snippet.output_file_path is not None, ( + f"README bash snippet is missing --output argument: {snippet.test_id}. " + "The test script cannot guess the output file path." + ) + + # If the code snippet declares a relative path for the output file, append this path to the parent output collection directory. + # If the code snippet declares an absolute path (not likely but just in case), the return value resolution removes `run_dir`, also correctly pointing to this file. + return run_dir / snippet.output_file_path + + def _collect_images(self, root: Path) -> set[Path]: + return {path for path in root.rglob("*") if path.suffix.lower() in self.IMAGE_SUFFIXES} + + +@pytest.fixture +def example_runner() -> ExampleRunner: + return ExampleRunner(output_root=OUTPUT_DIR) + + +def run_cmd( + command: list[str] | str, + *, + shell: bool = False, + env: dict[str, str] | None = None, + cwd: Path | str | None = None, +) -> str: + """Run a command as a subprocess; assert zero exit code and return stdout. + + Output is fully captured and returned as a string so callers can parse it + (e.g. with :func:`extract_content_after_keyword`). + Use this for scripts whose printed output is part of the test assertion. + """ + if env is not None: + env = {**os.environ.copy(), **env} + result = subprocess.run(command, capture_output=True, text=True, shell=shell, env=env, cwd=cwd) + + if result.returncode != 0: + print(f"STDERR: {result.stderr}") + raise subprocess.CalledProcessError(result.returncode, command) + + all_output = result.stdout + print(f"All output:\n{all_output}") + return all_output + + +# --------------------------------------------------------------------------- +# Output validation helpers +# --------------------------------------------------------------------------- + + +def extract_content_after_keyword(keywords: str, text: str) -> str: + """Return the text that follows *keywords* in *text* (regex match). + + Raises ``AssertionError`` if the keyword is not found, so test failures + produce a clear message pointing at the missing keyword. + """ + matches = re.findall(rf"{keywords}\s*(.+)", text, re.DOTALL) + + if not matches: + raise AssertionError(f"Keywords {keywords} not found in provided text output") + return matches[0] + + +def strip_trailing_audio_saved_line(text: str) -> str: + """Drop trailing ``Audio saved to ...`` lines from captured client stdout. + + ``openai_chat_completion_client_for_multimodal_generation.py`` may print + ``Chat completion output from text:`` for one choice and ``Audio saved to`` + for another; :func:`extract_content_after_keyword` uses ``re.DOTALL`` and + would otherwise keep the audio progress line inside the *text* segment. + """ + lines = text.splitlines() + while lines and lines[-1].strip().startswith("Audio saved to"): + lines.pop() + return "\n".join(lines).strip() diff --git a/tests/examples/offline_inference/test_text_to_image.py b/tests/examples/offline_inference/test_text_to_image.py index a08d16f1614..f24506587c1 100644 --- a/tests/examples/offline_inference/test_text_to_image.py +++ b/tests/examples/offline_inference/test_text_to_image.py @@ -7,9 +7,9 @@ import pytest -from tests.conftest import assert_image_valid -from tests.examples.conftest import EXAMPLES, ExampleRunner, ReadmeSnippet -from tests.utils import hardware_marks +from tests.examples.helpers import EXAMPLES, ExampleRunner, ReadmeSnippet +from tests.helpers.assertions import assert_image_valid +from tests.helpers.mark import hardware_marks pytestmark = [pytest.mark.advanced_model, pytest.mark.example, *hardware_marks(res={"cuda": "H100"})] diff --git a/tests/examples/online_serving/test_qwen2_5_omni.py b/tests/examples/online_serving/test_qwen2_5_omni.py index a78ccf5924a..2f631c1fe8b 100644 --- a/tests/examples/online_serving/test_qwen2_5_omni.py +++ b/tests/examples/online_serving/test_qwen2_5_omni.py @@ -13,13 +13,14 @@ import pytest -from tests.conftest import OmniServerParams, convert_audio_file_to_text, cosine_similarity_text -from tests.examples.conftest import ( +from tests.examples.helpers import ( extract_content_after_keyword, run_cmd, strip_trailing_audio_saved_line, ) -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.media import convert_audio_file_to_text, cosine_similarity_text +from tests.helpers.runtime import OmniServerParams pytestmark = [pytest.mark.advanced_model, pytest.mark.example] diff --git a/tests/examples/online_serving/test_qwen3_omni.py b/tests/examples/online_serving/test_qwen3_omni.py index 65f99d7bf28..3a16b2f7633 100644 --- a/tests/examples/online_serving/test_qwen3_omni.py +++ b/tests/examples/online_serving/test_qwen3_omni.py @@ -13,13 +13,14 @@ import pytest -from tests.conftest import OmniServerParams, convert_audio_file_to_text, cosine_similarity_text -from tests.examples.conftest import ( +from tests.examples.helpers import ( extract_content_after_keyword, run_cmd, strip_trailing_audio_saved_line, ) -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.media import convert_audio_file_to_text, cosine_similarity_text +from tests.helpers.runtime import OmniServerParams pytestmark = [pytest.mark.advanced_model, pytest.mark.example] diff --git a/tests/examples/online_serving/test_text_to_image.py b/tests/examples/online_serving/test_text_to_image.py index 51b7ff61bc9..6f89cb5496b 100644 --- a/tests/examples/online_serving/test_text_to_image.py +++ b/tests/examples/online_serving/test_text_to_image.py @@ -13,9 +13,10 @@ import pytest -from tests.conftest import OmniServer, OmniServerParams, assert_image_valid -from tests.examples.conftest import EXAMPLES, OUTPUT_DIR, run_cmd, write_zimage_lora -from tests.utils import hardware_marks +from tests.examples.helpers import EXAMPLES, OUTPUT_DIR, run_cmd, write_zimage_lora +from tests.helpers.assertions import assert_image_valid +from tests.helpers.mark import hardware_marks +from tests.helpers.runtime import OmniServer, OmniServerParams pytestmark = [pytest.mark.advanced_model, pytest.mark.example, *hardware_marks(res={"cuda": "H100"})] diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 00000000000..a3348b07fe0 --- /dev/null +++ b/tests/helpers/__init__.py @@ -0,0 +1,8 @@ +"""Shared, importable test helper utilities. + +Submodules (``assertions``, ``env``, ``media``, ``runtime``, …) are imported +explicitly by callers. Avoid star-importing everything here: that ran before +refactor only inside the old monolithic ``conftest``; a greedy ``__init__`` +changes import order and can affect in-process Omni (``OmniRunner`` / offline +e2e) vs subprocess-based ``OmniServer`` tests. +""" diff --git a/tests/helpers/assertions.py b/tests/helpers/assertions.py new file mode 100644 index 00000000000..44346150d20 --- /dev/null +++ b/tests/helpers/assertions.py @@ -0,0 +1,414 @@ +"""Assertion and response validation helpers for tests.""" + +import io +import tempfile +import threading +from io import BytesIO +from pathlib import Path +from typing import Any + +import numpy as np +import soundfile as sf +from PIL import Image +from transformers import pipeline + +from tests.helpers.media import cosine_similarity_text + +_GENDER_PIPELINE = None +_GENDER_PIPELINE_LOCK = threading.Lock() +_PCM_SPEECH_SAMPLE_RATE_HZ = 24_000 +_MIN_PCM_SPEECH_HNR_DB = 1.0 +_PRESET_VOICE_GENDER_MAP: dict[str, str] = { + "serena": "female", + "uncle_fu": "male", + "chelsie": "female", + "clone": "female", + "ethan": "male", +} + + +def assert_image_diffusion_response( + response, + request_config: dict[str, Any], + run_level: str = None, +) -> None: + """ + Validate image diffusion response. + + Expected request_config schema: + { + "request_type": "image", + "extra_body": { + "num_outputs_per_prompt": 1, + "width": ..., + "height": ..., + ... + } + } + """ + assert response.images is not None, "Image response is None" + assert len(response.images) > 0, "No images in response" + + extra_body = request_config.get("extra_body") or {} + + num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt") + if num_outputs_per_prompt is not None: + assert len(response.images) == num_outputs_per_prompt, ( + f"Expected {num_outputs_per_prompt} images, got {len(response.images)}" + ) + + if run_level == "advanced_model": + width = extra_body.get("width") + height = extra_body.get("height") + + if width is not None or height is not None: + for img in response.images: + assert_image_valid(img, width=width, height=height) + + +def assert_video_diffusion_response( + response, + request_config: dict[str, Any], + run_level: str = None, +) -> None: + """ + Validate video diffusion response. + + Expected request_config schema: + { + "request_type": "video", + "form_data": { + "prompt": "...", + "num_frames": ..., + "width": ..., + "height": ..., + "fps": ..., + ... + } + } + """ + form_data = request_config.get("form_data", {}) + + assert response.videos is not None, "Video response is None" + assert len(response.videos) > 0, "No videos in response" + + expected_frames = _maybe_int(form_data.get("num_frames")) + expected_width = _maybe_int(form_data.get("width")) + expected_height = _maybe_int(form_data.get("height")) + expected_fps = _maybe_int(form_data.get("fps")) + + for vid_bytes in response.videos: + assert_video_valid( + vid_bytes, + num_frames=expected_frames, + width=expected_width, + height=expected_height, + fps=expected_fps, + ) + + +def assert_audio_diffusion_response( + response, + request_config: dict[str, Any], + run_level: str = None, +) -> None: + """ + Validate audio diffusion response. + """ + raise NotImplementedError("Audio validation is not implemented yet") + + +def _maybe_int(value: Any) -> int | None: + if value is None: + return None + return int(value) + + +def assert_image_valid(image: Path | Image.Image, *, width: int | None = None, height: int | None = None): + """Assert the file is a loadable image with optional exact dimensions.""" + if isinstance(image, Path): + assert image.exists(), f"Image not found: {image}" + image = Image.open(image) + image.load() + assert image.width > 0 and image.height > 0 + if width is not None: + assert image.width == width, f"Expected width={width}, got {image.width}" + if height is not None: + assert image.height == height, f"Expected height={height}, got {image.height}" + return image + + +def assert_video_valid( + video: Path | bytes | BytesIO, + *, + num_frames: int | None = None, + width: int | None = None, + height: int | None = None, + fps: float | None = None, +) -> dict[str, int | float]: + """Assert the MP4 has the expected resolution and exact frame count.""" + temp_path = None + cap = None + try: + import cv2 + + if isinstance(video, Path): + if not video.exists(): + raise AssertionError(f"Video file not found: {video}") + video_path = str(video) + else: + suffix = ".mp4" + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix, mode="wb") as tmp: + if isinstance(video, bytes): + tmp.write(video) + elif isinstance(video, BytesIO): + tmp.write(video.getvalue()) + else: + raise TypeError(f"Unsupported video type: {type(video)}") + temp_path = Path(tmp.name) + video_path = str(temp_path) + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise AssertionError("Failed to open video capture") + + actual_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + actual_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + actual_fps = float(cap.get(cv2.CAP_PROP_FPS)) + actual_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + if width is not None: + assert actual_width == width, f"Expected width={width}, got {actual_width}" + if height is not None: + assert actual_height == height, f"Expected height={height}, got {actual_height}" + if fps is not None and actual_fps: + assert abs(actual_fps - float(fps)) < 1.0, f"Expected fps~={fps}, got {actual_fps}" + if num_frames is not None: + assert actual_frames == num_frames, f"Expected frames={num_frames}, got {actual_frames}" + + return { + "width": actual_width, + "height": actual_height, + "fps": actual_fps, + "num_frames": actual_frames, + } + finally: + if cap is not None: + cap.release() + if temp_path and temp_path.exists(): + try: + temp_path.unlink() + except OSError: + pass + + +def assert_audio_valid(path: Path, *, sample_rate: int, channels: int, duration_s: float) -> None: + """Assert the WAV has the expected sample rate, channel count, and duration.""" + assert path.exists(), f"Audio not found: {path}" + info = sf.info(str(path)) + assert info.samplerate == sample_rate, f"Expected sample_rate={sample_rate}, got {info.samplerate}" + assert info.channels == channels, f"Expected {channels} channel(s), got {info.channels}" + expected_frames = int(duration_s * sample_rate) + assert info.frames == expected_frames, ( + f"Expected {expected_frames} frames ({duration_s}s @ {sample_rate} Hz), got {info.frames}" + ) + + +def _load_gender_pipeline(): + global _GENDER_PIPELINE + if _GENDER_PIPELINE is not None: + return _GENDER_PIPELINE + model_name = "7wolf/wav2vec2-base-gender-classification" + try: + _GENDER_PIPELINE = pipeline(task="audio-classification", model=model_name, device=-1) + return _GENDER_PIPELINE + except Exception as exc: # pragma: no cover + print(f"Warning: failed to create gender pipeline '{model_name}': {exc}") + _GENDER_PIPELINE = None + return None + + +def _median_pitch_hz_from_autocorr(mono: np.ndarray, sr: int) -> float | None: + x = np.asarray(mono, dtype=np.float64) + x = x - np.mean(x) + if x.size < int(0.15 * sr): + return None + frame_len = int(0.04 * sr) + hop = max(frame_len // 2, 1) + f0_min_hz, f0_max_hz = 70.0, 400.0 + lag_min = max(1, int(sr / f0_max_hz)) + lag_max = min(frame_len - 2, int(sr / f0_min_hz)) + if lag_max <= lag_min: + return None + win = np.hamming(frame_len) + pitches: list[float] = [] + for start in range(0, int(x.shape[0]) - frame_len, hop): + frame = x[start : start + frame_len] * win + frame = frame - np.mean(frame) + if float(np.sqrt(np.mean(frame**2))) < 1e-4: + continue + ac = np.correlate(frame, frame, mode="full")[frame_len - 1 :] + ac = ac / (float(ac[0]) + 1e-12) + region = ac[lag_min : lag_max + 1] + peak_rel = int(np.argmax(region)) + peak_lag = peak_rel + lag_min + if peak_lag <= 0: + continue + f0 = float(sr) / float(peak_lag) + if f0_min_hz <= f0 <= f0_max_hz: + pitches.append(f0) + if len(pitches) < 4: + return None + return float(np.median(np.asarray(pitches, dtype=np.float64))) + + +def _estimate_voice_gender_from_audio(audio_bytes: bytes) -> str: + data, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True) + if data.size == 0: + raise ValueError("Empty audio") + mono = np.mean(data, axis=1) + try: + target_sr = 16000 + if int(sr) != target_sr and mono.size > 1: + src_len = int(mono.shape[0]) + dst_len = max(1, int(round(src_len * float(target_sr) / float(sr)))) + src_idx = np.arange(src_len, dtype=np.float32) + dst_idx = np.linspace(0, src_len - 1, dst_len, dtype=np.float32) + mono = np.interp(dst_idx, src_idx, mono.astype(np.float32, copy=False)).astype(np.float32) + sr = target_sr + + median_f0 = _median_pitch_hz_from_autocorr(mono, sr) + clf = _load_gender_pipeline() + if clf is None: + return "unknown" + with _GENDER_PIPELINE_LOCK: + outputs = clf(mono, sampling_rate=sr) + if not outputs: + return "unknown" + top = outputs[0] + label = str(top.get("label", "")).lower() + conf = float(top.get("score", 0.0)) + if conf < 0.5: + gender = "unknown" + elif ("female" in label) or ("жен" in label): + gender = "female" + elif ("male" in label) or ("муж" in label): + gender = "male" + else: + gender = "unknown" + + if gender == "female" and median_f0 is not None and median_f0 < 165.0 and conf < 0.88: + gender = "male" + elif gender == "male" and median_f0 is not None and median_f0 > 230.0 and conf < 0.88: + gender = "female" + return gender + except Exception: # pragma: no cover + return "unknown" + + +def _assert_preset_voice_gender_from_audio(audio_bytes: bytes | None, voice_name: str | None) -> None: + if not voice_name or not audio_bytes: + return + expected_gender = _PRESET_VOICE_GENDER_MAP.get(str(voice_name).lower()) + if expected_gender is None: + return + estimated_gender = _estimate_voice_gender_from_audio(audio_bytes) + if estimated_gender != "unknown": + assert estimated_gender == expected_gender + + +def _compute_pcm_hnr_db(pcm_samples: np.ndarray, sr: int = _PCM_SPEECH_SAMPLE_RATE_HZ) -> float: + frame_len = int(0.03 * sr) + hop = frame_len // 2 + hnr_values: list[float] = [] + for start in range(0, len(pcm_samples) - frame_len, hop): + frame = pcm_samples[start : start + frame_len].astype(np.float32, copy=False) + frame = frame - np.mean(frame) + if np.max(np.abs(frame)) < 0.01: + continue + ac = np.correlate(frame, frame, mode="full")[len(frame) - 1 :] + ac = ac / (ac[0] + 1e-10) + min_lag = int(sr / 400) + max_lag = min(int(sr / 80), len(ac)) + if min_lag >= max_lag: + continue + peak = float(np.max(ac[min_lag:max_lag])) + if 0 < peak < 1: + hnr_values.append(10 * np.log10(peak / (1 - peak + 1e-10))) + return float(np.mean(hnr_values)) if hnr_values else 0.0 + + +def _assert_pcm_int16_speech_hnr(audio_bytes: bytes) -> None: + assert audio_bytes is not None and len(audio_bytes) >= 2, "missing PCM bytes" + assert len(audio_bytes) % 2 == 0, "PCM byte length must be aligned to int16" + pcm_samples = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0 + hnr = _compute_pcm_hnr_db(pcm_samples) + assert hnr >= _MIN_PCM_SPEECH_HNR_DB + + +def assert_omni_response(response: Any, request_config: dict[str, Any], run_level): + assert response.success, "The request failed." + modalities = request_config.get("modalities", ["text", "audio"]) + if run_level == "advanced_model": + if "audio" in modalities: + assert response.audio_content is not None, "No audio output is generated" + speaker = request_config.get("speaker") + if speaker: + _assert_preset_voice_gender_from_audio(response.audio_bytes, speaker) + if "text" in modalities: + assert response.text_content is not None, "No text output is generated" + keywords_dict = request_config.get("key_words", {}) + for word_type in ["text", "image", "audio", "video"]: + keywords = keywords_dict.get(word_type) + if not keywords: + continue + if "text" in modalities: + text_lower = (response.text_content or "").lower() + assert any(str(kw).lower() in text_lower for kw in keywords) + else: + audio_lower = (response.audio_content or "").lower() + assert any(str(kw).lower() in audio_lower for kw in keywords) + if "text" in modalities and "audio" in modalities: + assert response.similarity is not None and response.similarity > 0.9 + + +def assert_audio_speech_response(response: Any, request_config: dict[str, Any], run_level: str) -> None: + assert response.success, "The request failed." + req_fmt = request_config.get("response_format") + if req_fmt == "pcm" and response.audio_bytes: + _assert_pcm_int16_speech_hnr(response.audio_bytes) + elif req_fmt == "wav" and response.audio_format: + assert req_fmt in response.audio_format + + if run_level == "advanced_model" and req_fmt != "pcm": + expected_text = request_config.get("input") + if expected_text: + transcript = (response.audio_content or "").strip() + similarity = cosine_similarity_text(transcript.lower(), expected_text.lower()) + assert similarity > 0.9 + _assert_preset_voice_gender_from_audio(response.audio_bytes, request_config.get("voice")) + + +def assert_diffusion_response(response: Any, request_config: dict[str, Any], run_level: str = None): + assert response.success, "The request failed." + has_any_content = any(content is not None for content in (response.images, response.videos, response.audios)) + assert has_any_content, "Response contains no images, videos, or audios" + if response.images is not None: + assert_image_diffusion_response(response=response, request_config=request_config, run_level=run_level) + if response.videos is not None: + assert_video_diffusion_response(response=response, request_config=request_config, run_level=run_level) + if response.audios is not None: + assert_audio_diffusion_response(response=response, request_config=request_config, run_level=run_level) + + +__all__ = [ + "assert_audio_diffusion_response", + "assert_audio_speech_response", + "assert_diffusion_response", + "assert_image_diffusion_response", + "assert_image_valid", + "assert_omni_response", + "assert_video_diffusion_response", + "assert_video_valid", + "assert_audio_valid", +] diff --git a/tests/helpers/env.py b/tests/helpers/env.py new file mode 100644 index 00000000000..ac7f226b7ea --- /dev/null +++ b/tests/helpers/env.py @@ -0,0 +1,207 @@ +"""Test environment / lifecycle helpers (GPU cleanup hooks and memory monitoring for tests).""" + +from __future__ import annotations + +import gc +import os +import threading +import time +from contextlib import contextmanager + +import torch +from vllm.platforms import current_platform + +from vllm_omni.platforms import current_omni_platform + +if current_platform.is_rocm(): + from amdsmi import ( + amdsmi_get_gpu_vram_usage, + amdsmi_get_processor_handles, + amdsmi_init, + amdsmi_shut_down, + ) + + @contextmanager + def _nvml(): + try: + amdsmi_init() + yield + finally: + amdsmi_shut_down() +elif current_platform.is_cuda(): + from vllm.third_party.pynvml import ( + nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, + nvmlInit, + nvmlShutdown, + ) + + @contextmanager + def _nvml(): + try: + nvmlInit() + yield + finally: + nvmlShutdown() +else: + + @contextmanager + def _nvml(): + yield + + +def get_physical_device_indices(devices): + visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") + if visible_devices is None: + return devices + visible_indices = [int(x) for x in visible_devices.split(",")] + index_mapping = {i: physical for i, physical in enumerate(visible_indices)} + return [index_mapping[i] for i in devices if i in index_mapping] + + +@_nvml() +def wait_for_gpu_memory_to_clear( + *, + devices: list[int], + threshold_bytes: int | None = None, + threshold_ratio: float | None = None, + timeout_s: float = 120, +) -> None: + assert threshold_bytes is not None or threshold_ratio is not None + devices = get_physical_device_indices(devices) + start_time = time.time() + + device_list = ", ".join(str(d) for d in devices) + if threshold_bytes is not None: + threshold_str = f"{threshold_bytes / 2**30:.2f} GiB" + condition_str = f"Memory usage ≤ {threshold_str}" + else: + threshold_percent = threshold_ratio * 100 + threshold_str = f"{threshold_percent:.1f}%" + condition_str = f"Memory usage ratio ≤ {threshold_str}" + + print(f"[GPU Memory Monitor] Waiting for GPU {device_list} to free memory, Condition: {condition_str}") + + if threshold_bytes is not None: + + def is_free(used, total): + return used <= threshold_bytes / 2**30 + else: + + def is_free(used, total): + return used / total <= threshold_ratio + + while True: + output: dict[int, str] = {} + output_raw: dict[int, tuple[float, float]] = {} + for device in devices: + if current_platform.is_rocm(): + dev_handle = amdsmi_get_processor_handles()[device] + mem_info = amdsmi_get_gpu_vram_usage(dev_handle) + gb_used = mem_info["vram_used"] / 2**10 + gb_total = mem_info["vram_total"] / 2**10 + else: + dev_handle = nvmlDeviceGetHandleByIndex(device) + mem_info = nvmlDeviceGetMemoryInfo(dev_handle) + gb_used = mem_info.used / 2**30 + gb_total = mem_info.total / 2**30 + output_raw[device] = (gb_used, gb_total) + usage_percent = (gb_used / gb_total) * 100 if gb_total > 0 else 0 + output[device] = f"{gb_used:.1f}GiB/{gb_total:.1f}GiB ({usage_percent:.1f}%)" + + print("[GPU Memory Status] Current usage:") + for device_id, mem_info in output.items(): + print(f" GPU {device_id}: {mem_info}") + + dur_s = time.time() - start_time + elapsed_minutes = dur_s / 60 + if all(is_free(used, total) for used, total in output_raw.values()): + print(f"[GPU Memory Freed] Devices {device_list} meet memory condition") + print(f" Condition: {condition_str}") + print(f" Wait time: {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)") + break + + if dur_s >= timeout_s: + raise ValueError( + f"[GPU Memory Timeout] Devices {device_list} still don't meet memory condition after {dur_s:.1f} seconds\n" + f"Condition: {condition_str}\n" + f"Current status:\n" + "\n".join(f" GPU {device}: {output[device]}" for device in devices) + ) + + gc.collect() + torch.cuda.empty_cache() + time.sleep(5) + + +def _run_pre_test_cleanup(enable_force: bool = False) -> None: + if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force: + return + + num_gpus = torch.cuda.device_count() + if num_gpus > 0: + try: + wait_for_gpu_memory_to_clear( + devices=list(range(num_gpus)), + threshold_ratio=0.05, + ) + except Exception as e: + print(f"Pre-test cleanup note: {e}") + + +def _run_post_test_cleanup(enable_force: bool = False) -> None: + if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force: + return + + if torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + + +class DeviceMemoryMonitor: + """Poll global device memory usage.""" + + def __init__(self, device_index: int, interval: float = 0.05): + self.device_index = device_index + self.interval = interval + self._peak_used_mb = 0.0 + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + + def start(self) -> None: + def monitor_loop() -> None: + while not self._stop_event.is_set(): + try: + with current_omni_platform.device(self.device_index): + free_bytes, total_bytes = current_omni_platform.mem_get_info() + used_mb = (total_bytes - free_bytes) / (1024**2) + self._peak_used_mb = max(self._peak_used_mb, used_mb) + except Exception: + pass + time.sleep(self.interval) + + self._thread = threading.Thread(target=monitor_loop, daemon=False) + self._thread.start() + + def stop(self) -> None: + if self._thread is None: + return + self._stop_event.set() + self._thread.join(timeout=2.0) + + @property + def peak_used_mb(self) -> float: + fallback_alloc = current_omni_platform.max_memory_allocated(device=self.device_index) / (1024**2) + fallback_reserved = current_omni_platform.max_memory_reserved(device=self.device_index) / (1024**2) + return max(self._peak_used_mb, fallback_alloc, fallback_reserved) + + def __del__(self): + self.stop() + + +__all__ = [ + "DeviceMemoryMonitor", + "_run_post_test_cleanup", + "_run_pre_test_cleanup", + "get_physical_device_indices", + "wait_for_gpu_memory_to_clear", +] diff --git a/tests/helpers/fixtures/__init__.py b/tests/helpers/fixtures/__init__.py new file mode 100644 index 00000000000..8bd090b7824 --- /dev/null +++ b/tests/helpers/fixtures/__init__.py @@ -0,0 +1 @@ +"""Pytest fixture modules under tests.helpers.""" diff --git a/tests/helpers/fixtures/env.py b/tests/helpers/fixtures/env.py new file mode 100644 index 00000000000..0126ff2782c --- /dev/null +++ b/tests/helpers/fixtures/env.py @@ -0,0 +1,56 @@ +import os + +import pytest +import torch + +from tests.helpers.env import _run_post_test_cleanup, _run_pre_test_cleanup + + +@pytest.fixture(scope="session", autouse=True) +def default_env(): + # Keep behavior but avoid import-time side effects (RFC #2299). + keys = ("VLLM_WORKER_MULTIPROC_METHOD", "VLLM_TARGET_DEVICE") + previous = {key: os.environ.get(key) for key in keys} + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = previous["VLLM_WORKER_MULTIPROC_METHOD"] or "spawn" + os.environ["VLLM_TARGET_DEVICE"] = previous["VLLM_TARGET_DEVICE"] or ( + "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu" + ) + yield + for key, value in previous.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + +@pytest.fixture(scope="session") +def model_prefix() -> str: + prefix = os.environ.get("MODEL_PREFIX", "") + return f"{prefix.rstrip('/')}/" if prefix else "" + + +@pytest.fixture(autouse=True) +def clean_gpu_memory_between_tests(): + _run_pre_test_cleanup() + yield + _run_post_test_cleanup() + + +@pytest.fixture(scope="session", autouse=True) +def default_vllm_config(): + """Set a default VllmConfig for the whole test session. + + Session scope ensures module-scoped fixtures (e.g. ``omni_runner``) and + deferred imports of ``tests.helpers.runtime`` both see the same context. + Function-scoped autouse ran too late for ``OmniRunner`` setup and could + desynchronize vLLM init vs request preprocessing (e.g. renderer state). + """ + from vllm.config import DeviceConfig, VllmConfig, set_current_vllm_config + + # Use CPU device if no GPU is available (e.g., in CI environments) + has_gpu = torch.cuda.is_available() and torch.cuda.device_count() > 0 + device = "cuda" if has_gpu else "cpu" + device_config = DeviceConfig(device=device) + + with set_current_vllm_config(VllmConfig(device_config=device_config)): + yield diff --git a/tests/helpers/fixtures/log.py b/tests/helpers/fixtures/log.py new file mode 100644 index 00000000000..798fa4ae6c7 --- /dev/null +++ b/tests/helpers/fixtures/log.py @@ -0,0 +1,7 @@ +import pytest + + +@pytest.fixture(autouse=True) +def log_test_name_before_test(request: pytest.FixtureRequest): + print(f"--- Running test: {request.node.name}") + yield diff --git a/tests/helpers/fixtures/run_args.py b/tests/helpers/fixtures/run_args.py new file mode 100644 index 00000000000..630513e07c6 --- /dev/null +++ b/tests/helpers/fixtures/run_args.py @@ -0,0 +1,16 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--run-level", + action="store", + default="core_model", + choices=["core_model", "advanced_model"], + help="Test level to run: L2, L3", + ) + + +@pytest.fixture(scope="session") +def run_level(request) -> str: + return request.config.getoption("--run-level") diff --git a/tests/helpers/fixtures/runtime.py b/tests/helpers/fixtures/runtime.py new file mode 100644 index 00000000000..b328e414c9d --- /dev/null +++ b/tests/helpers/fixtures/runtime.py @@ -0,0 +1,80 @@ +"""Runtime fixtures (OmniRunner / OmniServer). Imports are deferred to fixture time. + +Loading ``tests.helpers.runtime`` at plugin import time (before session fixtures) +pulls in vLLM/vllm_omni too early and breaks initialization order vs the legacy +monolithic conftest. Defer imports until fixtures run so ``default_env`` / +``default_vllm_config`` run first. +""" + +from __future__ import annotations + +import threading +from collections.abc import Generator +from typing import Any + +import pytest +import yaml + +from tests.helpers.stage_config import modify_stage_config + +omni_fixture_lock = threading.Lock() + + +@pytest.fixture(scope="module") +def omni_server(request: pytest.FixtureRequest, run_level: str, model_prefix: str) -> Generator[Any, Any, None]: + from tests.helpers.runtime import OmniServer, OmniServerParams + + with omni_fixture_lock: + params: OmniServerParams = request.param + model = model_prefix + params.model + port = params.port + stage_config_path = params.stage_config_path + + if run_level == "advanced_model" and stage_config_path is not None: + with open(stage_config_path, encoding="utf-8") as f: + cfg = yaml.safe_load(f) or {} + stage_ids = [stage["stage_id"] for stage in cfg.get("stage_args", []) if "stage_id" in stage] + stage_config_path = modify_stage_config( + stage_config_path, + deletes={"stage_args": {stage_id: ["engine_args.load_format"] for stage_id in stage_ids}}, + ) + + server_args = params.server_args or [] + if params.use_omni: + server_args = ["--stage-init-timeout", "120", *server_args] + if stage_config_path is not None: + server_args += ["--stage-configs-path", stage_config_path] + + with OmniServer( + model, + server_args, + port=port, + env_dict=params.env_dict, + use_omni=params.use_omni, + ) as server: + yield server + + +@pytest.fixture +def openai_client(omni_server: Any, run_level: str): + from tests.helpers.runtime import OpenAIClientHandler + + return OpenAIClientHandler(host=omni_server.host, port=omni_server.port, api_key="EMPTY", run_level=run_level) + + +@pytest.fixture(scope="module") +def omni_runner(request: pytest.FixtureRequest, model_prefix: str): + from tests.helpers.runtime import OmniRunner + + with omni_fixture_lock: + model, stage_config_path = request.param + model = model_prefix + model + with OmniRunner(model, seed=42, stage_configs_path=stage_config_path, stage_init_timeout=300) as runner: + yield runner + + +@pytest.fixture +def omni_runner_handler(omni_runner: Any): + from tests.helpers.runtime import OmniRunnerHandler + + return OmniRunnerHandler(omni_runner) diff --git a/tests/helpers/mark.py b/tests/helpers/mark.py new file mode 100644 index 00000000000..ed45dd7e9a1 --- /dev/null +++ b/tests/helpers/mark.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Pytest marks and decorators for hardware / resource selection (CUDA, ROCm, …).""" + +import pytest +from vllm.utils.torch_utils import cuda_device_count_stateless + +# Re-exported from tests.helpers.env (GPU wait + DeviceMemoryMonitor). + + +def cuda_marks(*, res: str, num_cards: int): + test_platform_detail = pytest.mark.cuda + if res == "L4": + test_resource = pytest.mark.L4 + elif res == "H100": + test_resource = pytest.mark.H100 + else: + raise ValueError(f"Invalid CUDA resource type: {res}. Supported: L4, H100") + marks = [test_resource, test_platform_detail] + if num_cards == 1: + return marks + test_distributed = pytest.mark.distributed_cuda(num_cards=num_cards) + test_skipif = pytest.mark.skipif_cuda( + cuda_device_count_stateless() < num_cards, + reason=f"Need at least {num_cards} CUDA GPUs to run the test.", + ) + return marks + [test_distributed, test_skipif] + + +def rocm_marks(*, res: str, num_cards: int): + test_platform_detail = pytest.mark.rocm + if res == "MI325": + test_resource = pytest.mark.MI325 + else: + raise ValueError(f"Invalid ROCm resource type: {res}. Supported: MI325") + marks = [test_resource, test_platform_detail] + if num_cards == 1: + return marks + test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards) + return marks + [test_distributed] + + +def xpu_marks(*, res: str, num_cards: int): + test_platform_detail = pytest.mark.xpu + if res == "B60": + test_resource = pytest.mark.B60 + else: + raise ValueError(f"Invalid XPU resource type: {res}. Supported: B60") + marks = [test_resource, test_platform_detail] + if num_cards == 1: + return marks + test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards) + return marks + [test_distributed] + + +def musa_marks(*, res: str, num_cards: int): + test_platform_detail = pytest.mark.musa + if res == "S5000": + test_resource = pytest.mark.S5000 + else: + raise ValueError(f"Invalid MUSA resource type: {res}. Supported: S5000") + marks = [test_resource, test_platform_detail] + if num_cards == 1: + return marks + test_distributed = pytest.mark.distributed_musa(num_cards=num_cards) + return marks + [test_distributed] + + +def gpu_marks(*, res: str, num_cards: int): + test_platform = pytest.mark.gpu + if res in ("L4", "H100"): + return [test_platform] + cuda_marks(res=res, num_cards=num_cards) + if res == "MI325": + return [test_platform] + rocm_marks(res=res, num_cards=num_cards) + if res == "B60": + return [test_platform] + xpu_marks(res=res, num_cards=num_cards) + if res == "S5000": + return [test_platform] + musa_marks(res=res, num_cards=num_cards) + raise ValueError(f"Invalid resource type: {res}. Supported: L4, H100, MI325, B60, S5000") + + +def npu_marks(*, res: str, num_cards: int): + test_platform = pytest.mark.npu + if res == "A2": + test_resource = pytest.mark.A2 + elif res == "A3": + test_resource = pytest.mark.A3 + else: + test_resource = None + if num_cards == 1: + return [mark for mark in [test_platform, test_resource] if mark is not None] + test_distributed = pytest.mark.distributed_npu(num_cards=num_cards) + return [mark for mark in [test_platform, test_resource, test_distributed] if mark is not None] + + +def hardware_marks(*, res: dict[str, str], num_cards: int | dict[str, int] = 1): + for platform, _ in res.items(): + if platform not in ("cuda", "rocm", "xpu", "npu", "musa"): + raise ValueError(f"Unsupported platform: {platform}") + if isinstance(num_cards, int): + num_cards_dict = {platform: num_cards for platform in res.keys()} + else: + num_cards_dict = num_cards + for platform in num_cards_dict.keys(): + if platform not in res: + raise ValueError(f"Platform '{platform}' in num_cards but not in res.") + for platform in res.keys(): + if platform not in num_cards_dict: + num_cards_dict[platform] = 1 + + all_marks: list[pytest.MarkDecorator] = [] + for platform, resource in res.items(): + cards = num_cards_dict[platform] + if platform in ("cuda", "rocm", "xpu"): + marks = gpu_marks(res=resource, num_cards=cards) + elif platform == "musa": + marks = musa_marks(res=resource, num_cards=cards) + elif platform == "npu": + marks = npu_marks(res=resource, num_cards=cards) + else: + raise ValueError(f"Unsupported platform: {platform}") + all_marks.extend(marks) + return all_marks + + +def hardware_test(*, res: dict[str, str], num_cards: int | dict[str, int] = 1): + all_marks = hardware_marks(res=res, num_cards=num_cards) + + def wrapper(f): + func = f + for mark in reversed(all_marks): + func = mark(func) + return func + + return wrapper diff --git a/tests/helpers/media.py b/tests/helpers/media.py new file mode 100644 index 00000000000..77261987592 --- /dev/null +++ b/tests/helpers/media.py @@ -0,0 +1,550 @@ +"""Synthetic media generation and media/text utilities for tests.""" + +import base64 +import concurrent.futures +import datetime +import gc +import io +import logging +import math +import multiprocessing +import os +import random +import re +import subprocess +import tempfile +import time +import uuid +from typing import Any + +import numpy as np +import soundfile as sf +from PIL import Image + +logger = logging.getLogger(__name__) + + +def generate_synthetic_audio( + duration: int, + num_channels: int, + sample_rate: int = 48000, + save_to_file: bool = False, +) -> dict[str, Any]: + """ + Generate TTS speech with pyttsx3 and return base64 string. + """ + import pyttsx3 + + def _pick_voice(engine: pyttsx3.Engine) -> str | None: + voices = engine.getProperty("voices") + if not voices: + return None + + preferred_tokens = ( + "natural", + "jenny", + "sonia", + "susan", + "zira", + "aria", + "hazel", + "samantha", + "ava", + "allison", + "female", + "woman", + "english-us", + "en-us", + "english", + ) + discouraged_tokens = ( + "espeak", + "robot", + "mbrola", + "microsoft david", + "male", + "man", + ) + + best_voice = voices[0] + best_score = float("-inf") + for voice in voices: + voice_text = f"{getattr(voice, 'id', '')} {getattr(voice, 'name', '')}".lower() + voice_languages = " ".join( + lang.decode(errors="ignore") if isinstance(lang, bytes) else str(lang) + for lang in getattr(voice, "languages", []) + ).lower() + combined_text = f"{voice_text} {voice_languages}" + score = 0 + for idx, token in enumerate(preferred_tokens): + if token in combined_text: + score += 20 - idx + for token in discouraged_tokens: + if token in combined_text: + score -= 10 + if "english" in combined_text or "en_" in combined_text or "en-" in combined_text: + score += 4 + if "en-us" in combined_text or "english-us" in combined_text: + score += 4 + if score > best_score: + best_score = score + best_voice = voice + + return best_voice.id + + def _resample_audio(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray: + if src_sr == dst_sr or len(audio) == 0: + return audio.astype(np.float32) + src_len = audio.shape[0] + dst_len = max(1, int(round(src_len * float(dst_sr) / float(src_sr)))) + src_idx = np.arange(src_len, dtype=np.float32) + dst_idx = np.linspace(0, src_len - 1, dst_len, dtype=np.float32) + resampled_channels: list[np.ndarray] = [] + for ch in range(audio.shape[1]): + resampled_channels.append(np.interp(dst_idx, src_idx, audio[:, ch]).astype(np.float32)) + return np.stack(resampled_channels, axis=1) + + def _match_channels(audio: np.ndarray, target_channels: int) -> np.ndarray: + current_channels = audio.shape[1] + if current_channels == target_channels: + return audio.astype(np.float32) + if target_channels == 1: + return np.mean(audio, axis=1, keepdims=True, dtype=np.float32) + if current_channels == 1: + return np.repeat(audio, target_channels, axis=1).astype(np.float32) + collapsed = np.mean(audio, axis=1, keepdims=True, dtype=np.float32) + return np.repeat(collapsed, target_channels, axis=1).astype(np.float32) + + def _trim_silence(audio: np.ndarray, threshold: float = 0.01) -> np.ndarray: + if len(audio) == 0: + return audio + energy = np.max(np.abs(audio), axis=1) + voiced = np.where(energy > threshold)[0] + if len(voiced) == 0: + return audio + start = max(0, int(voiced[0]) - int(sample_rate * 0.02)) + end = min(len(audio), int(voiced[-1]) + int(sample_rate * 0.04) + 1) + return audio[start:end] + + def _enhance_speech(audio: np.ndarray) -> np.ndarray: + if len(audio) == 0: + return audio.astype(np.float32) + enhanced = audio.astype(np.float32).copy() + enhanced -= np.mean(enhanced, axis=0, keepdims=True, dtype=np.float32) + if len(enhanced) > 1: + preemphasis = enhanced.copy() + preemphasis[1:] = enhanced[1:] - 0.94 * enhanced[:-1] + enhanced = 0.7 * enhanced + 0.3 * preemphasis + enhanced = np.sign(enhanced) * np.sqrt(np.abs(enhanced)) + fade = min(len(enhanced) // 4, max(1, int(sample_rate * 0.01))) + if fade > 1: + ramp_in = np.linspace(0.0, 1.0, fade, dtype=np.float32) + ramp_out = np.linspace(1.0, 0.0, fade, dtype=np.float32) + enhanced[:fade] *= ramp_in[:, None] + enhanced[-fade:] *= ramp_out[:, None] + peak = float(np.max(np.abs(enhanced))) + if peak > 1e-8: + enhanced = enhanced / peak * 0.95 + return enhanced.astype(np.float32) + + phrase_text = "test" + num_samples = int(sample_rate * max(1, duration)) + audio_data = np.zeros((num_samples, num_channels), dtype=np.float32) + + engine = pyttsx3.init() + engine.setProperty("rate", 112) + engine.setProperty("volume", 1.0) + selected_voice = _pick_voice(engine) + if selected_voice is not None: + engine.setProperty("voice", selected_voice) + + temp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + temp_wav.close() + try: + engine.save_to_file(phrase_text, temp_wav.name) + engine.runAndWait() + engine.stop() + + ready = False + for _ in range(50): + if os.path.exists(temp_wav.name) and os.path.getsize(temp_wav.name) > 44: + ready = True + break + time.sleep(0.1) + if not ready: + raise RuntimeError("pyttsx3 did not produce a WAV file in time.") + + tts_audio, tts_sr = sf.read(temp_wav.name, dtype="float32", always_2d=True) + finally: + if os.path.exists(temp_wav.name): + os.unlink(temp_wav.name) + + if len(tts_audio) == 0: + raise RuntimeError("pyttsx3 produced an empty WAV file.") + + tts_audio = _resample_audio(tts_audio, tts_sr, sample_rate) + tts_audio = _match_channels(tts_audio, num_channels) + tts_audio = _trim_silence(tts_audio, threshold=0.012) + tts_audio = _enhance_speech(tts_audio) + + lead_silence = min(int(sample_rate * 0.02), num_samples // 8) + pause_samples = int(sample_rate * 0.18) + start = lead_silence + phrase_len = tts_audio.shape[0] + while start < num_samples: + take = min(phrase_len, num_samples - start) + audio_data[start : start + take] = tts_audio[:take] + start += phrase_len + pause_samples + + max_amp = float(np.max(np.abs(audio_data))) + if max_amp > 0: + audio_data = audio_data / max_amp * 0.95 + + audio_bytes: bytes | None = None + output_path: str | None = None + result: dict[str, Any] = {"np_array": audio_data.copy()} + + if save_to_file: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = f"audio_{num_channels}ch_{timestamp}.wav" + try: + sf.write(output_path, audio_data, sample_rate, format="WAV", subtype="PCM_16") + with open(output_path, "rb") as f: + audio_bytes = f.read() + except Exception as e: + print(f"Save failed: {e}") + save_to_file = False + + if not save_to_file or audio_bytes is None: + buffer = io.BytesIO() + sf.write(buffer, audio_data, sample_rate, format="WAV", subtype="PCM_16") + buffer.seek(0) + audio_bytes = buffer.read() + + result["base64"] = base64.b64encode(audio_bytes).decode("utf-8") + result["file_path"] = output_path if save_to_file and output_path else None + return result + + +def _mux_mp4_bytes_with_synthetic_audio( + video_mp4_bytes: bytes, + *, + num_frames: int, + fps: float = 30.0, + sample_rate: int = 48000, +) -> bytes: + duration_sec = num_frames / fps if fps > 0 else 0.0 + duration_int = max(1, int(math.ceil(duration_sec))) + + try: + audio_result = generate_synthetic_audio( + duration=duration_int, + num_channels=1, + sample_rate=sample_rate, + save_to_file=False, + ) + audio_pcm = audio_result["np_array"] + except Exception as e: + logger.warning("Synthetic video: generate_synthetic_audio failed (%s); using video-only MP4.", e) + return video_mp4_bytes + + try: + import imageio_ffmpeg + + ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe() + except Exception: + ffmpeg_exe = "ffmpeg" + + try: + with tempfile.TemporaryDirectory(prefix="syn_vid_mux_") as tmp: + vid_path = os.path.join(tmp, "video.mp4") + wav_path = os.path.join(tmp, "audio.wav") + out_path = os.path.join(tmp, "out.mp4") + with open(vid_path, "wb") as f: + f.write(video_mp4_bytes) + sf.write(wav_path, audio_pcm, sample_rate, format="WAV", subtype="PCM_16") + cmd = [ + ffmpeg_exe, + "-y", + "-nostdin", + "-hide_banner", + "-loglevel", + "error", + "-i", + vid_path, + "-i", + wav_path, + "-c:v", + "copy", + "-c:a", + "aac", + "-b:a", + "128k", + "-shortest", + "-movflags", + "+faststart", + out_path, + ] + subprocess.run(cmd, check=True, stdin=subprocess.DEVNULL, timeout=300) + with open(out_path, "rb") as f: + return f.read() + except ( + FileNotFoundError, + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + OSError, + ) as e: + logger.warning("Synthetic video: audio mux failed (%s); using video-only MP4.", e) + return video_mp4_bytes + + +def generate_synthetic_video( + width: int, + height: int, + num_frames: int, + save_to_file: bool = False, + *, + embed_audio: bool = False, +) -> dict[str, Any]: + import cv2 + import imageio + + num_balls = random.randint(3, 8) + balls = [] + for _ in range(num_balls): + radius = min(width, height) // 8 + if radius < 1: + raise ValueError(f"Video dimensions ({width}x{height}) too small") + x = random.randint(radius, width - radius) + y = random.randint(radius, height - radius) + speed = random.uniform(3.0, 8.0) + angle = random.uniform(0, 2 * math.pi) + vx = speed * math.cos(angle) + vy = speed * math.sin(angle) + color_bgr = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255)) + balls.append({"x": x, "y": y, "vx": vx, "vy": vy, "radius": radius, "color_bgr": color_bgr}) + + video_frames = [] + for _ in range(num_frames): + frame_bgr = np.zeros((height, width, 3), dtype=np.uint8) + for ball in balls: + ball["x"] += ball["vx"] + ball["y"] += ball["vy"] + if ball["x"] - ball["radius"] <= 0 or ball["x"] + ball["radius"] >= width: + ball["vx"] = -ball["vx"] + ball["x"] = max(ball["radius"], min(width - ball["radius"], ball["x"])) + if ball["y"] - ball["radius"] <= 0 or ball["y"] + ball["radius"] >= height: + ball["vy"] = -ball["vy"] + ball["y"] = max(ball["radius"], min(height - ball["radius"], ball["y"])) + x, y = int(ball["x"]), int(ball["y"]) + radius = int(ball["radius"]) + cv2.circle(frame_bgr, (x, y), radius, ball["color_bgr"], -1) + frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) + video_frames.append(frame_rgb) + + fps = 30 + buffer = io.BytesIO() + writer_kwargs = { + "format": "mp4", + "fps": fps, + "codec": "libx264", + "quality": 7, + "pixelformat": "yuv420p", + "macro_block_size": 16, + "ffmpeg_params": ["-preset", "medium", "-crf", "23", "-movflags", "+faststart", "-pix_fmt", "yuv420p"], + } + with imageio.get_writer(buffer, **writer_kwargs) as writer: + for frame in video_frames: + writer.append_data(frame) + buffer.seek(0) + video_only_bytes = buffer.read() + video_bytes = ( + _mux_mp4_bytes_with_synthetic_audio(video_only_bytes, num_frames=num_frames, fps=float(fps)) + if embed_audio + else video_only_bytes + ) + + result: dict[str, Any] = { + "np_array": np.array(video_frames), + "base64": base64.b64encode(video_bytes).decode("utf-8"), + } + if save_to_file: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = f"video_{width}x{height}_{timestamp}.mp4" + with open(output_path, "wb") as f: + f.write(video_bytes) + result["file_path"] = output_path + return result + + +def generate_synthetic_image(width: int, height: int, save_to_file: bool = False) -> dict[str, Any]: + from PIL import ImageDraw + + image = Image.new("RGB", (width, height), (255, 255, 255)) + draw = ImageDraw.Draw(image) + num_squares = random.randint(3, 8) + for _ in range(num_squares): + square_size = random.randint(max(1, min(width, height) // 8), max(2, min(width, height) // 4)) + x = random.randint(0, max(0, width - square_size - 1)) + y = random.randint(0, max(0, height - square_size - 1)) + color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + border_width = random.randint(1, 5) + draw.rectangle([x, y, x + square_size, y + square_size], fill=color, outline=(0, 0, 0), width=border_width) + + result: dict[str, Any] = {"np_array": np.array(image).copy()} + image_bytes: bytes | None = None + saved_file_path: str | None = None + if save_to_file: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = f"image_{width}x{height}_{timestamp}.jpg" + image.save(output_path, format="JPEG", quality=85, optimize=True) + saved_file_path = output_path + with open(output_path, "rb") as f: + image_bytes = f.read() + if not save_to_file or image_bytes is None: + buffer = io.BytesIO() + image.save(buffer, format="JPEG", quality=85, optimize=True) + buffer.seek(0) + image_bytes = buffer.read() + + result["base64"] = base64.b64encode(image_bytes).decode("utf-8") + if save_to_file and saved_file_path: + result["file_path"] = saved_file_path + return result + + +def decode_b64_image(b64: str): + img = Image.open(io.BytesIO(base64.b64decode(b64))) + img.load() + return img + + +def preprocess_text(text): + import opencc + + word_to_num = { + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + for word, num in word_to_num.items(): + pattern = r"\b" + re.escape(word) + r"\b" + text = re.sub(pattern, num, text, flags=re.IGNORECASE) + + text = re.sub(r"[^\w\s]", "", text) + text = re.sub(r"\s+", " ", text) + cc = opencc.OpenCC("t2s") + text = cc.convert(text) + text = re.sub(r"(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])", "", text) + return text.lower().strip() + + +def cosine_similarity_text(text1, text2, n: int = 3): + from collections import Counter + + if not text1 or not text2: + return 0.0 + + text1 = preprocess_text(text1) + text2 = preprocess_text(text2) + + ngrams1 = [text1[i : i + n] for i in range(len(text1) - n + 1)] + ngrams2 = [text2[i : i + n] for i in range(len(text2) - n + 1)] + counter1 = Counter(ngrams1) + counter2 = Counter(ngrams2) + + all_ngrams = set(counter1.keys()) | set(counter2.keys()) + vec1 = [counter1.get(ng, 0) for ng in all_ngrams] + vec2 = [counter2.get(ng, 0) for ng in all_ngrams] + dot_product = sum(a * b for a, b in zip(vec1, vec2)) + norm1 = sum(a * a for a in vec1) ** 0.5 + norm2 = sum(b * b for b in vec2) ** 0.5 + if norm1 == 0 or norm2 == 0: + return 0.0 + return dot_product / (norm1 * norm2) + + +def _merge_base64_audio_to_segment(base64_list: list[str]): + from pydub import AudioSegment + + merged = None + for b64 in base64_list: + raw = base64.b64decode(b64.split(",", 1)[-1]) + seg = AudioSegment.from_file(io.BytesIO(raw)) + merged = seg if merged is None else merged + seg + return merged + + +def _whisper_transcribe_in_current_process(output_path: str) -> str: + import whisper + + device_index = None + from vllm_omni.platforms import current_omni_platform + + if current_omni_platform.is_available(): + n = current_omni_platform.get_device_count() + if n == 1: + device_index = 0 + elif n > 1: + device_index = n - 1 + + if device_index is not None: + torch_device = current_omni_platform.get_torch_device(device_index) + current_omni_platform.set_device(torch_device) + device = str(torch_device) + use_accelerator = True + else: + use_accelerator = False + device = "cpu" + + model = whisper.load_model("small", device=device) + try: + text = model.transcribe( + output_path, + temperature=0.0, + word_timestamps=True, + condition_on_previous_text=False, + )["text"] + finally: + del model + gc.collect() + if use_accelerator: + current_omni_platform.synchronize() + current_omni_platform.empty_cache() + return text or "" + + +def convert_audio_file_to_text(output_path: str) -> str: + """Convert an audio file to text in an isolated subprocess.""" + ctx = multiprocessing.get_context("spawn") + with concurrent.futures.ProcessPoolExecutor(max_workers=1, mp_context=ctx) as executor: + future = executor.submit(_whisper_transcribe_in_current_process, output_path) + return future.result() + + +def convert_audio_bytes_to_text(raw_bytes: bytes) -> str: + output_path = f"./test_{uuid.uuid4().hex}.wav" + data, samplerate = sf.read(io.BytesIO(raw_bytes)) + sf.write(output_path, data, samplerate, format="WAV", subtype="PCM_16") + return convert_audio_file_to_text(output_path) + + +__all__ = [ + "_merge_base64_audio_to_segment", + "convert_audio_bytes_to_text", + "convert_audio_file_to_text", + "cosine_similarity_text", + "decode_b64_image", + "generate_synthetic_audio", + "generate_synthetic_image", + "generate_synthetic_video", + "preprocess_text", +] diff --git a/tests/e2e/offline_inference/utils.py b/tests/helpers/process.py similarity index 58% rename from tests/e2e/offline_inference/utils.py rename to tests/helpers/process.py index 3113599a305..094de965239 100644 --- a/tests/e2e/offline_inference/utils.py +++ b/tests/helpers/process.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib import functools import os import signal @@ -10,73 +9,48 @@ import tempfile from collections.abc import Callable from contextlib import ExitStack, suppress -from pathlib import Path from typing import Any, Literal import cloudpickle from typing_extensions import ParamSpec from vllm.platforms import current_platform -VLLM_PATH = Path(__file__).parent.parent.parent -"""Path to root of the vLLM repository.""" - - _P = ParamSpec("_P") def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]: - """Decorator to fork a new process for each test function. - See https://github.com/vllm-project/vllm/issues/7053 for more details. - """ + """Decorator to fork a new process for each test function.""" @functools.wraps(func) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: - # Make the process the leader of its own process group - # to avoid sending SIGTERM to the parent process os.setpgrp() from _pytest.outcomes import Skipped - # Create a unique temporary file to store exception info from child - # process. Use test function name and process ID to avoid collisions. with ( tempfile.NamedTemporaryFile( - delete=False, - mode="w+b", - prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", - suffix=".exc", + delete=False, mode="w+b", prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", suffix=".exc" ) as exc_file, ExitStack() as delete_after, ): exc_file_path = exc_file.name delete_after.callback(os.remove, exc_file_path) - pid = os.fork() - print(f"Fork a new process to run a test {pid}") if pid == 0: - # Parent process responsible for deleting, don't delete - # in child. delete_after.pop_all() try: func(*args, **kwargs) except Skipped as e: - # convert Skipped to exit code 0 print(str(e)) os._exit(0) except Exception as e: import traceback tb_string = traceback.format_exc() - - # Try to serialize the exception object first exc_to_serialize: dict[str, Any] try: - # First, try to pickle the actual exception with - # its traceback. exc_to_serialize = {"pickled_exception": e} - # Test if it can be pickled cloudpickle.dumps(exc_to_serialize) except (Exception, KeyboardInterrupt): - # Fall back to string-based approach. exc_to_serialize = { "exception_type": type(e).__name__, "exception_msg": str(e), @@ -86,7 +60,6 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: with open(exc_file_path, "wb") as f: cloudpickle.dump(exc_to_serialize, f) except Exception: - # Fallback: just print the traceback. print(tb_string) os._exit(1) else: @@ -94,40 +67,24 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: else: pgid = os.getpgid(pid) _pid, _exitcode = os.waitpid(pid, 0) - # ignore SIGTERM signal itself old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) - # kill all child processes os.killpg(pgid, signal.SIGTERM) - # restore the signal handler signal.signal(signal.SIGTERM, old_signal_handler) if _exitcode != 0: - # Try to read the exception from the child process exc_info = {} if os.path.exists(exc_file_path): - with ( - contextlib.suppress(Exception), - open(exc_file_path, "rb") as f, - ): + with suppress(Exception), open(exc_file_path, "rb") as f: exc_info = cloudpickle.load(f) - - original_exception = exc_info.get("pickled_exception") - if original_exception is not None and isinstance(original_exception, Exception): - # Re-raise the actual exception object if it was - # successfully pickled. + if (original_exception := exc_info.get("pickled_exception")) is not None: + assert isinstance(original_exception, Exception) raise original_exception - if (original_tb := exc_info.get("traceback")) is not None: - # Use string-based traceback for fallback case raise AssertionError( - f"Test {func.__name__} failed when called with" - f" args {args} and kwargs {kwargs}" + f"Test {func.__name__} failed when called with args {args} and kwargs {kwargs}" f" (exit code: {_exitcode}):\n{original_tb}" ) from None - - # Fallback to the original generic error raise AssertionError( - f"function {func.__name__} failed when called with" - f" args {args} and kwargs {kwargs}" + f"function {func.__name__} failed when called with args {args} and kwargs {kwargs}" f" (exit code: {_exitcode})" ) from None @@ -139,9 +96,7 @@ def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None] @functools.wraps(f) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: - # Check if we're already in a subprocess if os.environ.get("RUNNING_IN_SUBPROCESS") == "1": - # If we are, just run the function directly return f(*args, **kwargs) import torch.multiprocessing as mp @@ -149,33 +104,18 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: with suppress(RuntimeError): mp.set_start_method("spawn") - # Get the module module_name = f.__module__ - - # Create a process with environment variable set env = os.environ.copy() env["RUNNING_IN_SUBPROCESS"] = "1" with tempfile.TemporaryDirectory() as tempdir: output_filepath = os.path.join(tempdir, "new_process.tmp") - - # `cloudpickle` allows pickling complex functions directly input_bytes = cloudpickle.dumps((f, output_filepath)) - - repo_root = str(VLLM_PATH.resolve()) - - env = dict(env or os.environ) - env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "") - cmd = [sys.executable, "-m", f"{module_name}"] - returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env) - - # check if the subprocess is successful try: returned.check_returncode() except Exception as e: - # wrap raised exception to provide more information raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e return wrapper @@ -184,27 +124,11 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: def create_new_process_for_each_test( method: Literal["spawn", "fork"] | None = None, ) -> Callable[[Callable[_P, None]], Callable[_P, None]]: - """Creates a decorator that runs each test function in a new process. - - Args: - method: The process creation method. Can be either "spawn" or "fork". - If not specified, it defaults to "spawn" on ROCm and XPU - platforms and "fork" otherwise. - - Returns: - A decorator to run test functions in separate processes. - """ + """Creates a decorator that runs each test function in a new process.""" if method is None: - # TODO: Find out why spawn is not working correctly on ROCm - # The test content will not run and tests passed immediately. - # For now, using `fork` for ROCm as it can run with `fork` - # and tests are running correctly. use_spawn = current_platform.is_xpu() method = "spawn" if use_spawn else "fork" - assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'" - if method == "fork": return fork_new_process_for_each_test - return spawn_new_process_for_each_test diff --git a/tests/helpers/runtime.py b/tests/helpers/runtime.py new file mode 100644 index 00000000000..9936f45d3ef --- /dev/null +++ b/tests/helpers/runtime.py @@ -0,0 +1,1024 @@ +"""Server/client/runner runtime primitives for tests.""" + +import base64 +import concurrent.futures +import io +import json +import os +import socket +import subprocess +import sys +import time +from dataclasses import dataclass +from io import BytesIO +from typing import Any, NamedTuple + +import psutil +import requests +import soundfile as sf +import torch +from openai import OpenAI, omit +from PIL import Image +from vllm import TextPrompt +from vllm.logger import init_logger + +from tests.helpers.assertions import ( + assert_audio_speech_response, + assert_diffusion_response, + assert_omni_response, +) +from tests.helpers.env import _run_post_test_cleanup, _run_pre_test_cleanup +from tests.helpers.media import ( + _merge_base64_audio_to_segment, + convert_audio_bytes_to_text, + cosine_similarity_text, + decode_b64_image, +) + +logger = init_logger(__name__) + +PromptAudioInput = list[tuple[Any, int]] | tuple[Any, int] | None +PromptImageInput = list[Any] | Any | None +PromptVideoInput = list[Any] | Any | None + +try: + from vllm.distributed.parallel_state import cleanup_dist_env_and_memory # type: ignore +except Exception: # pragma: no cover + + def cleanup_dist_env_and_memory() -> None: + return None + + +def get_open_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return int(s.getsockname()[1]) + + +class OmniServerParams(NamedTuple): + model: str + port: int | None = None + stage_config_path: str | None = None + server_args: list[str] | None = None + env_dict: dict[str, str] | None = None + use_omni: bool = True + + +class OmniServer: + """Omniserver for vLLM-Omni tests.""" + + def __init__( + self, + model: str, + serve_args: list[str], + *, + port: int | None = None, + env_dict: dict[str, str] | None = None, + use_omni: bool = True, + ) -> None: + _run_pre_test_cleanup(enable_force=True) + _run_post_test_cleanup(enable_force=True) + cleanup_dist_env_and_memory() + self.model = model + self.serve_args = serve_args + self.env_dict = env_dict + self.use_omni = use_omni + self.proc: subprocess.Popen | None = None + self.host = "127.0.0.1" + self.port = get_open_port() if port is None else port + + def _start_server(self) -> None: + env = os.environ.copy() + env.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + if self.env_dict is not None: + env.update(self.env_dict) + + cmd = [ + sys.executable, + "-m", + "vllm_omni.entrypoints.cli.main", + "serve", + self.model, + "--host", + self.host, + "--port", + str(self.port), + ] + if self.use_omni: + cmd.append("--omni") + cmd += self.serve_args + + self.proc = subprocess.Popen( + cmd, + env=env, + cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + ) + + max_wait = 1200 + start_time = time.time() + while time.time() - start_time < max_wait: + ret = self.proc.poll() + if ret is not None: + raise RuntimeError(f"Server processes exited with code {ret} before becoming ready.") + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(1) + if sock.connect_ex((self.host, self.port)) == 0: + return + time.sleep(2) + raise RuntimeError(f"Server failed to start within {max_wait} seconds") + + def _kill_process_tree(self, pid): + try: + parent = psutil.Process(pid) + children = parent.children(recursive=True) + for child in children: + try: + child.terminate() + except psutil.NoSuchProcess: + pass + _, still_alive = psutil.wait_procs(children, timeout=10) + for child in still_alive: + try: + child.kill() + except psutil.NoSuchProcess: + pass + try: + parent.terminate() + parent.wait(timeout=10) + except (psutil.NoSuchProcess, psutil.TimeoutExpired): + try: + parent.kill() + except psutil.NoSuchProcess: + pass + except psutil.NoSuchProcess: + pass + + def __enter__(self): + self._start_server() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.proc: + self._kill_process_tree(self.proc.pid) + _run_pre_test_cleanup(enable_force=True) + _run_post_test_cleanup(enable_force=True) + cleanup_dist_env_and_memory() + + +@dataclass +class OmniResponse: + text_content: str | None = None + audio_data: list[str] | None = None + audio_content: str | None = None + audio_format: str | None = None + audio_bytes: bytes | None = None + similarity: float | None = None + e2e_latency: float | None = None + success: bool = False + error_message: str | None = None + + +@dataclass +class DiffusionResponse: + text_content: str | None = None + images: list[Image.Image] | None = None + audios: list[Any] | None = None + videos: list[Any] | None = None + e2e_latency: float | None = None + success: bool = False + error_message: str | None = None + + +class OpenAIClientHandler: + def __init__(self, host: str = "127.0.0.1", port: int = None, api_key: str = "EMPTY", run_level: str = None): + if port is None: + port = get_open_port() + self.base_url = f"http://{host}:{port}" + self.client = OpenAI(base_url=f"http://{host}:{port}/v1", api_key=api_key) + self.run_level = run_level + + def _process_stream_omni_response(self, chat_completion) -> OmniResponse: + result = OmniResponse() + start_time = time.perf_counter() + try: + text_content = "" + audio_data = [] + for chunk in chat_completion: + for choice in chunk.choices: + content = getattr(getattr(choice, "delta", None), "content", None) + modality = getattr(chunk, "modality", None) + if modality == "audio" and content: + audio_data.append(content) + elif modality == "text" and content: + text_content += content + result.e2e_latency = time.perf_counter() - start_time + audio_content = None + similarity = None + if audio_data: + merged_seg = _merge_base64_audio_to_segment(audio_data) + wav_buf = BytesIO() + merged_seg.export(wav_buf, format="wav") + result.audio_bytes = wav_buf.getvalue() + audio_content = convert_audio_bytes_to_text(result.audio_bytes) + if audio_content and text_content: + similarity = cosine_similarity_text(audio_content.lower(), text_content.lower()) + result.text_content = text_content + result.audio_data = audio_data + result.audio_content = audio_content + result.similarity = similarity + result.success = True + except Exception as e: + result.error_message = f"Stream processing error: {str(e)}" + return result + + def _process_non_stream_omni_response(self, chat_completion) -> OmniResponse: + result = OmniResponse() + start_time = time.perf_counter() + try: + audio_data = None + text_content = None + for choice in chat_completion.choices: + if hasattr(choice.message, "audio") and choice.message.audio is not None: + audio_data = choice.message.audio.data + if hasattr(choice.message, "content") and choice.message.content is not None: + text_content = choice.message.content + result.e2e_latency = time.perf_counter() - start_time + audio_content = None + similarity = None + if audio_data: + result.audio_bytes = base64.b64decode(audio_data) + audio_content = convert_audio_bytes_to_text(result.audio_bytes) + if audio_content and text_content: + similarity = cosine_similarity_text(audio_content.lower(), text_content.lower()) + result.text_content = text_content + result.audio_content = audio_content + result.similarity = similarity + result.success = True + except Exception as e: + result.error_message = f"Non-stream processing error: {str(e)}" + return result + + def _process_diffusion_response(self, chat_completion) -> DiffusionResponse: + result = DiffusionResponse() + start_time = time.perf_counter() + try: + images = [] + for choice in chat_completion.choices: + content = getattr(choice.message, "content", None) + if isinstance(content, list): + for item in content: + image_url = None + if isinstance(item, dict): + image_url = item.get("image_url", {}).get("url") + else: + image_url_obj = getattr(item, "image_url", None) + image_url = getattr(image_url_obj, "url", None) if image_url_obj else None + if image_url and image_url.startswith("data:image"): + b64_data = image_url.split(",", 1)[1] + images.append(decode_b64_image(b64_data)) + result.e2e_latency = time.perf_counter() - start_time + result.images = images if images else None + result.success = True + except Exception as e: + result.error_message = f"Diffusion response processing error: {str(e)}" + return result + + def send_omni_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]: + responses: list[OmniResponse] = [] + stream = request_config.get("stream", False) + modalities = request_config.get("modalities", ["text", "audio"]) + extra_body: dict[str, Any] = {} + if "speaker" in request_config: + extra_body["speaker"] = request_config["speaker"] + if request_config.get("use_audio_in_video"): + mm = dict(extra_body.get("mm_processor_kwargs") or {}) + mm["use_audio_in_video"] = True + extra_body["mm_processor_kwargs"] = mm + create_kwargs: dict[str, Any] = { + "model": request_config.get("model"), + "messages": request_config.get("messages"), + "stream": stream, + "modalities": modalities, + } + if extra_body: + create_kwargs["extra_body"] = extra_body + + if request_num == 1: + chat_completion = self.client.chat.completions.create(**create_kwargs) + resp = ( + self._process_stream_omni_response(chat_completion) + if stream + else self._process_non_stream_omni_response(chat_completion) + ) + assert_omni_response(resp, request_config, run_level=self.run_level) + responses.append(resp) + return responses + + def _one(): + chat_completion = self.client.chat.completions.create(**create_kwargs) + return ( + self._process_stream_omni_response(chat_completion) + if stream + else self._process_non_stream_omni_response(chat_completion) + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor: + futures = [executor.submit(_one) for _ in range(request_num)] + for future in concurrent.futures.as_completed(futures): + resp = future.result() + assert_omni_response(resp, request_config, run_level=self.run_level) + responses.append(resp) + return responses + + def _process_stream_audio_speech_response(self, response, *, response_format: str | None = None) -> OmniResponse: + """ + Process streaming /v1/audio/speech responses into an OmniResponse. + + This mirrors _process_stream_omni_response but operates on low-level + audio bytes and produces an OmniResponse with audio_content filled + from Whisper transcription. + """ + result = OmniResponse() + start_time = time.perf_counter() + + try: + # Aggregate all audio bytes from the streaming response. + data = bytearray() + + # Preferred OpenAI helper. + if hasattr(response, "iter_bytes") and callable(getattr(response, "iter_bytes")): + for chunk in response.iter_bytes(): + if chunk: + data.extend(chunk) + else: + # Generic iterable-of-bytes fallback (e.g., generator or list of chunks). + try: + iterator = iter(response) + except TypeError: + iterator = None + + if iterator is not None: + for chunk in iterator: + if not chunk: + continue + if isinstance(chunk, (bytes, bytearray)): + data.extend(chunk) + elif hasattr(chunk, "data"): + data.extend(chunk.data) # type: ignore[arg-type] + elif hasattr(chunk, "content"): + data.extend(chunk.content) # type: ignore[arg-type] + else: + raise TypeError(f"Unsupported stream chunk type: {type(chunk)}") + else: + raise TypeError(f"Unsupported audio speech streaming response type: {type(response)}") + + raw_bytes = bytes(data) + if response_format == "pcm": + transcript = None + else: + transcript = convert_audio_bytes_to_text(raw_bytes) + + # Populate OmniResponse. + result.audio_bytes = raw_bytes + result.audio_content = transcript + result.e2e_latency = time.perf_counter() - start_time + result.success = True + result.audio_format = getattr(response, "response", None) + if result.audio_format is not None: + result.audio_format = result.audio_format.headers.get("content-type", "") + + except Exception as e: + result.error_message = f"Audio speech stream processing error: {str(e)}" + print(f"Error: {result.error_message}") + + return result + + def _process_non_stream_audio_speech_response( + self, response, *, response_format: str | None = None + ) -> OmniResponse: + """ + Process non-streaming /v1/audio/speech responses into an OmniResponse. + + This mirrors _process_non_stream_omni_response but for the binary + audio payload returned by audio.speech.create. + """ + result = OmniResponse() + start_time = time.perf_counter() + + try: + # OpenAI non-streaming audio.speech.create returns HttpxBinaryResponseContent (.read() or .content) + if hasattr(response, "read") and callable(getattr(response, "read")): + raw_bytes = response.read() + elif hasattr(response, "content"): + raw_bytes = response.content # type: ignore[assignment] + else: + raise TypeError(f"Unsupported audio speech response type: {type(response)}") + + if response_format == "pcm": + transcript = None + else: + transcript = convert_audio_bytes_to_text(raw_bytes) + + result.audio_bytes = raw_bytes + result.audio_content = transcript + result.e2e_latency = time.perf_counter() - start_time + result.success = True + result.audio_format = getattr(response, "response", None) + if result.audio_format is not None: + result.audio_format = result.audio_format.headers.get("content-type", "") + + except Exception as e: + result.error_message = f"Audio speech non-stream processing error: {str(e)}" + print(f"Error: {result.error_message}") + + return result + + def send_audio_speech_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]: + """ + Call the /v1/audio/speech endpoint using the same configuration-dict + style as send_omni_request, but via the OpenAI Python client's + audio.speech APIs. + + Expected keys in request_config: + - model: model name/path (required) + - input: text to synthesize (required) + - response_format: audio format such as "wav" or "pcm" (optional) + - task_type, ref_text, ref_audio: TTS-specific extras (optional, passed via extra_body) + - timeout: request timeout in seconds (float, optional, default 120.0) + - stream: whether to use streaming API (bool, optional, default False) + """ + timeout = float(request_config.get("timeout", 120.0)) + + model = request_config["model"] + text_input = request_config["input"] + stream = bool(request_config.get("stream", False)) + voice = request_config.get("voice", None) + + # Standard OpenAI param: use omit when not provided to keep default behavior. + response_format = request_config.get("response_format", omit) + + # Qwen3-TTS custom fields, forwarded via extra_body. + extra_body: dict[str, Any] = {} + # Keep this list aligned with vllm_omni.entrypoints.openai.protocol.audio params. + for key in ("task_type", "ref_text", "ref_audio", "language", "max_new_tokens"): + if key in request_config: + extra_body[key] = request_config[key] + + responses: list[OmniResponse] = [] + + speech_fmt: str | None = None if response_format is omit else str(response_format).lower() + + if request_num == 1: + if stream: + # Use streaming response helper. + with self.client.audio.speech.with_streaming_response.create( + model=model, + input=text_input, + response_format=response_format, + extra_body=extra_body or None, + timeout=timeout, + voice=voice, + ) as resp: + omni_resp = self._process_stream_audio_speech_response(resp, response_format=speech_fmt) + else: + # Non-streaming response. + resp = self.client.audio.speech.create( + model=model, + input=text_input, + response_format=response_format, + extra_body=extra_body or None, + timeout=timeout, + voice=voice, + ) + omni_resp = self._process_non_stream_audio_speech_response(resp, response_format=speech_fmt) + + assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level) + responses.append(omni_resp) + return responses + else: + # request_num > 1: concurrent requests (use same params as single-request path) + + if stream: + + def _stream_task(): + with self.client.audio.speech.with_streaming_response.create( + model=model, + input=text_input, + response_format=response_format, + extra_body=extra_body or None, + timeout=timeout, + voice=voice, + ) as resp: + return self._process_stream_audio_speech_response(resp, response_format=speech_fmt) + + with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor: + futures = [executor.submit(_stream_task) for _ in range(request_num)] + for future in concurrent.futures.as_completed(futures): + omni_resp = future.result() + assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level) + responses.append(omni_resp) + else: + with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor: + futures = [] + for _ in range(request_num): + future = executor.submit( + self.client.audio.speech.create, + model=model, + input=text_input, + response_format=response_format, + extra_body=extra_body or None, + timeout=timeout, + voice=voice, + ) + futures.append(future) + + for future in concurrent.futures.as_completed(futures): + resp = future.result() + omni_resp = self._process_non_stream_audio_speech_response(resp, response_format=speech_fmt) + assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level) + responses.append(omni_resp) + + return responses + + def send_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]: + if request_num != 1: + raise NotImplementedError("Concurrent diffusion requests not supported here") + chat_completion = self.client.chat.completions.create( + model=request_config.get("model"), + messages=request_config.get("messages"), + extra_body=request_config.get("extra_body", None), + modalities=request_config.get("modalities", omit), + ) + resp = self._process_diffusion_response(chat_completion) + assert_diffusion_response(resp, request_config, run_level=self.run_level) + return [resp] + + def send_video_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]: + if request_num != 1: + raise NotImplementedError("Concurrent video diffusion requests are not currently implemented") + form_data = request_config.get("form_data") + if not isinstance(form_data, dict): + raise ValueError("Video request_config must contain 'form_data'") + normalized_form_data = {key: str(value) for key, value in form_data.items() if value is not None} + files: dict[str, tuple[str, BytesIO, str]] = {} + image_reference = request_config.get("image_reference") + if image_reference: + if image_reference.startswith("data:image"): + header, encoded = image_reference.split(",", 1) + content_type = header.split(";")[0].removeprefix("data:") + extension = content_type.split("/")[-1] + file_data = base64.b64decode(encoded) + files["input_reference"] = (f"reference.{extension}", BytesIO(file_data), content_type) + else: + normalized_form_data["image_reference"] = json.dumps({"image_url": image_reference}) + + result = DiffusionResponse() + start_time = time.perf_counter() + create_url = self._build_url("/v1/videos") + response = requests.post( + create_url, + data=normalized_form_data, + files=files, + headers={"Accept": "application/json"}, + timeout=60, + ) + response.raise_for_status() + job_data = response.json() + video_id = job_data["id"] + self._wait_until_video_completed(video_id) + video_content = self._download_video_content(video_id) + result.success = True + result.videos = [video_content] + result.e2e_latency = time.perf_counter() - start_time + assert_diffusion_response(result, request_config, run_level=self.run_level) + return [result] + + def _wait_until_video_completed( + self, video_id: str, poll_interval_seconds: int = 2, timeout_seconds: int = 300 + ) -> None: + status_url = self._build_url(f"/v1/videos/{video_id}") + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + status_resp = requests.get(status_url, headers={"Accept": "application/json"}, timeout=30) + status_resp.raise_for_status() + status_data = status_resp.json() + current_status = status_data["status"] + if current_status == "completed": + return + if current_status == "failed": + error_msg = status_data.get("last_error", "Unknown error") + raise RuntimeError(f"Job failed: {error_msg}") + time.sleep(poll_interval_seconds) + raise TimeoutError(f"Video job {video_id} did not complete within {timeout_seconds}s") + + def _download_video_content(self, video_id: str) -> bytes: + download_url = self._build_url(f"/v1/videos/{video_id}/content") + video_resp = requests.get(download_url, stream=True, timeout=60) + video_resp.raise_for_status() + video_bytes = BytesIO() + for chunk in video_resp.iter_content(chunk_size=8192): + if chunk: + video_bytes.write(chunk) + return video_bytes.getvalue() + + def _build_url(self, path: str) -> str: + return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}" + + +class OmniRunner: + def __init__( + self, + model_name: str, + seed: int = 42, + stage_init_timeout: int = 300, + batch_timeout: int = 10, + init_timeout: int = 300, + shm_threshold_bytes: int = 65536, + log_stats: bool = False, + stage_configs_path: str | None = None, + **kwargs, + ) -> None: + cleanup_dist_env_and_memory() + _run_pre_test_cleanup(enable_force=True) + _run_post_test_cleanup(enable_force=True) + self.model_name = model_name + self.seed = seed + self._prompt_len_estimate_cache: dict[str, Any] = {} + from vllm_omni.entrypoints.omni import Omni + + self.omni = Omni( + model=model_name, + log_stats=log_stats, + stage_init_timeout=stage_init_timeout, + batch_timeout=batch_timeout, + init_timeout=init_timeout, + shm_threshold_bytes=shm_threshold_bytes, + stage_configs_path=stage_configs_path, + **kwargs, + ) + + def get_default_sampling_params_list(self) -> list[Any]: + if not hasattr(self.omni, "default_sampling_params_list"): + raise AttributeError("Omni.default_sampling_params_list is not available") + return list(self.omni.default_sampling_params_list) + + def _estimate_prompt_len( + self, + additional_information: dict[str, Any], + model_name: str, + ) -> int: + """Estimate prompt_token_ids placeholder length for the Talker stage. + + The AR Talker replaces all input embeddings via ``preprocess``, so the + placeholder values are irrelevant but the **length** must match the + embeddings that ``preprocess`` will produce. + """ + _cache = self._prompt_len_estimate_cache + try: + from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import Qwen3TTSConfig + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import ( + Qwen3TTSTalkerForConditionalGeneration, + ) + + if model_name not in _cache: + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left") + cfg = Qwen3TTSConfig.from_pretrained(model_name, trust_remote_code=True) + _cache[model_name] = (tok, getattr(cfg, "talker_config", None)) + + tok, tcfg = _cache[model_name] + task_type = (additional_information.get("task_type") or ["CustomVoice"])[0] + return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information( + additional_information=additional_information, + task_type=task_type, + tokenize_prompt=lambda t: tok(t, padding=False)["input_ids"], + codec_language_id=getattr(tcfg, "codec_language_id", None), + spk_is_dialect=getattr(tcfg, "spk_is_dialect", None), + ) + except Exception as exc: + logger.warning("Failed to estimate prompt length, using fallback 2048: %s", exc) + return 2048 + + def get_omni_inputs( + self, + prompts: list[str] | str, + system_prompt: str | None = None, + audios: PromptAudioInput = None, + images: PromptImageInput = None, + videos: PromptVideoInput = None, + mm_processor_kwargs: dict[str, Any] | None = None, + modalities: list[str] | None = None, + ) -> list[TextPrompt]: + if system_prompt is None: + system_prompt = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " + "Group, capable of perceiving auditory and visual inputs, as well as " + "generating text and speech." + ) + video_padding_token = "<|VIDEO|>" + image_padding_token = "<|IMAGE|>" + audio_padding_token = "<|AUDIO|>" + if "Qwen3-Omni-30B-A3B-Instruct" in self.model_name: + video_padding_token = "<|video_pad|>" + image_padding_token = "<|image_pad|>" + audio_padding_token = "<|audio_pad|>" + if isinstance(prompts, str): + prompts = [prompts] + + # Qwen-TTS: follow examples/offline_inference/qwen3_tts/end2end.py style. + # Stage 0 expects token placeholders + additional_information (text/speaker/task_type/...), + # and Talker replaces embeddings in preprocess based on additional_information only. + is_tts_model = "Qwen3-TTS" in self.model_name or "qwen3_tts" in self.model_name.lower() + if is_tts_model and modalities == ["audio"]: + tts_kw = mm_processor_kwargs or {} + task_type = tts_kw.get("task_type", "CustomVoice") + speaker = tts_kw.get("speaker", "Vivian") + language = tts_kw.get("language", "Auto") + max_new_tokens = int(tts_kw.get("max_new_tokens", 2048)) + ref_audio = tts_kw.get("ref_audio", None) + ref_text = tts_kw.get("ref_text", None) + + omni_inputs: list[TextPrompt] = [] + for prompt_text in prompts: + text_str = str(prompt_text).strip() or " " + additional_information: dict[str, Any] = { + "task_type": [task_type], + "text": [text_str], + "language": [language], + "speaker": [speaker], + "max_new_tokens": [max_new_tokens], + } + if ref_audio is not None: + additional_information["ref_audio"] = [ref_audio] + if ref_text is not None: + additional_information["ref_text"] = [ref_text] + plen = self._estimate_prompt_len(additional_information, self.model_name) + input_dict: TextPrompt = { + "prompt_token_ids": [0] * plen, + "additional_information": additional_information, + } + omni_inputs.append(input_dict) + return omni_inputs + + def _normalize(mm_input, num_prompts): + if mm_input is None: + return [None] * num_prompts + if isinstance(mm_input, list): + if len(mm_input) != num_prompts: + raise ValueError("Multimodal input list length must match prompts length") + return mm_input + return [mm_input] * num_prompts + + num_prompts = len(prompts) + audios_list = _normalize(audios, num_prompts) + images_list = _normalize(images, num_prompts) + videos_list = _normalize(videos, num_prompts) + + omni_inputs = [] + for i, prompt_text in enumerate(prompts): + user_content = "" + multi_modal_data = {} + audio = audios_list[i] + if audio is not None: + if isinstance(audio, list): + for _ in audio: + user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>" + multi_modal_data["audio"] = audio + else: + user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>" + multi_modal_data["audio"] = audio + image = images_list[i] + if image is not None: + if isinstance(image, list): + for _ in image: + user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>" + multi_modal_data["image"] = image + else: + user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>" + multi_modal_data["image"] = image + video = videos_list[i] + if video is not None: + if isinstance(video, list): + for _ in video: + user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>" + multi_modal_data["video"] = video + else: + user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>" + multi_modal_data["video"] = video + user_content += prompt_text + + full_prompt = ( + f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + f"<|im_start|>user\n{user_content}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + input_dict: dict[str, Any] = {"prompt": full_prompt} + if multi_modal_data: + input_dict["multi_modal_data"] = multi_modal_data + if modalities: + input_dict["modalities"] = modalities + if mm_processor_kwargs: + input_dict["mm_processor_kwargs"] = mm_processor_kwargs + omni_inputs.append(input_dict) + return omni_inputs + + def generate( + self, + prompts: list[Any], + sampling_params_list: list[Any] | None = None, + ) -> list[Any]: + if sampling_params_list is None: + sampling_params_list = self.get_default_sampling_params_list() + return self.omni.generate(prompts, sampling_params_list) + + def generate_multimodal( + self, + prompts: list[str] | str, + sampling_params_list: list[Any] | None = None, + system_prompt: str | None = None, + audios: PromptAudioInput = None, + images: PromptImageInput = None, + videos: PromptVideoInput = None, + mm_processor_kwargs: dict[str, Any] | None = None, + modalities: list[str] | None = None, + ) -> list[Any]: + omni_inputs = self.get_omni_inputs( + prompts=prompts, + system_prompt=system_prompt, + audios=audios, + images=images, + videos=videos, + mm_processor_kwargs=mm_processor_kwargs, + modalities=modalities, + ) + return self.generate(omni_inputs, sampling_params_list) + + def start_profile(self, profile_prefix: str | None = None, stages: list[int] | None = None) -> list[Any]: + return self.omni.start_profile(profile_prefix=profile_prefix, stages=stages) + + def stop_profile(self, stages: list[int] | None = None) -> list[Any]: + return self.omni.stop_profile(stages=stages) + + def _cleanup_process(self): + try: + keywords = ["enginecore"] + matched = [] + for proc in psutil.process_iter(["pid", "name", "cmdline", "username"]): + try: + cmdline = " ".join(proc.cmdline()).lower() if proc.cmdline() else "" + name = proc.name().lower() + if any(k in cmdline for k in keywords) or any(k in name for k in keywords): + matched.append(proc) + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + for proc in matched: + try: + proc.terminate() + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + _, still_alive = psutil.wait_procs(matched, timeout=5) + for proc in still_alive: + try: + proc.kill() + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except Exception as e: + print(f"Error in psutil vllm cleanup: {e}") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if hasattr(self.omni, "close"): + self.omni.close() + self._cleanup_process() + _run_pre_test_cleanup(enable_force=True) + _run_post_test_cleanup(enable_force=True) + cleanup_dist_env_and_memory() + + +class OmniRunnerHandler: + def __init__(self, omni_runner): + self.runner = omni_runner + + def _process_output(self, outputs: list[Any]) -> OmniResponse: + result = OmniResponse() + try: + text_content = None + audio_content = None + for stage_output in outputs: + if getattr(stage_output, "final_output_type", None) == "text": + text_content = stage_output.request_output.outputs[0].text + if getattr(stage_output, "final_output_type", None) == "audio": + audio_content = stage_output.request_output.outputs[0].multimodal_output["audio"] + result.audio_content = audio_content + result.text_content = text_content + result.success = True + except Exception as e: + result.error_message = f"Output processing error: {str(e)}" + result.success = False + return result + + def send_request(self, request_config: dict[str, Any] | None = None) -> OmniResponse: + if request_config is None: + request_config = {} + prompts = request_config.get("prompts") + videos = request_config.get("videos") + images = request_config.get("images") + audios = request_config.get("audios") + modalities = request_config.get("modalities", ["text", "audio"]) + outputs = self.runner.generate_multimodal( + prompts=prompts, videos=videos, images=images, audios=audios, modalities=modalities + ) + response = self._process_output(outputs) + assert_omni_response(response, request_config, run_level="core_model") + return response + + def send_audio_speech_request(self, request_config: dict[str, Any]) -> OmniResponse: + """ + Offline TTS: text -> audio via generate_multimodal, then validate with assert_audio_speech_response. + + request_config must contain: + - 'input' or 'prompts': text to synthesize. + Optional keys: + - 'voice' -> speaker (CustomVoice) + - 'task_type' -> task_type in additional_information (default: "CustomVoice") + - 'language' -> language in additional_information (default: "Auto") + - 'max_new_tokens' -> max_new_tokens in additional_information (default: 2048) + - 'response_format' -> desired audio format (used only for assertion) + """ + input_text = request_config.get("input") or request_config.get("prompts") + if input_text is None: + raise ValueError("request_config must contain 'input' or 'prompts' for TTS") + if isinstance(input_text, list): + input_text = input_text[0] if input_text else "" + + mm_processor_kwargs: dict[str, Any] = {} + if "voice" in request_config: + mm_processor_kwargs["speaker"] = request_config["voice"] + if "task_type" in request_config: + mm_processor_kwargs["task_type"] = request_config["task_type"] + if "ref_audio" in request_config: + mm_processor_kwargs["ref_audio"] = request_config["ref_audio"] + if "ref_text" in request_config: + mm_processor_kwargs["ref_text"] = request_config["ref_text"] + if "language" in request_config: + mm_processor_kwargs["language"] = request_config["language"] + if "max_new_tokens" in request_config: + mm_processor_kwargs["max_new_tokens"] = request_config["max_new_tokens"] + + outputs = self.runner.generate_multimodal( + prompts=input_text, + modalities=["audio"], + mm_processor_kwargs=mm_processor_kwargs or None, + ) + mm_out: dict[str, Any] | None = None + for stage_out in outputs: + if getattr(stage_out, "final_output_type", None) == "audio": + mm_out = stage_out.request_output.outputs[0].multimodal_output + break + if mm_out is None: + result = OmniResponse(success=False, error_message="No audio output from pipeline") + assert result.success, result.error_message + return result + + audio_data = mm_out.get("audio") + if audio_data is None: + result = OmniResponse(success=False, error_message="No audio tensor in multimodal output") + assert result.success, result.error_message + return result + + sr_raw = mm_out.get("sr") + sr_val = sr_raw[-1] if isinstance(sr_raw, list) and sr_raw else sr_raw + sr = int(sr_val.item() if hasattr(sr_val, "item") else sr_val) + wav_tensor = torch.cat(audio_data, dim=-1) if isinstance(audio_data, list) else audio_data + wav_buf = io.BytesIO() + sf.write( + wav_buf, + wav_tensor.float().cpu().numpy().reshape(-1), + samplerate=sr, + format="WAV", + subtype="PCM_16", + ) + result = OmniResponse(success=True, audio_bytes=wav_buf.getvalue(), audio_format="audio/wav") + assert_audio_speech_response(result, request_config, run_level="core_model") + return result + + def start_profile(self, profile_prefix: str | None = None, stages: list[int] | None = None) -> list[Any]: + return self.runner.start_profile(profile_prefix=profile_prefix, stages=stages) + + def stop_profile(self, stages: list[int] | None = None) -> list[Any]: + return self.runner.stop_profile(stages=stages) + + +__all__ = [ + "DiffusionResponse", + "OmniResponse", + "OmniRunner", + "OmniRunnerHandler", + "OmniServer", + "OmniServerParams", + "OpenAIClientHandler", + "get_open_port", +] diff --git a/tests/helpers/stage_config.py b/tests/helpers/stage_config.py new file mode 100644 index 00000000000..52a7da10034 --- /dev/null +++ b/tests/helpers/stage_config.py @@ -0,0 +1,193 @@ +"""Config/message construction helpers used by tests.""" + +import time +from pathlib import Path +from typing import Any + +import yaml + + +def dummy_messages_from_mix_data( + system_prompt: dict[str, Any] = None, + video_data_url: Any = None, + audio_data_url: Any = None, + image_data_url: Any = None, + content_text: str = None, +): + """Create messages with video、image、audio data URL for OpenAI API.""" + if content_text is not None: + content = [{"type": "text", "text": content_text}] + else: + content = [] + + media_items = [] + if isinstance(video_data_url, list): + for video_url in video_data_url: + media_items.append((video_url, "video")) + else: + media_items.append((video_data_url, "video")) + + if isinstance(image_data_url, list): + for url in image_data_url: + media_items.append((url, "image")) + else: + media_items.append((image_data_url, "image")) + + if isinstance(audio_data_url, list): + for url in audio_data_url: + media_items.append((url, "audio")) + else: + media_items.append((audio_data_url, "audio")) + + content.extend( + {"type": f"{media_type}_url", f"{media_type}_url": {"url": url}} + for url, media_type in media_items + if url is not None + ) + messages = [{"role": "user", "content": content}] + if system_prompt is not None: + messages = [system_prompt] + messages + return messages + + +def modify_stage_config( + yaml_path: str, + updates: dict[str, Any] = None, + deletes: dict[str, Any] = None, +) -> str: + path = Path(yaml_path) + if not path.exists(): + raise FileNotFoundError(f"yaml does not exist: {path}") + + try: + with open(yaml_path, encoding="utf-8") as f: + config = yaml.safe_load(f) or {} + except Exception as e: + raise ValueError(f"Cannot parse YAML file: {e}") + + def apply_update(config_dict: dict, key_path: str, value: Any) -> None: + if "." not in key_path: + config_dict[key_path] = value + return + current = config_dict + keys = key_path.split(".") + for i in range(len(keys) - 1): + key = keys[i] + if key.isdigit() and isinstance(current, list): + index = int(key) + if index < 0: + raise ValueError(f"Negative list index not allowed: {index}") + if index >= len(current): + while len(current) <= index: + current.append({} if i < len(keys) - 2 else None) + current = current[index] + elif isinstance(current, dict): + if key not in current: + if keys[i + 1].isdigit(): + current[key] = [] + else: + current[key] = {} + elif not isinstance(current[key], (dict, list)) and i < len(keys) - 1: + current[key] = [] if keys[i + 1].isdigit() else {} + current = current[key] + else: + raise TypeError( + f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}" + ) + + last_key = keys[-1] + if isinstance(current, list) and last_key.isdigit(): + index = int(last_key) + if index < 0: + raise ValueError(f"Negative list index not allowed: {index}") + if index >= len(current): + while len(current) <= index: + current.append(None) + current[index] = value + elif isinstance(current, dict): + current[last_key] = value + else: + raise TypeError(f"Cannot set value at {key_path}.") + + def delete_by_path(config_dict: dict, path: str) -> None: + if not path: + return + current = config_dict + keys = path.split(".") + for i in range(len(keys) - 1): + key = keys[i] + if key.isdigit() and isinstance(current, list): + index = int(key) + if index < 0 or index >= len(current): + raise KeyError(f"List index {index} out of bounds") + current = current[index] + elif isinstance(current, dict): + if key not in current: + raise KeyError(f"Path {'.'.join(keys[: i + 1])} does not exist") + current = current[key] + else: + raise TypeError(f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list.") + last_key = keys[-1] + if isinstance(current, list) and last_key.isdigit(): + index = int(last_key) + if index < 0 or index >= len(current): + raise KeyError(f"List index {index} out of bounds") + del current[index] + elif isinstance(current, dict) and last_key in current: + del current[last_key] + + if deletes: + for key, value in deletes.items(): + if key == "stage_args": + if value and isinstance(value, dict): + stage_args = config.get("stage_args", []) + for stage_id, delete_paths in value.items(): + target_stage = None + for stage in stage_args: + if stage.get("stage_id") == int(stage_id): + target_stage = stage + break + if target_stage is None: + continue + for p in delete_paths: + delete_by_path(target_stage, p) + elif "." in key: + delete_by_path(config, key) + elif value is None and key in config: + del config[key] + + if updates: + for key, value in updates.items(): + if key == "stage_args": + if value and isinstance(value, dict): + stage_args = config.get("stage_args", []) + for stage_id, stage_updates in value.items(): + target_stage = None + for stage in stage_args: + if stage.get("stage_id") == int(stage_id): + target_stage = stage + break + if target_stage is None: + available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s] + raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}") + for p, val in stage_updates.items(): + if "." not in p: + target_stage[p] = val + else: + apply_update(target_stage, p, val) + elif "." in key: + apply_update(config, key, value) + else: + config[key] = value + + base_name = yaml_path.rsplit(".", 1)[0] if "." in yaml_path else yaml_path + output_path = f"{base_name}_{time.time_ns()}.yaml" + with open(output_path, "w", encoding="utf-8") as f: + yaml.dump(config, f, default_flow_style=None, sort_keys=False, allow_unicode=True, indent=2) + return output_path + + +__all__ = [ + "dummy_messages_from_mix_data", + "modify_stage_config", +] diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index 84edbbf3d11..00000000000 --- a/tests/utils.py +++ /dev/null @@ -1,621 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Some functions are copied from vllm/tests/utils.py -import functools -import os -import signal -import subprocess -import sys -import tempfile -import threading -import time -from collections.abc import Callable -from contextlib import ExitStack, contextmanager, suppress -from typing import Any, Literal - -import cloudpickle -import pytest -import torch -from typing_extensions import ParamSpec -from vllm.platforms import current_platform -from vllm.utils.torch_utils import cuda_device_count_stateless - -from vllm_omni.platforms import current_omni_platform - -_P = ParamSpec("_P") - -if current_platform.is_rocm(): - from amdsmi import ( - amdsmi_get_gpu_vram_usage, - amdsmi_get_processor_handles, - amdsmi_init, - amdsmi_shut_down, - ) - - @contextmanager - def _nvml(): - try: - amdsmi_init() - yield - finally: - amdsmi_shut_down() -elif current_platform.is_cuda(): - from vllm.third_party.pynvml import ( - nvmlDeviceGetHandleByIndex, - nvmlDeviceGetMemoryInfo, - nvmlInit, - nvmlShutdown, - ) - - @contextmanager - def _nvml(): - try: - nvmlInit() - yield - finally: - nvmlShutdown() -else: - - @contextmanager - def _nvml(): - yield - - -def get_physical_device_indices(devices): - visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") - if visible_devices is None: - return devices - - visible_indices = [int(x) for x in visible_devices.split(",")] - index_mapping = {i: physical for i, physical in enumerate(visible_indices)} - return [index_mapping[i] for i in devices if i in index_mapping] - - -@_nvml() -def wait_for_gpu_memory_to_clear( - *, - devices: list[int], - threshold_bytes: int | None = None, - threshold_ratio: float | None = None, - timeout_s: float = 120, -) -> None: - import gc - - assert threshold_bytes is not None or threshold_ratio is not None - # Use nvml instead of pytorch to reduce measurement error from torch cuda - # context. - devices = get_physical_device_indices(devices) - start_time = time.time() - - # Print waiting start information - device_list = ", ".join(str(d) for d in devices) - if threshold_bytes is not None: - threshold_str = f"{threshold_bytes / 2**30:.2f} GiB" - condition_str = f"Memory usage ≤ {threshold_str}" - else: - threshold_percent = threshold_ratio * 100 - threshold_str = f"{threshold_percent:.1f}%" - condition_str = f"Memory usage ratio ≤ {threshold_str}" - - print(f"[GPU Memory Monitor] Waiting for GPU {device_list} to free memory, Condition: {condition_str}") - - # Define the is_free function based on threshold type - if threshold_bytes is not None: - - def is_free(used, total): - return used <= threshold_bytes / 2**30 - else: - - def is_free(used, total): - return used / total <= threshold_ratio - - while True: - output: dict[int, str] = {} - output_raw: dict[int, tuple[float, float]] = {} - for device in devices: - if current_platform.is_rocm(): - dev_handle = amdsmi_get_processor_handles()[device] - mem_info = amdsmi_get_gpu_vram_usage(dev_handle) - gb_used = mem_info["vram_used"] / 2**10 - gb_total = mem_info["vram_total"] / 2**10 - else: - dev_handle = nvmlDeviceGetHandleByIndex(device) - mem_info = nvmlDeviceGetMemoryInfo(dev_handle) - gb_used = mem_info.used / 2**30 - gb_total = mem_info.total / 2**30 - output_raw[device] = (gb_used, gb_total) - # Format to more readable form - usage_percent = (gb_used / gb_total) * 100 if gb_total > 0 else 0 - output[device] = f"{gb_used:.1f}GiB/{gb_total:.1f}GiB ({usage_percent:.1f}%)" - - # Optimized GPU memory status print - print("[GPU Memory Status] Current usage:") - for device_id, mem_info in output.items(): - print(f" GPU {device_id}: {mem_info}") - - # Calculate waiting duration - dur_s = time.time() - start_time - elapsed_minutes = dur_s / 60 - - # Check if all devices meet the condition - if all(is_free(used, total) for used, total in output_raw.values()): - # Optimized completion message - print(f"[GPU Memory Freed] Devices {device_list} meet memory condition") - print(f" Condition: {condition_str}") - print(f" Wait time: {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)") - print(" Final status:") - for device_id, mem_info in output.items(): - print(f" GPU {device_id}: {mem_info}") - break - - # Check timeout - if dur_s >= timeout_s: - raise ValueError( - f"[GPU Memory Timeout] Devices {device_list} still don't meet memory condition after {dur_s:.1f} seconds\n" - f"Condition: {condition_str}\n" - f"Current status:\n" + "\n".join(f" GPU {device}: {output[device]}" for device in devices) - ) - - # Add waiting hint (optional) - if dur_s > 10 and int(dur_s) % 10 == 0: # Show hint every 10 seconds - print(f"Waiting... Already waited {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)") - - gc.collect() - torch.cuda.empty_cache() - - time.sleep(5) - - -def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]: - """Decorator to fork a new process for each test function. - See https://github.com/vllm-project/vllm/issues/7053 for more details. - """ - - @functools.wraps(func) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: - # Make the process the leader of its own process group - # to avoid sending SIGTERM to the parent process - os.setpgrp() - from _pytest.outcomes import Skipped - - # Create a unique temporary file to store exception info from child - # process. Use test function name and process ID to avoid collisions. - with ( - tempfile.NamedTemporaryFile( - delete=False, mode="w+b", prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", suffix=".exc" - ) as exc_file, - ExitStack() as delete_after, - ): - exc_file_path = exc_file.name - delete_after.callback(os.remove, exc_file_path) - - pid = os.fork() - print(f"Fork a new process to run a test {pid}") - if pid == 0: - # Parent process responsible for deleting, don't delete - # in child. - delete_after.pop_all() - try: - func(*args, **kwargs) - except Skipped as e: - # convert Skipped to exit code 0 - print(str(e)) - os._exit(0) - except Exception as e: - import traceback - - tb_string = traceback.format_exc() - - # Try to serialize the exception object first - exc_to_serialize: dict[str, Any] - try: - # First, try to pickle the actual exception with - # its traceback. - exc_to_serialize = {"pickled_exception": e} - # Test if it can be pickled - cloudpickle.dumps(exc_to_serialize) - except (Exception, KeyboardInterrupt): - # Fall back to string-based approach. - exc_to_serialize = { - "exception_type": type(e).__name__, - "exception_msg": str(e), - "traceback": tb_string, - } - try: - with open(exc_file_path, "wb") as f: - cloudpickle.dump(exc_to_serialize, f) - except Exception: - # Fallback: just print the traceback. - print(tb_string) - os._exit(1) - else: - os._exit(0) - else: - pgid = os.getpgid(pid) - _pid, _exitcode = os.waitpid(pid, 0) - # ignore SIGTERM signal itself - old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) - # kill all child processes - os.killpg(pgid, signal.SIGTERM) - # restore the signal handler - signal.signal(signal.SIGTERM, old_signal_handler) - if _exitcode != 0: - # Try to read the exception from the child process - exc_info = {} - if os.path.exists(exc_file_path): - with suppress(Exception), open(exc_file_path, "rb") as f: - exc_info = cloudpickle.load(f) - - if (original_exception := exc_info.get("pickled_exception")) is not None: - # Re-raise the actual exception object if it was - # successfully pickled. - assert isinstance(original_exception, Exception) - raise original_exception - - if (original_tb := exc_info.get("traceback")) is not None: - # Use string-based traceback for fallback case - raise AssertionError( - f"Test {func.__name__} failed when called with" - f" args {args} and kwargs {kwargs}" - f" (exit code: {_exitcode}):\n{original_tb}" - ) from None - - # Fallback to the original generic error - raise AssertionError( - f"function {func.__name__} failed when called with" - f" args {args} and kwargs {kwargs}" - f" (exit code: {_exitcode})" - ) from None - - return wrapper - - -def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]: - """Decorator to spawn a new process for each test function.""" - - @functools.wraps(f) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: - # Check if we're already in a subprocess - if os.environ.get("RUNNING_IN_SUBPROCESS") == "1": - # If we are, just run the function directly - return f(*args, **kwargs) - - import torch.multiprocessing as mp - - with suppress(RuntimeError): - mp.set_start_method("spawn") - - # Get the module - module_name = f.__module__ - - # Create a process with environment variable set - env = os.environ.copy() - env["RUNNING_IN_SUBPROCESS"] = "1" - - with tempfile.TemporaryDirectory() as tempdir: - output_filepath = os.path.join(tempdir, "new_process.tmp") - - # `cloudpickle` allows pickling complex functions directly - input_bytes = cloudpickle.dumps((f, output_filepath)) - - cmd = [sys.executable, "-m", f"{module_name}"] - - returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env) - - # check if the subprocess is successful - try: - returned.check_returncode() - except Exception as e: - # wrap raised exception to provide more information - raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e - - return wrapper - - -def create_new_process_for_each_test( - method: Literal["spawn", "fork"] | None = None, -) -> Callable[[Callable[_P, None]], Callable[_P, None]]: - """Creates a decorator that runs each test function in a new process. - - Args: - method: The process creation method. Can be either "spawn" or "fork". - If not specified, it defaults to "spawn" on ROCm and XPU - platforms and "fork" otherwise. - - Returns: - A decorator to run test functions in separate processes. - """ - if method is None: - # TODO: Spawn is not working correctly on ROCm - # The test content will not run and tests passed immediately. - # For now, using `fork` for ROCm as it can run with `fork` - # and tests are running correctly. - use_spawn = current_platform.is_xpu() - method = "spawn" if use_spawn else "fork" - - assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'" - - if method == "fork": - return fork_new_process_for_each_test - - return spawn_new_process_for_each_test - - -def cuda_marks(*, res: str, num_cards: int): - """ - Get a collection of pytest marks to apply for `@cuda_test`. - - Args: - res: Resource type, e.g., "L4" or "H100". - num_cards: Number of GPU cards required. - - Returns: - List of pytest marks to apply. - """ - test_platform_detail = pytest.mark.cuda - - if res == "L4": - test_resource = pytest.mark.L4 - elif res == "H100": - test_resource = pytest.mark.H100 - else: - raise ValueError(f"Invalid CUDA resource type: {res}. Supported: L4, H100") - - marks = [test_resource, test_platform_detail] - - if num_cards == 1: - return marks - else: - test_distributed = pytest.mark.distributed_cuda(num_cards=num_cards) - test_skipif = pytest.mark.skipif_cuda( - cuda_device_count_stateless() < num_cards, - reason=f"Need at least {num_cards} CUDA GPUs to run the test.", - ) - return marks + [test_distributed, test_skipif] - - -def rocm_marks(*, res: str, num_cards: int): - """ - Get a collection of pytest marks to apply for `@rocm_test`. - - Args: - res: Resource type, e.g., "MI325". - num_cards: Number of GPU cards required. - - Returns: - List of pytest marks to apply. - """ - test_platform_detail = pytest.mark.rocm - - if res == "MI325": - test_resource = pytest.mark.MI325 - else: - raise ValueError(f"Invalid ROCm resource type: {res}. Supported: MI325") - - marks = [test_resource, test_platform_detail] - - if num_cards == 1: - return marks - else: - test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards) - # TODO: add ROCm support for `skipif_rocm` marker - return marks + [test_distributed] - - -def xpu_marks(*, res: str, num_cards: int): - """ - Get a collection of pytest marks to apply for `@xpu_test`. - - Args: - res: Resource type, e.g., "B60". - num_cards: Number of GPU cards required. - - Returns: - List of pytest marks to apply. - """ - test_platform_detail = pytest.mark.xpu - - if res == "B60": - test_resource = pytest.mark.B60 - else: - raise ValueError(f"Invalid XPU resource type: {res}. Supported: B60") - - marks = [test_resource, test_platform_detail] - - if num_cards == 1: - return marks - else: - test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards) - # TODO: add XPU support for `skipif_xpu` marker - return marks + [test_distributed] - - -def musa_marks(*, res: str, num_cards: int): - """ - Get a collection of pytest marks to apply for `@musa_test`. - - Args: - res: Resource type, e.g., "S5000". - num_cards: Number of GPU cards required. - - Returns: - List of pytest marks to apply. - """ - test_platform_detail = pytest.mark.musa - - if res == "S5000": - test_resource = pytest.mark.S5000 - else: - raise ValueError(f"Invalid MUSA resource type: {res}. Supported: S5000") - - marks = [test_resource, test_platform_detail] - - if num_cards == 1: - return marks - else: - test_distributed = pytest.mark.distributed_musa(num_cards=num_cards) - # TODO: add MUSA support for `skipif_musa` marker - return marks + [test_distributed] - - -def gpu_marks(*, res: str, num_cards: int): - """ - Get a collection of pytest marks to apply for `@gpu_test`. - Platform is automatically determined based on resource type. - - Args: - res: Resource type, e.g., "L4", "H100" for CUDA, or "MI325" for ROCm, or "B60" for XPU, or "S5000" for MUSA. - num_cards: Number of GPU cards required. - - Returns: - List of pytest marks to apply. - """ - test_platform = pytest.mark.gpu - if res in ("L4", "H100"): - return [test_platform] + cuda_marks(res=res, num_cards=num_cards) - if res == "MI325": - return [test_platform] + rocm_marks(res=res, num_cards=num_cards) - if res == "B60": - return [test_platform] + xpu_marks(res=res, num_cards=num_cards) - if res == "S5000": - return [test_platform] + musa_marks(res=res, num_cards=num_cards) - raise ValueError(f"Invalid resource type: {res}. Supported: L4, H100, MI325, B60, S5000") - - -def npu_marks(*, res: str, num_cards: int): - """Get a collection of pytest marks to apply for `@npu_test`.""" - test_platform = pytest.mark.npu - if res == "A2": - test_resource = pytest.mark.A2 - elif res == "A3": - test_resource = pytest.mark.A3 - else: - # TODO: Currently we don't have various NPU card types defined - # Use None to skip resource-specific marking for unknown types - test_resource = None - - if num_cards == 1: - return [mark for mark in [test_platform, test_resource] if mark is not None] - else: - # Multiple cards scenario needs distributed_npu mark - test_distributed = pytest.mark.distributed_npu(num_cards=num_cards) - # TODO: add NPU support for `skipif_npu` marker - return [mark for mark in [test_platform, test_resource, test_distributed] if mark is not None] - - -def hardware_marks(*, res: dict[str, str], num_cards: int | dict[str, int] = 1): - """ - Get a collection of pytest marks to apply for `@hardware_test`, - including CUDA, ROCm, XPU, NPU, and MUSA, - based on the specified platforms and resources. - """ - # Validate platforms - # Don't validate platform details in this decorator - for platform, _ in res.items(): - if platform not in ("cuda", "rocm", "xpu", "npu", "musa"): - raise ValueError(f"Unsupported platform: {platform}") - - # Normalize num_cards - if isinstance(num_cards, int): - num_cards_dict = {platform: num_cards for platform in res.keys()} - else: - num_cards_dict = num_cards - for platform in num_cards_dict.keys(): - if platform not in res: - raise ValueError( - f"Platform '{platform}' in num_cards but not in res. Available platforms: {list(res.keys())}" - ) - for platform in res.keys(): - if platform not in num_cards_dict: - num_cards_dict[platform] = 1 - - # Collect marks from all platforms - all_marks: list[pytest.MarkDecorator] = [] - for platform, resource in res.items(): - cards = num_cards_dict[platform] - if platform == "cuda" or platform == "rocm" or platform == "xpu": - marks = gpu_marks(res=resource, num_cards=cards) - elif platform == "musa": - marks = musa_marks(res=resource, num_cards=cards) - elif platform == "npu": - marks = npu_marks(res=resource, num_cards=cards) - else: - raise ValueError(f"Unsupported platform: {platform}") - all_marks.extend(marks) - return all_marks - - -def hardware_test(*, res: dict[str, str], num_cards: int | dict[str, int] = 1): - """ - Decorate a test for multiple hardware platforms with a single call. - Automatically wraps the test with @create_new_process_for_each_test() for distributed tests. - - Args: - res: Mapping from platform to resource type. Supported platforms/resources: - - cuda: L4, H100 - - rocm: MI325 - - xpu: B60 - - npu: A2, A3 - - musa: S5000 - num_cards: Number of cards required. Can be: - - int: same card count for all platforms (default: 1) - - dict: per-platform card count, e.g., {"cuda": 2, "rocm": 2} - - Example: - @hardware_test( - res={"cuda": "L4", "rocm": "MI325", "npu": "A2", "musa": "S5000"}, - num_cards={"cuda": 2, "rocm": 2, "npu": 2, "musa": 2}, - ) - def test_multi_platform(): - ... - """ - all_marks = hardware_marks(res=res, num_cards=num_cards) - - def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: - func = f - for mark in reversed(all_marks): - func = mark(func) - return func - - return wrapper - - -class DeviceMemoryMonitor: - """Poll global device memory usage.""" - - def __init__(self, device_index: int, interval: float = 0.05): - self.device_index = device_index - self.interval = interval - self._peak_used_mb = 0.0 - self._stop_event = threading.Event() - self._thread: threading.Thread | None = None - - def start(self) -> None: - def monitor_loop() -> None: - while not self._stop_event.is_set(): - try: - with current_omni_platform.device(self.device_index): - free_bytes, total_bytes = current_omni_platform.mem_get_info() - used_mb = (total_bytes - free_bytes) / (1024**2) - self._peak_used_mb = max(self._peak_used_mb, used_mb) - except Exception: - pass - time.sleep(self.interval) - - self._thread = threading.Thread(target=monitor_loop, daemon=False) - self._thread.start() - - def stop(self) -> None: - if self._thread is None: - return - self._stop_event.set() - self._thread.join(timeout=2.0) - - @property - def peak_used_mb(self) -> float: - fallback_alloc = current_omni_platform.max_memory_allocated(device=self.device_index) / (1024**2) - fallback_reserved = current_omni_platform.max_memory_reserved(device=self.device_index) / (1024**2) - return max(self._peak_used_mb, fallback_alloc, fallback_reserved) - - def __del__(self): - self.stop() diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py index 1c08a1543d2..819a7c8c3dd 100644 --- a/tools/pre_commit/check_pickle_imports.py +++ b/tools/pre_commit/check_pickle_imports.py @@ -16,8 +16,7 @@ # alternatives like msgpack or pydantic that are already in use in vLLM. Only # add to this list if absolutely necessary and after careful security review. ALLOWED_FILES = { - "tests/e2e/offline_inference/utils.py", - "tests/utils.py", + "tests/helpers/process.py", "vllm_omni/diffusion/distributed/group_coordinator.py", "tests/diffusion/attention/test_attention_sp.py", } diff --git a/vllm_omni/benchmarks/metrics/metrics.py b/vllm_omni/benchmarks/metrics/metrics.py index a2acc7d7567..dbf764698a0 100644 --- a/vllm_omni/benchmarks/metrics/metrics.py +++ b/vllm_omni/benchmarks/metrics/metrics.py @@ -185,7 +185,7 @@ def calculate_metrics( # Note : this may inflate the output token count slightly output_len = len(tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids) actual_output_lens.append(output_len) - total_input += input_requests[i].prompt_len + total_input += outputs[i].prompt_len tpot = 0 if output_len > 1: latency_minus_ttft = outputs[i].text_latency - outputs[i].ttft diff --git a/vllm_omni/benchmarks/patch/patch.py b/vllm_omni/benchmarks/patch/patch.py index d8145c40bcd..343655df20e 100644 --- a/vllm_omni/benchmarks/patch/patch.py +++ b/vllm_omni/benchmarks/patch/patch.py @@ -190,6 +190,10 @@ async def async_request_openai_chat_omni_completions( if metrics := data.get("metrics"): output.output_tokens = metrics.get("num_tokens_out", 0) + if usage := data.get("usage"): + if (pt := usage.get("prompt_tokens")) is not None: + output.prompt_len = pt + output.latency = timestamp - st output.generated_text = generated_text if generated_audio is not None: From 651c636b381852bc13ea5da7c93a2d739fa39d32 Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Thu, 9 Apr 2026 10:43:42 +0800 Subject: [PATCH 02/19] Enhance error handling and logging in assertion and media helper functions Signed-off-by: wangyu <410167048@qq.com> --- tests/helpers/assertions.py | 43 +++++++++++++++++++-- tests/helpers/env.py | 63 +++++++++++++++++++++++++++++++ tests/helpers/fixtures/env.py | 1 + tests/helpers/fixtures/runtime.py | 8 ++++ tests/helpers/media.py | 41 ++++++++++++++------ tests/helpers/runtime.py | 35 +++++++++++++++++ tests/helpers/stage_config.py | 2 + 7 files changed, 178 insertions(+), 15 deletions(-) diff --git a/tests/helpers/assertions.py b/tests/helpers/assertions.py index 44346150d20..b97f769567d 100644 --- a/tests/helpers/assertions.py +++ b/tests/helpers/assertions.py @@ -192,6 +192,9 @@ def assert_video_valid( "fps": actual_fps, "num_frames": actual_frames, } + except Exception as e: + print(f"ERROR: {type(e).__name__}: {e}", flush=True) + raise finally: if cap is not None: cap.release() @@ -280,6 +283,7 @@ def _estimate_voice_gender_from_audio(audio_bytes: bytes) -> str: median_f0 = _median_pitch_hz_from_autocorr(mono, sr) clf = _load_gender_pipeline() if clf is None: + print("gender model not available, returning 'unknown'") return "unknown" with _GENDER_PIPELINE_LOCK: outputs = clf(mono, sampling_rate=sr) @@ -298,21 +302,30 @@ def _estimate_voice_gender_from_audio(audio_bytes: bytes) -> str: gender = "unknown" if gender == "female" and median_f0 is not None and median_f0 < 165.0 and conf < 0.88: + print(f"gender pitch assist: reclassifying female->male (median_f0={median_f0:.1f} Hz, conf={conf:.3f})") gender = "male" elif gender == "male" and median_f0 is not None and median_f0 > 230.0 and conf < 0.88: + print(f"gender pitch assist: reclassifying male->female (median_f0={median_f0:.1f} Hz, conf={conf:.3f})") gender = "female" + print( + f"gender classifier: label={label}, conf={conf:.3f}, gender={gender}" + + (f", median_f0={median_f0:.1f}Hz" if median_f0 is not None else "") + ) return gender - except Exception: # pragma: no cover + except Exception as exc: # pragma: no cover + print(f"Warning: gender classification failed, returning 'unknown': {exc}") return "unknown" def _assert_preset_voice_gender_from_audio(audio_bytes: bytes | None, voice_name: str | None) -> None: if not voice_name or not audio_bytes: return - expected_gender = _PRESET_VOICE_GENDER_MAP.get(str(voice_name).lower()) + key = str(voice_name).lower() + expected_gender = _PRESET_VOICE_GENDER_MAP.get(key) if expected_gender is None: return estimated_gender = _estimate_voice_gender_from_audio(audio_bytes) + print(f"Preset voice gender check: preset={key!r}, estimated={estimated_gender!r}, expected={expected_gender!r}") if estimated_gender != "unknown": assert estimated_gender == expected_gender @@ -343,20 +356,27 @@ def _assert_pcm_int16_speech_hnr(audio_bytes: bytes) -> None: assert len(audio_bytes) % 2 == 0, "PCM byte length must be aligned to int16" pcm_samples = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0 hnr = _compute_pcm_hnr_db(pcm_samples) + print(f"PCM speech HNR: {hnr:.2f} dB (threshold: {_MIN_PCM_SPEECH_HNR_DB} dB)") assert hnr >= _MIN_PCM_SPEECH_HNR_DB def assert_omni_response(response: Any, request_config: dict[str, Any], run_level): assert response.success, "The request failed." + e2e_latency = getattr(response, "e2e_latency", None) + if e2e_latency is not None: + print(f"the e2e latency is: {e2e_latency}") + modalities = request_config.get("modalities", ["text", "audio"]) if run_level == "advanced_model": if "audio" in modalities: assert response.audio_content is not None, "No audio output is generated" + print(f"audio content is: {response.audio_content}") speaker = request_config.get("speaker") if speaker: _assert_preset_voice_gender_from_audio(response.audio_bytes, speaker) if "text" in modalities: assert response.text_content is not None, "No text output is generated" + print(f"text content is: {response.text_content}") keywords_dict = request_config.get("key_words", {}) for word_type in ["text", "image", "audio", "video"]: keywords = keywords_dict.get(word_type) @@ -370,13 +390,22 @@ def assert_omni_response(response: Any, request_config: dict[str, Any], run_leve assert any(str(kw).lower() in audio_lower for kw in keywords) if "text" in modalities and "audio" in modalities: assert response.similarity is not None and response.similarity > 0.9 + print(f"similarity is: {response.similarity}") def assert_audio_speech_response(response: Any, request_config: dict[str, Any], run_level: str) -> None: assert response.success, "The request failed." + e2e_latency = getattr(response, "e2e_latency", None) + if e2e_latency is not None: + print(f"the avg e2e latency is: {e2e_latency}") + req_fmt = request_config.get("response_format") if req_fmt == "pcm" and response.audio_bytes: _assert_pcm_int16_speech_hnr(response.audio_bytes) + if response.audio_format: + assert "pcm" in response.audio_format.lower(), ( + f"Expected audio/pcm content-type, got {response.audio_format!r}" + ) elif req_fmt == "wav" and response.audio_format: assert req_fmt in response.audio_format @@ -384,13 +413,21 @@ def assert_audio_speech_response(response: Any, request_config: dict[str, Any], expected_text = request_config.get("input") if expected_text: transcript = (response.audio_content or "").strip() + print(f"audio content is: {transcript}") + print(f"input text is: {expected_text}") similarity = cosine_similarity_text(transcript.lower(), expected_text.lower()) - assert similarity > 0.9 + print(f"Cosine similarity: {similarity:.3f}") + assert similarity > 0.9, ( + f"Transcript doesn't match input: similarity={similarity:.2f}, transcript='{transcript}'" + ) _assert_preset_voice_gender_from_audio(response.audio_bytes, request_config.get("voice")) def assert_diffusion_response(response: Any, request_config: dict[str, Any], run_level: str = None): assert response.success, "The request failed." + e2e_latency = getattr(response, "e2e_latency", None) + if e2e_latency is not None: + print(f"the avg e2e is: {e2e_latency}") has_any_content = any(content is not None for content in (response.images, response.videos, response.audios)) assert has_any_content, "Response contains no images, videos, or audios" if response.images is not None: diff --git a/tests/helpers/env.py b/tests/helpers/env.py index ac7f226b7ea..bba1cec1b48 100644 --- a/tests/helpers/env.py +++ b/tests/helpers/env.py @@ -4,6 +4,7 @@ import gc import os +import subprocess import threading import time from contextlib import contextmanager @@ -133,10 +134,68 @@ def is_free(used, total): time.sleep(5) +def _print_gpu_processes() -> None: + """Print GPU information including nvidia-smi and system processes.""" + + print("\n" + "=" * 80) + print("NVIDIA GPU Information (nvidia-smi)") + print("=" * 80) + + try: + nvidia_result = subprocess.run( + ["nvidia-smi"], + capture_output=True, + text=True, + timeout=5, + ) + + if nvidia_result.returncode == 0: + lines = nvidia_result.stdout.strip().split("\n") + for line in lines[:20]: + print(line) + + if len(lines) > 20: + print(f"... (showing first 20 of {len(lines)} lines)") + else: + print("nvidia-smi command failed") + + except (subprocess.TimeoutExpired, FileNotFoundError): + print("nvidia-smi not available or timed out") + except Exception as e: + print(f"Error running nvidia-smi: {e}") + + print("\n" + "=" * 80) + print("Detailed GPU Processes (nvidia-smi pmon)") + print("=" * 80) + + try: + pmon_result = subprocess.run( + ["nvidia-smi", "pmon", "-c", "1"], + capture_output=True, + text=True, + timeout=3, + ) + + if pmon_result.returncode == 0 and pmon_result.stdout.strip(): + print(pmon_result.stdout) + else: + print("No active GPU processes found via nvidia-smi pmon") + + except Exception: + print("nvidia-smi pmon not available") + + print("\n" + "=" * 80) + print("System Processes with GPU keywords") + print("=" * 80) + + def _run_pre_test_cleanup(enable_force: bool = False) -> None: if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force: + print("GPU cleanup disabled") return + print("Pre-test GPU status:") + num_gpus = torch.cuda.device_count() if num_gpus > 0: try: @@ -150,12 +209,16 @@ def _run_pre_test_cleanup(enable_force: bool = False) -> None: def _run_post_test_cleanup(enable_force: bool = False) -> None: if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force: + print("GPU cleanup disabled") return if torch.cuda.is_available(): gc.collect() torch.cuda.empty_cache() + print("Post-test GPU status:") + _print_gpu_processes() + class DeviceMemoryMonitor: """Poll global device memory usage.""" diff --git a/tests/helpers/fixtures/env.py b/tests/helpers/fixtures/env.py index 0126ff2782c..8abcbc4dee7 100644 --- a/tests/helpers/fixtures/env.py +++ b/tests/helpers/fixtures/env.py @@ -31,6 +31,7 @@ def model_prefix() -> str: @pytest.fixture(autouse=True) def clean_gpu_memory_between_tests(): + print("\n=== PRE-TEST GPU CLEANUP ===") _run_pre_test_cleanup() yield _run_post_test_cleanup() diff --git a/tests/helpers/fixtures/runtime.py b/tests/helpers/fixtures/runtime.py index b328e414c9d..e6eb5f620ce 100644 --- a/tests/helpers/fixtures/runtime.py +++ b/tests/helpers/fixtures/runtime.py @@ -52,7 +52,11 @@ def omni_server(request: pytest.FixtureRequest, run_level: str, model_prefix: st env_dict=params.env_dict, use_omni=params.use_omni, ) as server: + print("OmniServer started successfully") yield server + print("OmniServer stopping...") + + print("OmniServer stopped") @pytest.fixture @@ -70,7 +74,11 @@ def omni_runner(request: pytest.FixtureRequest, model_prefix: str): model, stage_config_path = request.param model = model_prefix + model with OmniRunner(model, seed=42, stage_configs_path=stage_config_path, stage_init_timeout=300) as runner: + print("OmniRunner started successfully") yield runner + print("OmniRunner stopping...") + + print("OmniRunner stopped") @pytest.fixture diff --git a/tests/helpers/media.py b/tests/helpers/media.py index 77261987592..989aebc276f 100644 --- a/tests/helpers/media.py +++ b/tests/helpers/media.py @@ -209,6 +209,7 @@ def _enhance_speech(audio: np.ndarray) -> np.ndarray: output_path = f"audio_{num_channels}ch_{timestamp}.wav" try: sf.write(output_path, audio_data, sample_rate, format="WAV", subtype="PCM_16") + print(f"Audio saved: {output_path}") with open(output_path, "rb") as f: audio_bytes = f.read() except Exception as e: @@ -353,11 +354,15 @@ def generate_synthetic_video( "macro_block_size": 16, "ffmpeg_params": ["-preset", "medium", "-crf", "23", "-movflags", "+faststart", "-pix_fmt", "yuv420p"], } - with imageio.get_writer(buffer, **writer_kwargs) as writer: - for frame in video_frames: - writer.append_data(frame) - buffer.seek(0) - video_only_bytes = buffer.read() + try: + with imageio.get_writer(buffer, **writer_kwargs) as writer: + for frame in video_frames: + writer.append_data(frame) + buffer.seek(0) + video_only_bytes = buffer.read() + except Exception as e: + print(f"Warning: Failed to encode synthetic video: {e}") + raise video_bytes = ( _mux_mp4_bytes_with_synthetic_audio(video_only_bytes, num_frames=num_frames, fps=float(fps)) if embed_audio @@ -371,9 +376,13 @@ def generate_synthetic_video( if save_to_file: timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") output_path = f"video_{width}x{height}_{timestamp}.mp4" - with open(output_path, "wb") as f: - f.write(video_bytes) - result["file_path"] = output_path + try: + with open(output_path, "wb") as f: + f.write(video_bytes) + print(f"Video saved to: {output_path}") + result["file_path"] = output_path + except Exception as e: + print(f"Warning: Failed to save video to file {output_path}: {e}") return result @@ -397,10 +406,16 @@ def generate_synthetic_image(width: int, height: int, save_to_file: bool = False if save_to_file: timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") output_path = f"image_{width}x{height}_{timestamp}.jpg" - image.save(output_path, format="JPEG", quality=85, optimize=True) - saved_file_path = output_path - with open(output_path, "rb") as f: - image_bytes = f.read() + try: + image.save(output_path, format="JPEG", quality=85, optimize=True) + saved_file_path = output_path + print(f"Image saved to: {saved_file_path}") + with open(output_path, "rb") as f: + image_bytes = f.read() + except Exception as e: + print(f"Warning: Failed to save image to file {output_path}: {e}") + saved_file_path = None + image_bytes = None if not save_to_file or image_bytes is None: buffer = io.BytesIO() image.save(buffer, format="JPEG", quality=85, optimize=True) @@ -455,6 +470,7 @@ def cosine_similarity_text(text1, text2, n: int = 3): text1 = preprocess_text(text1) text2 = preprocess_text(text2) + print(f"cosine similarity text1 is: {text1}, text2 is: {text2}") ngrams1 = [text1[i : i + n] for i in range(len(text1) - n + 1)] ngrams2 = [text2[i : i + n] for i in range(len(text2) - n + 1)] @@ -534,6 +550,7 @@ def convert_audio_bytes_to_text(raw_bytes: bytes) -> str: output_path = f"./test_{uuid.uuid4().hex}.wav" data, samplerate = sf.read(io.BytesIO(raw_bytes)) sf.write(output_path, data, samplerate, format="WAV", subtype="PCM_16") + print(f"audio data is saved: {output_path}") return convert_audio_file_to_text(output_path) diff --git a/tests/helpers/runtime.py b/tests/helpers/runtime.py index 9936f45d3ef..3db6f118137 100644 --- a/tests/helpers/runtime.py +++ b/tests/helpers/runtime.py @@ -108,6 +108,7 @@ def _start_server(self) -> None: cmd.append("--omni") cmd += self.serve_args + print(f"Launching OmniServer with: {' '.join(cmd)}") self.proc = subprocess.Popen( cmd, env=env, @@ -123,6 +124,7 @@ def _start_server(self) -> None: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.settimeout(1) if sock.connect_ex((self.host, self.port)) == 0: + print(f"Server ready on {self.host}:{self.port}") return time.sleep(2) raise RuntimeError(f"Server failed to start within {max_wait} seconds") @@ -131,17 +133,22 @@ def _kill_process_tree(self, pid): try: parent = psutil.Process(pid) children = parent.children(recursive=True) + all_pids = [pid] + [child.pid for child in children] + for child in children: try: child.terminate() except psutil.NoSuchProcess: pass + _, still_alive = psutil.wait_procs(children, timeout=10) + for child in still_alive: try: child.kill() except psutil.NoSuchProcess: pass + try: parent.terminate() parent.wait(timeout=10) @@ -150,6 +157,21 @@ def _kill_process_tree(self, pid): parent.kill() except psutil.NoSuchProcess: pass + + time.sleep(1) + alive_processes = [] + for check_pid in all_pids: + if psutil.pid_exists(check_pid): + alive_processes.append(check_pid) + + if alive_processes: + print(f"Warning: Processes still alive: {alive_processes}") + for alive_pid in alive_processes: + try: + subprocess.run(["kill", "-9", str(alive_pid)], timeout=2) + except Exception as e: + print(f"Cleanup failed: {e}") + except psutil.NoSuchProcess: pass @@ -229,6 +251,7 @@ def _process_stream_omni_response(self, chat_completion) -> OmniResponse: result.success = True except Exception as e: result.error_message = f"Stream processing error: {str(e)}" + print(f"Error: {result.error_message}") return result def _process_non_stream_omni_response(self, chat_completion) -> OmniResponse: @@ -256,6 +279,7 @@ def _process_non_stream_omni_response(self, chat_completion) -> OmniResponse: result.success = True except Exception as e: result.error_message = f"Non-stream processing error: {str(e)}" + print(f"Error: {result.error_message}") return result def _process_diffusion_response(self, chat_completion) -> DiffusionResponse: @@ -281,6 +305,7 @@ def _process_diffusion_response(self, chat_completion) -> DiffusionResponse: result.success = True except Exception as e: result.error_message = f"Diffusion response processing error: {str(e)}" + print(f"Error: {result.error_message}") return result def send_omni_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]: @@ -869,6 +894,7 @@ def _cleanup_process(self): cmdline = " ".join(proc.cmdline()).lower() if proc.cmdline() else "" name = proc.name().lower() if any(k in cmdline for k in keywords) or any(k in name for k in keywords): + print(f"Found vllm process: PID={proc.pid}, cmd={cmdline[:100]}") matched.append(proc) except (psutil.NoSuchProcess, psutil.AccessDenied): pass @@ -883,6 +909,14 @@ def _cleanup_process(self): proc.kill() except (psutil.NoSuchProcess, psutil.AccessDenied): pass + if still_alive: + _, stubborn = psutil.wait_procs(still_alive, timeout=3) + if stubborn: + print(f"Warning: failed to kill residual vllm pids: {[p.pid for p in stubborn]}") + else: + print(f"Force-killed residual vllm pids: {[p.pid for p in still_alive]}") + elif matched: + print(f"Terminated vllm pids: {[p.pid for p in matched]}") except Exception as e: print(f"Error in psutil vllm cleanup: {e}") @@ -918,6 +952,7 @@ def _process_output(self, outputs: list[Any]) -> OmniResponse: except Exception as e: result.error_message = f"Output processing error: {str(e)}" result.success = False + print(f"Error: {result.error_message}") return result def send_request(self, request_config: dict[str, Any] | None = None) -> OmniResponse: diff --git a/tests/helpers/stage_config.py b/tests/helpers/stage_config.py index 52a7da10034..cd2970fd35f 100644 --- a/tests/helpers/stage_config.py +++ b/tests/helpers/stage_config.py @@ -135,6 +135,8 @@ def delete_by_path(config_dict: dict, path: str) -> None: del current[index] elif isinstance(current, dict) and last_key in current: del current[last_key] + else: + print(f"Path {path} does not exist") if deletes: for key, value in deletes.items(): From b44ee2d2efa6db794a6194022115beb8f5087679 Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Thu, 9 Apr 2026 11:12:23 +0800 Subject: [PATCH 03/19] Adapt to the latest version of the code. Signed-off-by: wangyu <410167048@qq.com> --- .../ci/test_examples/l4_functionality_tests.inc.md | 2 +- docs/contributing/ci/tests_markers.md | 8 ++++---- tests/conftest.py | 2 +- tests/diffusion/cache/test_teacache_extractors.py | 2 +- tests/diffusion/distributed/test_ulysses_uaa_perf.py | 2 +- tests/diffusion/models/flux2/test_flux2_transformer_tp.py | 2 +- tests/diffusion/test_diffusion_model_runner.py | 2 +- tests/e2e/offline_inference/test_bagel_lora.py | 4 ++-- .../offline_inference/test_ltx2_cfg_parallel_parity.py | 2 +- tests/e2e/online_serving/test_bagel_expansion.py | 2 +- tests/e2e/online_serving/test_qwen3_omni_expansion.py | 4 ++-- .../models/cosyvoice3/test_cosyvoice3_components.py | 2 +- 12 files changed, 17 insertions(+), 17 deletions(-) diff --git a/docs/contributing/ci/test_examples/l4_functionality_tests.inc.md b/docs/contributing/ci/test_examples/l4_functionality_tests.inc.md index 69d6ad82871..ab4deecd60d 100644 --- a/docs/contributing/ci/test_examples/l4_functionality_tests.inc.md +++ b/docs/contributing/ci/test_examples/l4_functionality_tests.inc.md @@ -40,7 +40,7 @@ Currently all the features are available in online serving mode. Hence, only nee - Test marks: always add `advanced_model` and `diffusion`. Add GPU-related marks if needed. Ref: [Markers for Tests](https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/ci/tests_markers/). - To maximize code reuse, you may refer to - `tests/conftest.py` for `omni_server` (running server in subprocess) and `openai_client` fixtures (sending requests and validating output), `generate_synthetic_image` and `assert_XXX_valid` helper. - - `tests/utils.py` for `@hardware_test(...)` and `hardware_marks`. + - `tests/helpers/mark.py` for `@hardware_test(...)` and `hardware_marks`. - [Parametrizing tests (pytest doc)](https://docs.pytest.org/en/stable/example/parametrize.html) to reuse test function implementation for different cases. - Doc: add a concise docstring for each test function. - Reference L4 test implementation: [tests/e2e/online_serving/test_qwen_image_edit_expansion.py](https://github.com/vllm-project/vllm-omni/blob/main/tests/e2e/online_serving/test_qwen_image_edit_expansion.py). diff --git a/docs/contributing/ci/tests_markers.md b/docs/contributing/ci/tests_markers.md index f6145e81160..7628db284a7 100644 --- a/docs/contributing/ci/tests_markers.md +++ b/docs/contributing/ci/tests_markers.md @@ -53,7 +53,7 @@ def test_video_to_audio() ### Decorator: `@hardware_test` -This decorator is intended to make hardware-aware, cross-platform test authoring easier and more robust for CI/CD environments. The `hardware_test` decorator in `vllm-omni/tests/utils.py` performs the following actions: +This decorator is intended to make hardware-aware, cross-platform test authoring easier and more robust for CI/CD environments. The `hardware_test` decorator in `vllm-omni/tests/helpers/mark.py` performs the following actions: 1. **Applies platform and resource markers** Adds the appropriate pytest markers for each specified hardware platform (e.g., `cuda`, `rocm`, `xpu`, `npu`) and resource type (e.g., `L4`, `H100`, `MI325`, `B60`, `A2`, `A3`). @@ -133,9 +133,9 @@ If you want to add support for a new platform (e.g., "tpu" for a new accelerator "distributed_tpu: Tests that require multiple TPU devices", ] ``` -2. **Implement a marker construction function for your platform** in `vllm-omni/tests/utils.py`: +2. **Implement a marker construction function for your platform** in `vllm-omni/tests/helpers/mark.py`: ```python - # In vllm-omni/tests/utils.py + # In vllm-omni/tests/helpers/mark.py def tpu_marks(*, res: str, num_cards: int): test_platform = pytest.mark.tpu @@ -175,4 +175,4 @@ If you want to add support for a new platform (e.g., "tpu" for a new accelerator - Plug into `hardware_marks` - You're done: tests using `@hardware_test` or `hardware_marks` with your platform now automatically get the correct markers, distribution, and isolation! -See code in `vllm-omni/tests/utils.py` for existing examples (`cuda_marks`, `rocm_marks`, `npu_marks`). +See code in `vllm-omni/tests/helpers/mark.py` for existing examples (`cuda_marks`, `rocm_marks`, `npu_marks`). diff --git a/tests/conftest.py b/tests/conftest.py index 685a7c663a1..81c9b67362f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ - `tests/conftest.py` stays thin: plugin registration + compatibility re-exports. - Importable utilities live under `tests/helpers/`. -- Fixtures live under `tests/fixtures/` and are loaded via `pytest_plugins`. +- Fixtures live under `tests/helpers/fixtures/` and are loaded via `pytest_plugins`. """ from __future__ import annotations diff --git a/tests/diffusion/cache/test_teacache_extractors.py b/tests/diffusion/cache/test_teacache_extractors.py index a52e11b3d46..afaf8bc1e3f 100644 --- a/tests/diffusion/cache/test_teacache_extractors.py +++ b/tests/diffusion/cache/test_teacache_extractors.py @@ -21,7 +21,7 @@ import pytest import torch -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.diffusion.cache.teacache.extractors import extract_flux2_klein_context from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( Flux2Transformer2DModel, diff --git a/tests/diffusion/distributed/test_ulysses_uaa_perf.py b/tests/diffusion/distributed/test_ulysses_uaa_perf.py index 04bbf5ee863..2a16a9ae578 100644 --- a/tests/diffusion/distributed/test_ulysses_uaa_perf.py +++ b/tests/diffusion/distributed/test_ulysses_uaa_perf.py @@ -17,7 +17,7 @@ import torch import torch.distributed as dist -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.diffusion.attention.parallel.ulysses import ( _all_gather_int, _ulysses_all_to_all_any_o, diff --git a/tests/diffusion/models/flux2/test_flux2_transformer_tp.py b/tests/diffusion/models/flux2/test_flux2_transformer_tp.py index faad08afd1c..aa9e9392aca 100644 --- a/tests/diffusion/models/flux2/test_flux2_transformer_tp.py +++ b/tests/diffusion/models/flux2/test_flux2_transformer_tp.py @@ -3,7 +3,7 @@ import pytest import torch -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.diffusion.models.flux2.flux2_transformer import ( Flux2PosEmbed, Flux2Transformer2DModel, diff --git a/tests/diffusion/test_diffusion_model_runner.py b/tests/diffusion/test_diffusion_model_runner.py index 8768986f01d..b63f6d8887f 100644 --- a/tests/diffusion/test_diffusion_model_runner.py +++ b/tests/diffusion/test_diffusion_model_runner.py @@ -8,7 +8,7 @@ import torch import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner pytestmark = [pytest.mark.diffusion] diff --git a/tests/e2e/offline_inference/test_bagel_lora.py b/tests/e2e/offline_inference/test_bagel_lora.py index 593a640478d..9d82741e431 100644 --- a/tests/e2e/offline_inference/test_bagel_lora.py +++ b/tests/e2e/offline_inference/test_bagel_lora.py @@ -32,8 +32,8 @@ from PIL import Image from safetensors.torch import save_file -from tests.conftest import modify_stage_config -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.stage_config import modify_stage_config from vllm_omni.entrypoints.omni import Omni from vllm_omni.lora.request import LoRARequest from vllm_omni.lora.utils import stable_lora_int_id diff --git a/tests/e2e/offline_inference/test_ltx2_cfg_parallel_parity.py b/tests/e2e/offline_inference/test_ltx2_cfg_parallel_parity.py index 659040929e2..07aa5a647be 100644 --- a/tests/e2e/offline_inference/test_ltx2_cfg_parallel_parity.py +++ b/tests/e2e/offline_inference/test_ltx2_cfg_parallel_parity.py @@ -11,7 +11,7 @@ import pytest from PIL import Image -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test REPO_ROOT = Path(__file__).resolve().parents[3] T2V_EXAMPLE = REPO_ROOT / "examples" / "offline_inference" / "text_to_video" / "text_to_video.py" diff --git a/tests/e2e/online_serving/test_bagel_expansion.py b/tests/e2e/online_serving/test_bagel_expansion.py index 342cd60351d..8e6cf0233c5 100644 --- a/tests/e2e/online_serving/test_bagel_expansion.py +++ b/tests/e2e/online_serving/test_bagel_expansion.py @@ -134,7 +134,7 @@ def test_bagel( - Ulysses-SP (degree=2) - Ring-Attention (degree=2) - Validation is delegated to assert_diffusion_response in tests.conftest, + Validation is delegated to assert_diffusion_response in tests/helpers/assertions.py, which checks output dimensions and basic correctness. """ diff --git a/tests/e2e/online_serving/test_qwen3_omni_expansion.py b/tests/e2e/online_serving/test_qwen3_omni_expansion.py index a9d3a515d4b..f35a1a2ddfb 100644 --- a/tests/e2e/online_serving/test_qwen3_omni_expansion.py +++ b/tests/e2e/online_serving/test_qwen3_omni_expansion.py @@ -489,7 +489,7 @@ def test_one_word_prompt_001(omni_server, openai_client) -> None: "key_words": {"text": ["london"]}, } - # Retry only when assert_omni_response fails on text/audio cosine similarity (see tests/conftest.py). + # Retry only when assert_omni_response fails on text/audio cosine similarity (see tests/helpers/assertions.py). _similarity_assert_msg = "The audio content is not same as the text" _max_retries = 3 for attempt in range(_max_retries): @@ -553,7 +553,7 @@ def test_speaker_002(omni_server, openai_client) -> None: "key_words": {"text": ["beijing"]}, } - # Retry only when assert_omni_response fails on preset voice gender (see tests/conftest.py). + # Retry only when assert_omni_response fails on preset voice gender (see tests/helpers/assertions.py). _gender_assert_substr = "estimated gender" _max_retries = 3 for attempt in range(_max_retries): diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py index ec24f6949fe..0e071f724e5 100644 --- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py +++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test class TestPreLookaheadLayer: From a80d9e36455c925f3b5c6299075f4a368e191790 Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Fri, 10 Apr 2026 17:06:18 +0800 Subject: [PATCH 04/19] Refactor media helper functions to support caching of synthetic audio, video, and image generation. Introduce parameters for cache directory and force regeneration, enhancing performance and usability. Remove deprecated save_to_file logic and improve error handling for media processing. Signed-off-by: wangyu <410167048@qq.com> --- tests/helpers/media.py | 190 +++++++++++++++++++++++++++-------------- 1 file changed, 124 insertions(+), 66 deletions(-) diff --git a/tests/helpers/media.py b/tests/helpers/media.py index 989aebc276f..63bfc5ec676 100644 --- a/tests/helpers/media.py +++ b/tests/helpers/media.py @@ -2,7 +2,6 @@ import base64 import concurrent.futures -import datetime import gc import io import logging @@ -15,6 +14,7 @@ import tempfile import time import uuid +from pathlib import Path from typing import Any import numpy as np @@ -24,15 +24,71 @@ logger = logging.getLogger(__name__) +def _resolve_synthetic_media_cache_dir(cache_dir: Path | str | None) -> Path: + if cache_dir is not None: + return Path(cache_dir).expanduser().resolve() + return Path(tempfile.gettempdir()) / "vllm_omni_test_synthetic_media" + + +def _np_array_from_mp4_bytes(video_bytes: bytes) -> np.ndarray: + """Decode MP4 bytes to a (T, H, W, 3) uint8 RGB stack (matches in-memory synthetic frames).""" + import cv2 + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: + tmp.write(video_bytes) + path = tmp.name + cap = None + try: + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + raise RuntimeError("Failed to open cached synthetic video for decode") + frames: list[np.ndarray] = [] + while True: + ok, frame_bgr = cap.read() + if not ok: + break + frames.append(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)) + if not frames: + raise RuntimeError("Cached synthetic video has no decodable frames") + return np.stack(frames, axis=0) + finally: + if cap is not None: + cap.release() + try: + os.unlink(path) + except OSError: + pass + + def generate_synthetic_audio( duration: int, num_channels: int, sample_rate: int = 48000, - save_to_file: bool = False, + *, + force_regenerate: bool = False, + cache_dir: Path | str | None = None, ) -> dict[str, Any]: """ Generate TTS speech with pyttsx3 and return base64 string. + + Caches the WAV under ``cache_dir`` when given, else under the default temp + subdirectory. Reuses the file when the same + ``duration`` / ``num_channels`` / ``sample_rate`` are requested unless + ``force_regenerate`` is true. """ + root = _resolve_synthetic_media_cache_dir(cache_dir) + root.mkdir(parents=True, exist_ok=True) + cache_path = root / f"synth_audio_d{duration}_ch{num_channels}_sr{sample_rate}.wav" + + if not force_regenerate and cache_path.is_file(): + data, _sr = sf.read(str(cache_path), dtype="float32", always_2d=True) + audio_bytes = cache_path.read_bytes() + return { + "np_array": np.asarray(data, dtype=np.float32), + "base64": base64.b64encode(audio_bytes).decode("utf-8"), + "file_path": str(cache_path.resolve()), + } + import pyttsx3 def _pick_voice(engine: pyttsx3.Engine) -> str | None: @@ -200,31 +256,14 @@ def _enhance_speech(audio: np.ndarray) -> np.ndarray: if max_amp > 0: audio_data = audio_data / max_amp * 0.95 - audio_bytes: bytes | None = None - output_path: str | None = None - result: dict[str, Any] = {"np_array": audio_data.copy()} - - if save_to_file: - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - output_path = f"audio_{num_channels}ch_{timestamp}.wav" - try: - sf.write(output_path, audio_data, sample_rate, format="WAV", subtype="PCM_16") - print(f"Audio saved: {output_path}") - with open(output_path, "rb") as f: - audio_bytes = f.read() - except Exception as e: - print(f"Save failed: {e}") - save_to_file = False - - if not save_to_file or audio_bytes is None: - buffer = io.BytesIO() - sf.write(buffer, audio_data, sample_rate, format="WAV", subtype="PCM_16") - buffer.seek(0) - audio_bytes = buffer.read() + sf.write(str(cache_path), audio_data, sample_rate, format="WAV", subtype="PCM_16") + audio_bytes = cache_path.read_bytes() - result["base64"] = base64.b64encode(audio_bytes).decode("utf-8") - result["file_path"] = output_path if save_to_file and output_path else None - return result + return { + "np_array": audio_data.copy(), + "base64": base64.b64encode(audio_bytes).decode("utf-8"), + "file_path": str(cache_path.resolve()), + } def _mux_mp4_bytes_with_synthetic_audio( @@ -242,7 +281,6 @@ def _mux_mp4_bytes_with_synthetic_audio( duration=duration_int, num_channels=1, sample_rate=sample_rate, - save_to_file=False, ) audio_pcm = audio_result["np_array"] except Exception as e: @@ -303,10 +341,29 @@ def generate_synthetic_video( width: int, height: int, num_frames: int, - save_to_file: bool = False, *, embed_audio: bool = False, + force_regenerate: bool = False, + cache_dir: Path | str | None = None, ) -> dict[str, Any]: + """ + Generate synthetic MP4 (optional AAC audio). Caches final bytes by + ``width`` / ``height`` / ``num_frames`` / ``embed_audio`` unless + ``force_regenerate`` is true. Cache root: ``cache_dir`` if given, else the + default temp subdirectory. + """ + root = _resolve_synthetic_media_cache_dir(cache_dir) + root.mkdir(parents=True, exist_ok=True) + cache_path = root / f"synth_video_w{width}_h{height}_nf{num_frames}_ea{int(embed_audio)}.mp4" + + if not force_regenerate and cache_path.is_file(): + video_bytes = cache_path.read_bytes() + return { + "np_array": _np_array_from_mp4_bytes(video_bytes), + "base64": base64.b64encode(video_bytes).decode("utf-8"), + "file_path": str(cache_path.resolve()), + } + import cv2 import imageio @@ -369,24 +426,43 @@ def generate_synthetic_video( else video_only_bytes ) - result: dict[str, Any] = { + cache_path.write_bytes(video_bytes) + + return { "np_array": np.array(video_frames), "base64": base64.b64encode(video_bytes).decode("utf-8"), + "file_path": str(cache_path.resolve()), } - if save_to_file: - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - output_path = f"video_{width}x{height}_{timestamp}.mp4" - try: - with open(output_path, "wb") as f: - f.write(video_bytes) - print(f"Video saved to: {output_path}") - result["file_path"] = output_path - except Exception as e: - print(f"Warning: Failed to save video to file {output_path}: {e}") - return result -def generate_synthetic_image(width: int, height: int, save_to_file: bool = False) -> dict[str, Any]: +def generate_synthetic_image( + width: int, + height: int, + *, + force_regenerate: bool = False, + cache_dir: Path | str | None = None, +) -> dict[str, Any]: + """ + Random colored squares on white background. Caches JPEG by ``width`` / + ``height`` unless ``force_regenerate`` is true. Cache root: ``cache_dir`` + if given, else the default temp subdirectory. + """ + root = _resolve_synthetic_media_cache_dir(cache_dir) + root.mkdir(parents=True, exist_ok=True) + cache_path = root / f"synth_image_w{width}_h{height}.jpg" + + if not force_regenerate and cache_path.is_file(): + from PIL import Image as PILImage + + image = PILImage.open(cache_path) + image.load() + image_bytes = cache_path.read_bytes() + return { + "np_array": np.array(image).copy(), + "base64": base64.b64encode(image_bytes).decode("utf-8"), + "file_path": str(cache_path.resolve()), + } + from PIL import ImageDraw image = Image.new("RGB", (width, height), (255, 255, 255)) @@ -400,32 +476,14 @@ def generate_synthetic_image(width: int, height: int, save_to_file: bool = False border_width = random.randint(1, 5) draw.rectangle([x, y, x + square_size, y + square_size], fill=color, outline=(0, 0, 0), width=border_width) - result: dict[str, Any] = {"np_array": np.array(image).copy()} - image_bytes: bytes | None = None - saved_file_path: str | None = None - if save_to_file: - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - output_path = f"image_{width}x{height}_{timestamp}.jpg" - try: - image.save(output_path, format="JPEG", quality=85, optimize=True) - saved_file_path = output_path - print(f"Image saved to: {saved_file_path}") - with open(output_path, "rb") as f: - image_bytes = f.read() - except Exception as e: - print(f"Warning: Failed to save image to file {output_path}: {e}") - saved_file_path = None - image_bytes = None - if not save_to_file or image_bytes is None: - buffer = io.BytesIO() - image.save(buffer, format="JPEG", quality=85, optimize=True) - buffer.seek(0) - image_bytes = buffer.read() + image.save(str(cache_path), format="JPEG", quality=85, optimize=True) + image_bytes = cache_path.read_bytes() - result["base64"] = base64.b64encode(image_bytes).decode("utf-8") - if save_to_file and saved_file_path: - result["file_path"] = saved_file_path - return result + return { + "np_array": np.array(image).copy(), + "base64": base64.b64encode(image_bytes).decode("utf-8"), + "file_path": str(cache_path.resolve()), + } def decode_b64_image(b64: str): From 15399bd8f7edac0db7aa0c2785dbfa0e769c2f0a Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Mon, 13 Apr 2026 19:34:07 +0800 Subject: [PATCH 05/19] Refactor GPU memory management in test environment helpers. Delay imports of platform-specific modules until needed to ensure proper execution order of fixtures. Introduce a new function for forced GPU cleanup to streamline cleanup processes across different classes. Enhance memory monitoring logic for better clarity and performance during tests. Signed-off-by: wangyu <410167048@qq.com> --- tests/helpers/env.py | 159 +++++++++++++++++----------------- tests/helpers/fixtures/env.py | 6 +- tests/helpers/runtime.py | 21 +++-- 3 files changed, 95 insertions(+), 91 deletions(-) diff --git a/tests/helpers/env.py b/tests/helpers/env.py index bba1cec1b48..d6b4903b1d4 100644 --- a/tests/helpers/env.py +++ b/tests/helpers/env.py @@ -1,4 +1,8 @@ -"""Test environment / lifecycle helpers (GPU cleanup hooks and memory monitoring for tests).""" +"""Test environment / lifecycle helpers (GPU cleanup hooks and memory monitoring for tests). + +``vllm.platforms`` / ``vllm_omni.platforms`` are imported only inside functions that need them +so importing this module at pytest plugin load does not run before session autouse fixtures +""" from __future__ import annotations @@ -10,45 +14,6 @@ from contextlib import contextmanager import torch -from vllm.platforms import current_platform - -from vllm_omni.platforms import current_omni_platform - -if current_platform.is_rocm(): - from amdsmi import ( - amdsmi_get_gpu_vram_usage, - amdsmi_get_processor_handles, - amdsmi_init, - amdsmi_shut_down, - ) - - @contextmanager - def _nvml(): - try: - amdsmi_init() - yield - finally: - amdsmi_shut_down() -elif current_platform.is_cuda(): - from vllm.third_party.pynvml import ( - nvmlDeviceGetHandleByIndex, - nvmlDeviceGetMemoryInfo, - nvmlInit, - nvmlShutdown, - ) - - @contextmanager - def _nvml(): - try: - nvmlInit() - yield - finally: - nvmlShutdown() -else: - - @contextmanager - def _nvml(): - yield def get_physical_device_indices(devices): @@ -60,7 +25,6 @@ def get_physical_device_indices(devices): return [index_mapping[i] for i in devices if i in index_mapping] -@_nvml() def wait_for_gpu_memory_to_clear( *, devices: list[int], @@ -68,6 +32,8 @@ def wait_for_gpu_memory_to_clear( threshold_ratio: float | None = None, timeout_s: float = 120, ) -> None: + from vllm.platforms import current_platform + assert threshold_bytes is not None or threshold_ratio is not None devices = get_physical_device_indices(devices) start_time = time.time() @@ -92,46 +58,75 @@ def is_free(used, total): def is_free(used, total): return used / total <= threshold_ratio - while True: - output: dict[int, str] = {} - output_raw: dict[int, tuple[float, float]] = {} - for device in devices: - if current_platform.is_rocm(): - dev_handle = amdsmi_get_processor_handles()[device] - mem_info = amdsmi_get_gpu_vram_usage(dev_handle) - gb_used = mem_info["vram_used"] / 2**10 - gb_total = mem_info["vram_total"] / 2**10 - else: - dev_handle = nvmlDeviceGetHandleByIndex(device) - mem_info = nvmlDeviceGetMemoryInfo(dev_handle) - gb_used = mem_info.used / 2**30 - gb_total = mem_info.total / 2**30 - output_raw[device] = (gb_used, gb_total) - usage_percent = (gb_used / gb_total) * 100 if gb_total > 0 else 0 - output[device] = f"{gb_used:.1f}GiB/{gb_total:.1f}GiB ({usage_percent:.1f}%)" - - print("[GPU Memory Status] Current usage:") - for device_id, mem_info in output.items(): - print(f" GPU {device_id}: {mem_info}") - - dur_s = time.time() - start_time - elapsed_minutes = dur_s / 60 - if all(is_free(used, total) for used, total in output_raw.values()): - print(f"[GPU Memory Freed] Devices {device_list} meet memory condition") - print(f" Condition: {condition_str}") - print(f" Wait time: {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)") - break - - if dur_s >= timeout_s: - raise ValueError( - f"[GPU Memory Timeout] Devices {device_list} still don't meet memory condition after {dur_s:.1f} seconds\n" - f"Condition: {condition_str}\n" - f"Current status:\n" + "\n".join(f" GPU {device}: {output[device]}" for device in devices) - ) + @contextmanager + def nvml_scope(): + if current_platform.is_rocm(): + from amdsmi import amdsmi_init, amdsmi_shut_down - gc.collect() - torch.cuda.empty_cache() - time.sleep(5) + amdsmi_init() + try: + yield + finally: + amdsmi_shut_down() + elif current_platform.is_cuda(): + from vllm.third_party.pynvml import nvmlInit, nvmlShutdown + + nvmlInit() + try: + yield + finally: + nvmlShutdown() + else: + yield + + is_rocm = current_platform.is_rocm() + + with nvml_scope(): + if is_rocm: + from amdsmi import amdsmi_get_gpu_vram_usage, amdsmi_get_processor_handles + elif current_platform.is_cuda(): + from vllm.third_party.pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo + + while True: + output: dict[int, str] = {} + output_raw: dict[int, tuple[float, float]] = {} + for device in devices: + if is_rocm: + dev_handle = amdsmi_get_processor_handles()[device] + mem_info = amdsmi_get_gpu_vram_usage(dev_handle) + gb_used = mem_info["vram_used"] / 2**10 + gb_total = mem_info["vram_total"] / 2**10 + else: + dev_handle = nvmlDeviceGetHandleByIndex(device) + mem_info = nvmlDeviceGetMemoryInfo(dev_handle) + gb_used = mem_info.used / 2**30 + gb_total = mem_info.total / 2**30 + output_raw[device] = (gb_used, gb_total) + usage_percent = (gb_used / gb_total) * 100 if gb_total > 0 else 0 + output[device] = f"{gb_used:.1f}GiB/{gb_total:.1f}GiB ({usage_percent:.1f}%)" + + print("[GPU Memory Status] Current usage:") + for device_id, mem_info in output.items(): + print(f" GPU {device_id}: {mem_info}") + + dur_s = time.time() - start_time + elapsed_minutes = dur_s / 60 + if all(is_free(used, total) for used, total in output_raw.values()): + print(f"[GPU Memory Freed] Devices {device_list} meet memory condition") + print(f" Condition: {condition_str}") + print(f" Wait time: {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)") + break + + if dur_s >= timeout_s: + raise ValueError( + f"[GPU Memory Timeout] Devices {device_list} still don't meet memory condition after {dur_s:.1f} seconds\n" + f"Condition: {condition_str}\n" + f"Current status:\n" + "\n".join(f" GPU {device}: {output[device]}" for device in devices) + ) + + gc.collect() + torch.cuda.empty_cache() + time.sleep(5) def _print_gpu_processes() -> None: @@ -231,6 +226,8 @@ def __init__(self, device_index: int, interval: float = 0.05): self._thread: threading.Thread | None = None def start(self) -> None: + from vllm_omni.platforms import current_omni_platform + def monitor_loop() -> None: while not self._stop_event.is_set(): try: @@ -253,6 +250,8 @@ def stop(self) -> None: @property def peak_used_mb(self) -> float: + from vllm_omni.platforms import current_omni_platform + fallback_alloc = current_omni_platform.max_memory_allocated(device=self.device_index) / (1024**2) fallback_reserved = current_omni_platform.max_memory_reserved(device=self.device_index) / (1024**2) return max(self._peak_used_mb, fallback_alloc, fallback_reserved) diff --git a/tests/helpers/fixtures/env.py b/tests/helpers/fixtures/env.py index 8abcbc4dee7..1fe01ee09b0 100644 --- a/tests/helpers/fixtures/env.py +++ b/tests/helpers/fixtures/env.py @@ -3,8 +3,6 @@ import pytest import torch -from tests.helpers.env import _run_post_test_cleanup, _run_pre_test_cleanup - @pytest.fixture(scope="session", autouse=True) def default_env(): @@ -31,6 +29,10 @@ def model_prefix() -> str: @pytest.fixture(autouse=True) def clean_gpu_memory_between_tests(): + # Import here so ``tests.helpers.env`` (and vLLM platform modules) load only + # after session autouse fixtures like ``default_env`` have run (RFC #2299). + from tests.helpers.env import _run_post_test_cleanup, _run_pre_test_cleanup + print("\n=== PRE-TEST GPU CLEANUP ===") _run_pre_test_cleanup() yield diff --git a/tests/helpers/runtime.py b/tests/helpers/runtime.py index 3db6f118137..dbd7bf6180f 100644 --- a/tests/helpers/runtime.py +++ b/tests/helpers/runtime.py @@ -27,7 +27,6 @@ assert_diffusion_response, assert_omni_response, ) -from tests.helpers.env import _run_post_test_cleanup, _run_pre_test_cleanup from tests.helpers.media import ( _merge_base64_audio_to_segment, convert_audio_bytes_to_text, @@ -49,6 +48,14 @@ def cleanup_dist_env_and_memory() -> None: return None +def _run_forced_gpu_cleanup_round() -> None: + """Defer ``tests.helpers.env`` import until cleanup runs (RFC #2299).""" + from tests.helpers.env import _run_post_test_cleanup, _run_pre_test_cleanup + + _run_pre_test_cleanup(enable_force=True) + _run_post_test_cleanup(enable_force=True) + + def get_open_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) @@ -76,8 +83,7 @@ def __init__( env_dict: dict[str, str] | None = None, use_omni: bool = True, ) -> None: - _run_pre_test_cleanup(enable_force=True) - _run_post_test_cleanup(enable_force=True) + _run_forced_gpu_cleanup_round() cleanup_dist_env_and_memory() self.model = model self.serve_args = serve_args @@ -182,8 +188,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if self.proc: self._kill_process_tree(self.proc.pid) - _run_pre_test_cleanup(enable_force=True) - _run_post_test_cleanup(enable_force=True) + _run_forced_gpu_cleanup_round() cleanup_dist_env_and_memory() @@ -664,8 +669,7 @@ def __init__( **kwargs, ) -> None: cleanup_dist_env_and_memory() - _run_pre_test_cleanup(enable_force=True) - _run_post_test_cleanup(enable_force=True) + _run_forced_gpu_cleanup_round() self.model_name = model_name self.seed = seed self._prompt_len_estimate_cache: dict[str, Any] = {} @@ -927,8 +931,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): if hasattr(self.omni, "close"): self.omni.close() self._cleanup_process() - _run_pre_test_cleanup(enable_force=True) - _run_post_test_cleanup(enable_force=True) + _run_forced_gpu_cleanup_round() cleanup_dist_env_and_memory() From 3859a966238892c8043296cbf99e17101fbdf408 Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Mon, 13 Apr 2026 20:47:18 +0800 Subject: [PATCH 06/19] Refactor test imports to use helpers for consistency and clarity. Update test markers for better categorization in zimage_parallelism tests. Enhance GPU memory cleanup messages for improved debugging during test execution. Signed-off-by: wangyu <410167048@qq.com> --- pyproject.toml | 1 + tests/e2e/offline_inference/test_dynin_omni.py | 2 +- tests/e2e/offline_inference/test_voxcpm2.py | 4 ++-- tests/e2e/offline_inference/test_zimage_parallelism.py | 5 ++++- tests/e2e/online_serving/test_dynin_omni_expansion.py | 8 ++++---- tests/engine/test_async_omni_engine_abort.py | 1 - tests/entrypoints/openai_api/test_text_splitter.py | 2 +- tests/helpers/env.py | 9 +++++++-- tests/helpers/fixtures/runtime.py | 2 +- 9 files changed, 21 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e49aa6e3251..d814372f424 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -181,6 +181,7 @@ markers = [ # specified computation resources marks (auto-added) "H100: Tests that require H100 GPU", "L4: Tests that require L4 GPU", + "B60: Tests that require B60", "MI325: Tests that require MI325 GPU (AMD/ROCm)", "S5000: Tests that require S5000 GPU (Moore Threads/MUSA)", "A2: Tests that require A2 NPU", diff --git a/tests/e2e/offline_inference/test_dynin_omni.py b/tests/e2e/offline_inference/test_dynin_omni.py index 5388ac67468..f891fc4f12e 100644 --- a/tests/e2e/offline_inference/test_dynin_omni.py +++ b/tests/e2e/offline_inference/test_dynin_omni.py @@ -18,7 +18,7 @@ import torch from transformers import AutoTokenizer -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" diff --git a/tests/e2e/offline_inference/test_voxcpm2.py b/tests/e2e/offline_inference/test_voxcpm2.py index 4e4f635d5c4..9835d08414e 100644 --- a/tests/e2e/offline_inference/test_voxcpm2.py +++ b/tests/e2e/offline_inference/test_voxcpm2.py @@ -5,8 +5,8 @@ import pytest import torch -from tests.conftest import OmniRunner -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniRunner VOXCPM2_MODEL = "openbmb/VoxCPM2" STAGE_CONFIG = os.path.join( diff --git a/tests/e2e/offline_inference/test_zimage_parallelism.py b/tests/e2e/offline_inference/test_zimage_parallelism.py index a66ae328bde..ab330ee9a26 100644 --- a/tests/e2e/offline_inference/test_zimage_parallelism.py +++ b/tests/e2e/offline_inference/test_zimage_parallelism.py @@ -214,7 +214,10 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): ) -@pytest.mark.integration +@pytest.mark.advanced_model +@pytest.mark.diffusion +@pytest.mark.parallel +@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards={"cuda": 4, "rocm": 2}) def test_zimage_vae_patch_parallel_tp2(tmp_path: Path): if current_omni_platform.is_npu(): pytest.skip("Z-Image VAE patch parallel e2e test is only supported on CUDA and ROCm for now.") diff --git a/tests/e2e/online_serving/test_dynin_omni_expansion.py b/tests/e2e/online_serving/test_dynin_omni_expansion.py index 39b6dc8e212..ca6fe7c6637 100644 --- a/tests/e2e/online_serving/test_dynin_omni_expansion.py +++ b/tests/e2e/online_serving/test_dynin_omni_expansion.py @@ -15,9 +15,9 @@ import soundfile as sf from vllm.assets.image import ImageAsset -from tests import conftest as tests_conftest -from tests.conftest import OmniServerParams -from tests.utils import hardware_test +from tests.helpers.mark import hardware_test +from tests.helpers.media import convert_audio_bytes_to_text +from tests.helpers.runtime import OmniServerParams os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" @@ -87,7 +87,7 @@ def _convert_audio_bytes_to_text_without_ffmpeg(raw_bytes: bytes) -> str: @pytest.fixture def dynin_t2s_openai_client(openai_client, monkeypatch): monkeypatch.setattr( - tests_conftest, + convert_audio_bytes_to_text, "convert_audio_bytes_to_text", _convert_audio_bytes_to_text_without_ffmpeg, ) diff --git a/tests/engine/test_async_omni_engine_abort.py b/tests/engine/test_async_omni_engine_abort.py index 20fa9e3f82e..c05ab901387 100644 --- a/tests/engine/test_async_omni_engine_abort.py +++ b/tests/engine/test_async_omni_engine_abort.py @@ -60,7 +60,6 @@ async def generate( @pytest.mark.core_model @pytest.mark.omni -@pytest.mark.real_hf_config @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=1) @pytest.mark.asyncio async def test_abort(): diff --git a/tests/entrypoints/openai_api/test_text_splitter.py b/tests/entrypoints/openai_api/test_text_splitter.py index a1886662ae5..b9022e015dd 100644 --- a/tests/entrypoints/openai_api/test_text_splitter.py +++ b/tests/entrypoints/openai_api/test_text_splitter.py @@ -4,7 +4,7 @@ from vllm_omni.entrypoints.openai.text_splitter import SentenceSplitter -pytestmark = [pytest.mark.openai, pytest.mark.speech, pytest.mark.core_model, pytest.mark.cpu] +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] class TestSentenceSplitterEnglish: diff --git a/tests/helpers/env.py b/tests/helpers/env.py index c027fd95b8f..069179b711e 100644 --- a/tests/helpers/env.py +++ b/tests/helpers/env.py @@ -184,9 +184,14 @@ def _print_gpu_processes() -> None: print("=" * 80) +_SKIPPED_GPU_CLEANUP_MSG = ( + "\nSkipping GPU memory cleanup check (typically: instance already up; no check needed between tests)\n" +) + + def run_pre_test_cleanup(enable_force: bool = False) -> None: if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force: - print("GPU cleanup disabled") + print(_SKIPPED_GPU_CLEANUP_MSG) return print("Pre-test GPU status:") @@ -204,7 +209,7 @@ def run_pre_test_cleanup(enable_force: bool = False) -> None: def run_post_test_cleanup(enable_force: bool = False) -> None: if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force: - print("GPU cleanup disabled") + print(_SKIPPED_GPU_CLEANUP_MSG) return if torch.cuda.is_available(): diff --git a/tests/helpers/fixtures/runtime.py b/tests/helpers/fixtures/runtime.py index 5e0cfcf86a5..5289142a08e 100644 --- a/tests/helpers/fixtures/runtime.py +++ b/tests/helpers/fixtures/runtime.py @@ -109,7 +109,7 @@ def omni_runner(request: pytest.FixtureRequest, model_prefix: str): with omni_fixture_lock: model, stage_config_path = request.param model = model_prefix + model - with OmniRunner(model, seed=42, stage_configs_path=stage_config_path, stage_init_timeout=300) as runner: + with OmniRunner(model, seed=42, stage_configs_path=stage_config_path) as runner: print("OmniRunner started successfully") yield runner print("OmniRunner stopping...") From 926542bc34aa0fc8d5d827114db723791275d2ea Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Wed, 15 Apr 2026 18:56:24 +0800 Subject: [PATCH 07/19] Implement context manager to serialize Whisper small model downloads across processes Signed-off-by: wangyu <410167048@qq.com> --- tests/helpers/media.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/helpers/media.py b/tests/helpers/media.py index 63bfc5ec676..3c45c2a9d95 100644 --- a/tests/helpers/media.py +++ b/tests/helpers/media.py @@ -14,6 +14,7 @@ import tempfile import time import uuid +from contextlib import contextmanager from pathlib import Path from typing import Any @@ -557,6 +558,22 @@ def _merge_base64_audio_to_segment(base64_list: list[str]): return merged +@contextmanager +def _serialize_whisper_small_model_download(): + """Serialize Whisper ``small`` cache writes across processes (Linux/Unix).""" + import fcntl + + lock_path = Path.home() / ".cache" / "whisper" / ".small_model_download.lock" + lock_path.parent.mkdir(parents=True, exist_ok=True) + f = open(lock_path, "a+b") + try: + fcntl.flock(f.fileno(), fcntl.LOCK_EX) + yield + finally: + fcntl.flock(f.fileno(), fcntl.LOCK_UN) + f.close() + + def _whisper_transcribe_in_current_process(output_path: str) -> str: import whisper @@ -579,7 +596,8 @@ def _whisper_transcribe_in_current_process(output_path: str) -> str: use_accelerator = False device = "cpu" - model = whisper.load_model("small", device=device) + with _serialize_whisper_small_model_download(): + model = whisper.load_model("small", device=device) try: text = model.transcribe( output_path, From bd6508dd69f41fb8582dd28ed66cb74ca6720ca5 Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Wed, 15 Apr 2026 20:39:14 +0800 Subject: [PATCH 08/19] Enhance assert_video_valid function to accommodate codec-aligned frame count discrepancies in MP4 validation. Update docstring for clarity on expected behavior and adjust frame count assertion logic. Signed-off-by: wangyu <410167048@qq.com> --- tests/helpers/assertions.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/helpers/assertions.py b/tests/helpers/assertions.py index b97f769567d..992ddd309be 100644 --- a/tests/helpers/assertions.py +++ b/tests/helpers/assertions.py @@ -146,7 +146,12 @@ def assert_video_valid( height: int | None = None, fps: float | None = None, ) -> dict[str, int | float]: - """Assert the MP4 has the expected resolution and exact frame count.""" + """Assert the MP4 has the expected resolution and frame count. + + For several diffusion backends, encoded MP4 frame count follows a codec-aligned + convention (e.g. request `num_frames=8` can produce 9 encoded frames). Keep + this compatibility behavior to avoid false negatives in online-serving tests. + """ temp_path = None cap = None try: @@ -184,7 +189,8 @@ def assert_video_valid( if fps is not None and actual_fps: assert abs(actual_fps - float(fps)) < 1.0, f"Expected fps~={fps}, got {actual_fps}" if num_frames is not None: - assert actual_frames == num_frames, f"Expected frames={num_frames}, got {actual_frames}" + expected_frames = (int(num_frames) // 4) * 4 + 1 + assert actual_frames == expected_frames, f"Expected frames={expected_frames}, got {actual_frames}" return { "width": actual_width, From edf922b1031be8a2622797dabbd9abcbcd4e6b2b Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Wed, 15 Apr 2026 20:46:42 +0800 Subject: [PATCH 09/19] Refactor test imports in Qwen image edit and VoxCPM test files to use helpers from the runtime module. This change improves code organization and maintainability by consolidating cleanup functions under a common namespace. Signed-off-by: wangyu <410167048@qq.com> --- tests/e2e/accuracy/test_qwen_image_edit.py | 11 ++++------- tests/e2e/offline_inference/test_voxcpm.py | 8 ++++---- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/e2e/accuracy/test_qwen_image_edit.py b/tests/e2e/accuracy/test_qwen_image_edit.py index 9a970103438..59e486cb12c 100644 --- a/tests/e2e/accuracy/test_qwen_image_edit.py +++ b/tests/e2e/accuracy/test_qwen_image_edit.py @@ -10,12 +10,9 @@ from PIL import Image from benchmarks.accuracy.common import decode_base64_image, pil_to_png_bytes -from tests.conftest import ( - OmniServer, - _run_post_test_cleanup, - _run_pre_test_cleanup, -) from tests.e2e.accuracy.utils import assert_similarity, model_output_dir +from tests.helpers.env import run_post_test_cleanup, run_pre_test_cleanup +from tests.helpers.runtime import OmniServer from tests.utils import hardware_test SINGLE_MODEL = "Qwen/Qwen-Image-Edit" @@ -77,7 +74,7 @@ def _run_diffusers_image_edit( input_images: list[Image.Image], output_path: Path, ) -> Image.Image: - _run_pre_test_cleanup(enable_force=True) + run_pre_test_cleanup(enable_force=True) pipe: QwenImageEditPipeline | QwenImageEditPlusPipeline | None = None device = torch.device("cuda:0") torch.cuda.set_device(device) @@ -110,7 +107,7 @@ def _run_diffusers_image_edit( gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() - _run_post_test_cleanup(enable_force=True) + run_post_test_cleanup(enable_force=True) def _vllm_omni_output_single_image( diff --git a/tests/e2e/offline_inference/test_voxcpm.py b/tests/e2e/offline_inference/test_voxcpm.py index d7f65525e93..24f94b08758 100644 --- a/tests/e2e/offline_inference/test_voxcpm.py +++ b/tests/e2e/offline_inference/test_voxcpm.py @@ -12,8 +12,8 @@ import pytest import torch -import tests.conftest as omni_test_conftest -from tests.conftest import OmniRunner +import tests.helpers.runtime as omni_runtime +from tests.helpers.runtime import OmniRunner from tests.utils import hardware_test from vllm_omni.model_executor.models.voxcpm.voxcpm_runtime_utils import ( prepare_voxcpm_hf_config_dir, @@ -30,7 +30,7 @@ @pytest.fixture(autouse=True) def _patch_npu_cleanup_for_voxcpm(monkeypatch: pytest.MonkeyPatch): """Limit the NPU cleanup workaround to this VoxCPM test module only.""" - original_cleanup = omni_test_conftest.cleanup_dist_env_and_memory + original_cleanup = omni_runtime.cleanup_dist_env_and_memory def _safe_cleanup() -> None: try: @@ -40,7 +40,7 @@ def _safe_cleanup() -> None: return raise - monkeypatch.setattr(omni_test_conftest, "cleanup_dist_env_and_memory", _safe_cleanup) + monkeypatch.setattr(omni_runtime, "cleanup_dist_env_and_memory", _safe_cleanup) def _build_prompt(text: str) -> dict[str, Any]: From a05663b789150422ad41f413ed172dac69c44d6d Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Wed, 15 Apr 2026 22:07:20 +0800 Subject: [PATCH 10/19] Refactor and enhance image similarity assertion functions by moving them to a new helpers module. This change improves code organization and reusability, while also adding functionality to compute and assert SSIM and PSNR metrics for model outputs. The previous utility functions have been removed from the utils module to streamline the codebase. Signed-off-by: wangyu <410167048@qq.com> --- tests/e2e/accuracy/helpers.py | 71 +++++++++++++++++++++ tests/e2e/accuracy/test_qwen_image_edit.py | 4 +- tests/e2e/accuracy/utils.py | 74 ---------------------- tests/e2e/offline_inference/test_voxcpm.py | 2 +- 4 files changed, 74 insertions(+), 77 deletions(-) delete mode 100644 tests/e2e/accuracy/utils.py diff --git a/tests/e2e/accuracy/helpers.py b/tests/e2e/accuracy/helpers.py index 24b71a471e5..726873e58bc 100644 --- a/tests/e2e/accuracy/helpers.py +++ b/tests/e2e/accuracy/helpers.py @@ -1,5 +1,11 @@ from pathlib import Path +import numpy as np +import pytest +import torch +from PIL import Image +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure + def reset_artifact_dir(path: Path) -> Path: import shutil @@ -13,3 +19,68 @@ def reset_artifact_dir(path: Path) -> Path: def infer_model_label(model: str) -> str: label = Path(model.rstrip("/\\")).name or "model" return "".join(char if char.isalnum() or char in {"-", "_"} else "_" for char in label) + + +def model_output_dir(parent_dir: Path, model: str) -> Path: + safe_model_name = model.split("/")[-1].replace(".", "_") + path = parent_dir / safe_model_name + path.mkdir(parents=True, exist_ok=True) + return path + + +def assert_similarity( + *, + model_name: str, + vllm_image: Image.Image, + diffusers_image: Image.Image, + width: int, + height: int, + ssim_threshold: float, + psnr_threshold: float, +) -> None: + requested_size = (width, height) + if diffusers_image.size != requested_size: + pytest.skip( + "Skipping as diffusers baseline output is corrupt and not comparable: " + f"dimensions do not match requested size; requested={requested_size}, got={diffusers_image.size}." + ) + + assert vllm_image.size == diffusers_image.size, ( + f"Online and diffusers output sizes mismatch: online={vllm_image.size}, diffusers={diffusers_image.size}" + ) + + ssim_score, psnr_score = compute_image_ssim_psnr(prediction=vllm_image, reference=diffusers_image) + print(f"{model_name} similarity metrics:") + print(f" SSIM: value={ssim_score:.6f}, threshold>={ssim_threshold:.6f}, range=[-1, 1], higher_is_better=True") + print( + f" PSNR: value={psnr_score:.6f} dB, threshold>={psnr_threshold:.6f} dB, range=[0, +inf), higher_is_better=True" + ) + + assert ssim_score >= ssim_threshold, ( + f"SSIM below threshold for {model_name}: got {ssim_score:.6f}, expected >= {ssim_threshold:.6f}." + ) + assert psnr_score >= psnr_threshold, ( + f"PSNR below threshold for {model_name}: got {psnr_score:.6f}, expected >= {psnr_threshold:.6f}." + ) + + +def compute_image_ssim_psnr( + *, + prediction: Image.Image, + reference: Image.Image, +) -> tuple[float, float]: + pred_tensor = _pil_to_batched_tensor(prediction) + ref_tensor = _pil_to_batched_tensor(reference) + + ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0) + psnr_metric = PeakSignalNoiseRatio(data_range=1.0) + + ssim_value = float(ssim_metric(pred_tensor, ref_tensor).item()) + psnr_value = float(psnr_metric(pred_tensor, ref_tensor).item()) + return ssim_value, psnr_value + + +def _pil_to_batched_tensor(image: Image.Image) -> torch.Tensor: + array = np.asarray(image.convert("RGB"), dtype=np.float32) / 255.0 + tensor = torch.from_numpy(array).permute(2, 0, 1).unsqueeze(0) + return tensor diff --git a/tests/e2e/accuracy/test_qwen_image_edit.py b/tests/e2e/accuracy/test_qwen_image_edit.py index 59e486cb12c..e17aca6e99b 100644 --- a/tests/e2e/accuracy/test_qwen_image_edit.py +++ b/tests/e2e/accuracy/test_qwen_image_edit.py @@ -10,10 +10,10 @@ from PIL import Image from benchmarks.accuracy.common import decode_base64_image, pil_to_png_bytes -from tests.e2e.accuracy.utils import assert_similarity, model_output_dir +from tests.e2e.accuracy.helpers import assert_similarity, model_output_dir from tests.helpers.env import run_post_test_cleanup, run_pre_test_cleanup +from tests.helpers.mark import hardware_test from tests.helpers.runtime import OmniServer -from tests.utils import hardware_test SINGLE_MODEL = "Qwen/Qwen-Image-Edit" MULTIPLE_MODEL = "Qwen/Qwen-Image-Edit-2509" diff --git a/tests/e2e/accuracy/utils.py b/tests/e2e/accuracy/utils.py deleted file mode 100644 index eb0eea757ee..00000000000 --- a/tests/e2e/accuracy/utils.py +++ /dev/null @@ -1,74 +0,0 @@ -from __future__ import annotations - -from pathlib import Path - -import numpy as np -import pytest -import torch -from PIL import Image -from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure - - -def model_output_dir(parent_dir: Path, model: str) -> Path: - safe_model_name = model.split("/")[-1].replace(".", "_") - path = parent_dir / safe_model_name - path.mkdir(parents=True, exist_ok=True) - return path - - -def assert_similarity( - *, - model_name: str, - vllm_image: Image.Image, - diffusers_image: Image.Image, - width: int, - height: int, - ssim_threshold: float, - psnr_threshold: float, -) -> None: - requested_size = (width, height) - if diffusers_image.size != requested_size: - pytest.skip( - "Skipping as diffusers baseline output is corrupt and not comparable: " - f"dimensions do not match requested size; requested={requested_size}, got={diffusers_image.size}." - ) - - assert vllm_image.size == diffusers_image.size, ( - f"Online and diffusers output sizes mismatch: online={vllm_image.size}, diffusers={diffusers_image.size}" - ) - - ssim_score, psnr_score = compute_image_ssim_psnr(prediction=vllm_image, reference=diffusers_image) - print(f"{model_name} similarity metrics:") - print(f" SSIM: value={ssim_score:.6f}, threshold>={ssim_threshold:.6f}, range=[-1, 1], higher_is_better=True") - print( - f" PSNR: value={psnr_score:.6f} dB, threshold>={psnr_threshold:.6f} dB, range=[0, +inf), higher_is_better=True" - ) - - assert ssim_score >= ssim_threshold, ( - f"SSIM below threshold for {model_name}: got {ssim_score:.6f}, expected >= {ssim_threshold:.6f}." - ) - assert psnr_score >= psnr_threshold, ( - f"PSNR below threshold for {model_name}: got {psnr_score:.6f}, expected >= {psnr_threshold:.6f}." - ) - - -def compute_image_ssim_psnr( - *, - prediction: Image.Image, - reference: Image.Image, -) -> tuple[float, float]: - pred_tensor = _pil_to_batched_tensor(prediction) - ref_tensor = _pil_to_batched_tensor(reference) - - ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0) - psnr_metric = PeakSignalNoiseRatio(data_range=1.0) - - ssim_value = float(ssim_metric(pred_tensor, ref_tensor).item()) - psnr_value = float(psnr_metric(pred_tensor, ref_tensor).item()) - return ssim_value, psnr_value - - -def _pil_to_batched_tensor(image: Image.Image) -> torch.Tensor: - array = np.asarray(image.convert("RGB"), dtype=np.float32) / 255.0 - tensor = torch.from_numpy(array).permute(2, 0, 1).unsqueeze(0) - return tensor diff --git a/tests/e2e/offline_inference/test_voxcpm.py b/tests/e2e/offline_inference/test_voxcpm.py index 24f94b08758..bda087612de 100644 --- a/tests/e2e/offline_inference/test_voxcpm.py +++ b/tests/e2e/offline_inference/test_voxcpm.py @@ -13,8 +13,8 @@ import torch import tests.helpers.runtime as omni_runtime +from tests.helpers.mark import hardware_test from tests.helpers.runtime import OmniRunner -from tests.utils import hardware_test from vllm_omni.model_executor.models.voxcpm.voxcpm_runtime_utils import ( prepare_voxcpm_hf_config_dir, resolve_voxcpm_model_dir, From 4a6c4e241f7c0e5f829b5c7166194138acaa072d Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Wed, 15 Apr 2026 22:46:11 +0800 Subject: [PATCH 11/19] Refactor timeout argument handling in omni_server fixture to ensure compatibility with non-omni paths. Timeout flags are now gated behind the use_omni parameter, aligning with legacy behavior and improving code clarity. Signed-off-by: wangyu <410167048@qq.com> --- tests/helpers/fixtures/runtime.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/helpers/fixtures/runtime.py b/tests/helpers/fixtures/runtime.py index 5289142a08e..ffe21723409 100644 --- a/tests/helpers/fixtures/runtime.py +++ b/tests/helpers/fixtures/runtime.py @@ -41,14 +41,17 @@ def omni_server(request: pytest.FixtureRequest, run_level: str, model_prefix: st ) server_args = params.server_args or [] - if params.use_omni and params.stage_init_timeout is not None: - server_args = [*server_args, "--stage-init-timeout", str(params.stage_init_timeout)] - else: - server_args = [*server_args, "--stage-init-timeout", "600"] - if params.init_timeout is not None: - server_args = [*server_args, "--init-timeout", str(params.init_timeout)] - else: - server_args = [*server_args, "--init-timeout", "900"] + # Upstream vLLM non-omni path does not accept omni-specific timeout args. + # Keep timeout flags gated behind use_omni (matches legacy conftest behavior). + if params.use_omni: + if params.stage_init_timeout is not None: + server_args = [*server_args, "--stage-init-timeout", str(params.stage_init_timeout)] + else: + server_args = [*server_args, "--stage-init-timeout", "600"] + if params.init_timeout is not None: + server_args = [*server_args, "--init-timeout", str(params.init_timeout)] + else: + server_args = [*server_args, "--init-timeout", "900"] if params.use_stage_cli: if not params.use_omni: From 64e6f7f4e39f367120108b2274f8afaa928c3ef4 Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Thu, 16 Apr 2026 16:17:01 +0800 Subject: [PATCH 12/19] Add pytest_terminal_summary hook to conftest.py for Buildkite log folding Signed-off-by: wangyu <410167048@qq.com> --- tests/conftest.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 3bad7e02b2b..bd5f8baafde 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,12 @@ "tests.helpers.fixtures.runtime", ) + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + # Marker for Buildkite log folding before pytest summary lines. + terminalreporter.write_sep("-", "Result Summary") + + # Backward-compatible re-exports. # (Many tests still import from `tests.conftest`; migrate these imports to `tests.helpers.*` over time.) from tests.helpers.assertions import ( # noqa: F401,E402 From 8528717d48e1055e4916395430458354ad40d571 Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Thu, 16 Apr 2026 17:51:30 +0800 Subject: [PATCH 13/19] Add conftest.py for DFX benchmarks with configuration loading and parameter mapping functions Signed-off-by: wangyu <410167048@qq.com> --- tests/dfx/{helpers.py => conftest.py} | 0 tests/dfx/perf/scripts/run_benchmark.py | 2 +- tests/dfx/stability/scripts/test_benchmark_stability.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename tests/dfx/{helpers.py => conftest.py} (100%) diff --git a/tests/dfx/helpers.py b/tests/dfx/conftest.py similarity index 100% rename from tests/dfx/helpers.py rename to tests/dfx/conftest.py diff --git a/tests/dfx/perf/scripts/run_benchmark.py b/tests/dfx/perf/scripts/run_benchmark.py index 1497e463739..c8f06136b99 100644 --- a/tests/dfx/perf/scripts/run_benchmark.py +++ b/tests/dfx/perf/scripts/run_benchmark.py @@ -8,7 +8,7 @@ import pytest -from tests.dfx.helpers import ( +from tests.dfx.conftest import ( create_benchmark_indices, create_test_parameter_mapping, create_unique_server_params, diff --git a/tests/dfx/stability/scripts/test_benchmark_stability.py b/tests/dfx/stability/scripts/test_benchmark_stability.py index a8dfa2a7cff..2692312ce4b 100644 --- a/tests/dfx/stability/scripts/test_benchmark_stability.py +++ b/tests/dfx/stability/scripts/test_benchmark_stability.py @@ -24,7 +24,7 @@ import pytest -from tests.dfx.helpers import ( +from tests.dfx.conftest import ( create_benchmark_indices, create_test_parameter_mapping, create_unique_server_params, From 52be8a1bfc3a81cea661ad6a0e1308d4eb51181d Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Thu, 16 Apr 2026 20:00:08 +0800 Subject: [PATCH 14/19] Refactor tests and configuration files for improved clarity and performance Signed-off-by: wangyu <410167048@qq.com> --- pyproject.toml | 1 - tests/conftest.py | 70 +++++++++++++++---------- tests/dfx/perf/scripts/run_benchmark.py | 2 +- tests/helpers/assertions.py | 58 +++++++++++++++----- tests/helpers/fixtures/runtime.py | 25 ++++----- 5 files changed, 96 insertions(+), 60 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 58261f74ec6..9b034a7c8e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,7 +193,6 @@ markers = [ # specified computation resources marks (auto-added) "H100: Tests that require H100 GPU", "L4: Tests that require L4 GPU", - "B60: Tests that require B60", "MI325: Tests that require MI325 GPU (AMD/ROCm)", "B60: Tests that require Intel Arc Pro B60 XPU", "S5000: Tests that require S5000 GPU (Moore Threads/MUSA)", diff --git a/tests/conftest.py b/tests/conftest.py index bd5f8baafde..f8b36ec3264 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,36 +18,34 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): # Marker for Buildkite log folding before pytest summary lines. - terminalreporter.write_sep("-", "Result Summary") + terminalreporter.write_line("--- Running Summary") -# Backward-compatible re-exports. +# Backward-compatible lazy re-exports. # (Many tests still import from `tests.conftest`; migrate these imports to `tests.helpers.*` over time.) -from tests.helpers.assertions import ( # noqa: F401,E402 - assert_audio_speech_response, - assert_diffusion_response, - assert_image_diffusion_response, - assert_image_valid, - assert_omni_response, - assert_video_diffusion_response, - assert_video_valid, +# Keep these lazy so conftest import does not trigger heavy helper dependencies. +_ASSERTION_EXPORT_NAMES = ( + "assert_audio_speech_response", + "assert_diffusion_response", + "assert_image_diffusion_response", + "assert_image_valid", + "assert_omni_response", + "assert_video_diffusion_response", + "assert_video_valid", ) -from tests.helpers.media import ( # noqa: F401,E402 - convert_audio_bytes_to_text, - convert_audio_file_to_text, - cosine_similarity_text, - decode_b64_image, - generate_synthetic_audio, - generate_synthetic_image, - generate_synthetic_video, +_MEDIA_EXPORT_NAMES = ( + "convert_audio_bytes_to_text", + "convert_audio_file_to_text", + "cosine_similarity_text", + "decode_b64_image", + "generate_synthetic_audio", + "generate_synthetic_image", + "generate_synthetic_video", ) -from tests.helpers.stage_config import ( # noqa: F401,E402 - dummy_messages_from_mix_data, - modify_stage_config, +_STAGE_CONFIG_EXPORT_NAMES = ( + "dummy_messages_from_mix_data", + "modify_stage_config", ) - -# Lazy: importing `tests.helpers.runtime` at conftest load runs before session -# autouse fixtures and can scramble vLLM/vllm_omni init order. _RUNTIME_EXPORT_NAMES = ( "DiffusionResponse", "OmniResponse", @@ -58,15 +56,29 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): "OmniServerStageCli", "OpenAIClientHandler", ) +_LAZY_EXPORT_MODULES = { + **{name: "tests.helpers.assertions" for name in _ASSERTION_EXPORT_NAMES}, + **{name: "tests.helpers.media" for name in _MEDIA_EXPORT_NAMES}, + **{name: "tests.helpers.stage_config" for name in _STAGE_CONFIG_EXPORT_NAMES}, + **{name: "tests.helpers.runtime" for name in _RUNTIME_EXPORT_NAMES}, +} def __getattr__(name: str): - if name in _RUNTIME_EXPORT_NAMES: - import tests.helpers.runtime as _runtime - - return getattr(_runtime, name) + module_name = _LAZY_EXPORT_MODULES.get(name) + if module_name is not None: + module = __import__(module_name, fromlist=[name]) + return getattr(module, name) raise AttributeError(f"module {__name__!r} has no attribute {name!r}") def __dir__(): - return sorted({*globals(), *_RUNTIME_EXPORT_NAMES}) + return sorted( + { + *globals(), + *_ASSERTION_EXPORT_NAMES, + *_MEDIA_EXPORT_NAMES, + *_STAGE_CONFIG_EXPORT_NAMES, + *_RUNTIME_EXPORT_NAMES, + } + ) diff --git a/tests/dfx/perf/scripts/run_benchmark.py b/tests/dfx/perf/scripts/run_benchmark.py index c8f06136b99..c94a4ce39b7 100644 --- a/tests/dfx/perf/scripts/run_benchmark.py +++ b/tests/dfx/perf/scripts/run_benchmark.py @@ -64,7 +64,7 @@ def omni_server(request): print(f"Starting OmniServer with test: {test_name}, model: {model}") - server_args = ["--stage-init-timeout", "300", "--init-timeout", "900"] + server_args = ["--stage-init-timeout", "600", "--init-timeout", "900"] if stage_config_path: server_args = ["--stage-configs-path", stage_config_path] + server_args with OmniServer(model, server_args) as server: diff --git a/tests/helpers/assertions.py b/tests/helpers/assertions.py index 23bfa67fc4b..533806477e5 100644 --- a/tests/helpers/assertions.py +++ b/tests/helpers/assertions.py @@ -10,7 +10,6 @@ import numpy as np import soundfile as sf from PIL import Image -from transformers import pipeline from tests.helpers.media import cosine_similarity_text @@ -246,6 +245,8 @@ def _load_gender_pipeline(): return _GENDER_PIPELINE model_name = "7wolf/wav2vec2-base-gender-classification" try: + from transformers import pipeline + _GENDER_PIPELINE = pipeline(task="audio-classification", model=model_name, device=-1) return _GENDER_PIPELINE except Exception as exc: # pragma: no cover @@ -341,6 +342,7 @@ def _estimate_voice_gender_from_audio(audio_bytes: bytes) -> str: def _assert_preset_voice_gender_from_audio(audio_bytes: bytes | None, voice_name: str | None) -> None: + """If ``voice_name`` matches a known preset, assert classifier gender matches (skip when unknown).""" if not voice_name or not audio_bytes: return key = str(voice_name).lower() @@ -350,7 +352,9 @@ def _assert_preset_voice_gender_from_audio(audio_bytes: bytes | None, voice_name estimated_gender = _estimate_voice_gender_from_audio(audio_bytes) print(f"Preset voice gender check: preset={key!r}, estimated={estimated_gender!r}, expected={expected_gender!r}") if estimated_gender != "unknown": - assert estimated_gender == expected_gender + assert estimated_gender == expected_gender, ( + f"{voice_name!r} is expected {expected_gender}, but estimated gender is {estimated_gender!r}" + ) def _compute_pcm_hnr_db(pcm_samples: np.ndarray, sr: int = _PCM_SPEECH_SAMPLE_RATE_HZ) -> float: @@ -380,39 +384,67 @@ def _assert_pcm_int16_speech_hnr(audio_bytes: bytes) -> None: pcm_samples = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0 hnr = _compute_pcm_hnr_db(pcm_samples) print(f"PCM speech HNR: {hnr:.2f} dB (threshold: {_MIN_PCM_SPEECH_HNR_DB} dB)") - assert hnr >= _MIN_PCM_SPEECH_HNR_DB + assert hnr >= _MIN_PCM_SPEECH_HNR_DB, ( + f"Audio distortion detected: HNR={hnr:.2f} dB < {_MIN_PCM_SPEECH_HNR_DB} dB. " + "Voice clone decoder may be losing ref_code speaker context on later chunks." + ) def assert_omni_response(response: Any, request_config: dict[str, Any], run_level): + """ + Validate response results. + + Args: + response: OmniResponse object + + Raises: + AssertionError: When the response does not meet validation criteria + """ assert response.success, "The request failed." - e2e_latency = getattr(response, "e2e_latency", None) + e2e_latency = response.e2e_latency if e2e_latency is not None: print(f"the e2e latency is: {e2e_latency}") modalities = request_config.get("modalities", ["text", "audio"]) + if run_level == "advanced_model": if "audio" in modalities: assert response.audio_content is not None, "No audio output is generated" print(f"audio content is: {response.audio_content}") speaker = request_config.get("speaker") if speaker: - _assert_preset_voice_gender_from_audio(response.audio_bytes, speaker) + _assert_preset_voice_gender_from_audio( + response.audio_bytes, + speaker, + ) + if "text" in modalities: assert response.text_content is not None, "No text output is generated" print(f"text content is: {response.text_content}") + + # Verify image description + word_types = ["text", "image", "audio", "video"] keywords_dict = request_config.get("key_words", {}) - for word_type in ["text", "image", "audio", "video"]: + for word_type in word_types: keywords = keywords_dict.get(word_type) - if not keywords: - continue if "text" in modalities: - text_lower = (response.text_content or "").lower() - assert any(str(kw).lower() in text_lower for kw in keywords) + if keywords: + text_lower = response.text_content.lower() + assert any(str(kw).lower() in text_lower for kw in keywords), ( + "The output does not contain any of the keywords." + ) else: - audio_lower = (response.audio_content or "").lower() - assert any(str(kw).lower() in audio_lower for kw in keywords) + if keywords: + audio_lower = response.audio_content.lower() + assert any(str(kw).lower() in audio_lower for kw in keywords), ( + "The output does not contain any of the keywords." + ) + + # Verify similarity (Whisper transcript vs streamed/detokenized text) if "text" in modalities and "audio" in modalities: - assert response.similarity is not None and response.similarity > 0.9 + assert response.similarity is not None and response.similarity > 0.9, ( + "The audio content is not same as the text" + ) print(f"similarity is: {response.similarity}") diff --git a/tests/helpers/fixtures/runtime.py b/tests/helpers/fixtures/runtime.py index ffe21723409..27216ac60b1 100644 --- a/tests/helpers/fixtures/runtime.py +++ b/tests/helpers/fixtures/runtime.py @@ -73,22 +73,15 @@ def omni_server(request: pytest.FixtureRequest, run_level: str, model_prefix: st if stage_config_path is not None: server_args += ["--stage-configs-path", stage_config_path] - with ( - OmniServer( - model, - server_args, - port=port, - env_dict=params.env_dict, - use_omni=params.use_omni, - ) - if port - else OmniServer( - model, - server_args, - env_dict=params.env_dict, - use_omni=params.use_omni, - ) - ) as server: + kwargs: dict[str, Any] = dict( + model=model, + serve_args=server_args, + env_dict=params.env_dict, + use_omni=params.use_omni, + ) + if port is not None: + kwargs["port"] = port + with OmniServer(**kwargs) as server: print("OmniServer started successfully") yield server print("OmniServer stopping...") From 15bd4aa3c50184f9521208d21f7af864396c8cd4 Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Fri, 17 Apr 2026 10:46:08 +0800 Subject: [PATCH 15/19] Refactor import statement in e2e test for Flux2 Klein inpaint expansion Signed-off-by: wangyu <410167048@qq.com> --- tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py b/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py index f59a0e783d7..0c45bb33f1d 100644 --- a/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py +++ b/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py @@ -14,7 +14,7 @@ import pytest from PIL import Image, ImageDraw -from tests.conftest import OmniServer, OmniServerParams +from tests.helpers.runtime import OmniServer, OmniServerParams MODEL = "black-forest-labs/FLUX.2-klein-4B" From cc7d7067f19d284b44ea760e902ac992c7a9a5b1 Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Fri, 17 Apr 2026 11:18:09 +0800 Subject: [PATCH 16/19] Fix import path in conftest.py for stability tests Signed-off-by: wangyu <410167048@qq.com> --- tests/dfx/stability/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dfx/stability/conftest.py b/tests/dfx/stability/conftest.py index be816ff4557..e36c88b9aa6 100644 --- a/tests/dfx/stability/conftest.py +++ b/tests/dfx/stability/conftest.py @@ -6,7 +6,7 @@ import pytest -from .helpers import ( +from tests.dfx.stability.helpers import ( finalize_resource_monitor, report_latest_gpu_samples, start_resource_monitor, From 2bacffcb82fd4ebe95eea20bb701dc12fa31d6d0 Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Fri, 17 Apr 2026 16:58:23 +0800 Subject: [PATCH 17/19] Update pytest command in nightly test configuration to use marker for advanced model Signed-off-by: wangyu <410167048@qq.com> --- .buildkite/test-nightly.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/test-nightly.yml b/.buildkite/test-nightly.yml index ac43b597d15..6739a5ad9bf 100644 --- a/.buildkite/test-nightly.yml +++ b/.buildkite/test-nightly.yml @@ -506,7 +506,7 @@ steps: - label: ":full_moon: Diffusion X2V · Accuracy Test" timeout_in_minutes: 180 commands: - - pytest -s -v tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py --run-level advanced_model + - pytest -s -v tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py -m advanced_model --run-level advanced_model agents: queue: "mithril-h100-pool" plugins: From e6e649af3ca940bda76b08aaf2dce387b7d9bc2c Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Mon, 20 Apr 2026 10:18:46 +0800 Subject: [PATCH 18/19] Refactor stage config exports and update test imports - Removed unused export `dummy_messages_from_mix_data` from `_STAGE_CONFIG_EXPORT_NAMES`. - Added `dummy_messages_from_mix_data` to the imports in multiple test files for consistency. - Adjusted the `_REPO_ROOT` path comment in `stage_config.py` for clarity. Signed-off-by: wangyu <410167048@qq.com> --- tests/conftest.py | 6 ++---- tests/e2e/online_serving/test_bagel_expansion.py | 3 +-- tests/e2e/online_serving/test_flux_2_dev_expansion.py | 3 +-- tests/e2e/online_serving/test_flux_kontext_expansion.py | 3 +-- .../e2e/online_serving/test_longcat_image_edit_expansion.py | 3 +-- tests/e2e/online_serving/test_longcat_image_expansion.py | 3 +-- tests/e2e/online_serving/test_mimo_audio.py | 4 ++-- tests/e2e/online_serving/test_qwen_image_edit_expansion.py | 3 +-- tests/e2e/online_serving/test_qwen_image_expansion.py | 3 +-- .../e2e/online_serving/test_qwen_image_layered_expansion.py | 3 +-- tests/helpers/stage_config.py | 3 ++- 11 files changed, 14 insertions(+), 23 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ed30aa924c1..77075f9525a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,10 +42,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): "generate_synthetic_image", "generate_synthetic_video", ) -_STAGE_CONFIG_EXPORT_NAMES = ( - "dummy_messages_from_mix_data", - "modify_stage_config", -) +_STAGE_CONFIG_EXPORT_NAMES = ("modify_stage_config",) _RUNTIME_EXPORT_NAMES = ( "DiffusionResponse", "OmniResponse", @@ -55,6 +52,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): "OmniServerParams", "OmniServerStageCli", "OpenAIClientHandler", + "dummy_messages_from_mix_data", ) _LAZY_EXPORT_MODULES = { **{name: "tests.helpers.assertions" for name in _ASSERTION_EXPORT_NAMES}, diff --git a/tests/e2e/online_serving/test_bagel_expansion.py b/tests/e2e/online_serving/test_bagel_expansion.py index 3509945f95a..aa289c6d9a2 100644 --- a/tests/e2e/online_serving/test_bagel_expansion.py +++ b/tests/e2e/online_serving/test_bagel_expansion.py @@ -17,8 +17,7 @@ import pytest from tests.helpers.mark import hardware_marks -from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler -from tests.helpers.stage_config import dummy_messages_from_mix_data +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data PROMPT = "A futuristic city skyline at twilight, cyberpunk style, ultra-detailed, high resolution." NEGATIVE_PROMPT = "low quality, blurry, distorted, deformed, watermark" diff --git a/tests/e2e/online_serving/test_flux_2_dev_expansion.py b/tests/e2e/online_serving/test_flux_2_dev_expansion.py index 6d98c9d1bd7..addf6f00248 100644 --- a/tests/e2e/online_serving/test_flux_2_dev_expansion.py +++ b/tests/e2e/online_serving/test_flux_2_dev_expansion.py @@ -15,8 +15,7 @@ import pytest from tests.helpers.mark import hardware_marks -from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler -from tests.helpers.stage_config import dummy_messages_from_mix_data +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data MODEL = "black-forest-labs/FLUX.2-dev" PROMPT = "A cinematic mountain landscape at sunrise, dramatic clouds, ultra-detailed, realistic photography." diff --git a/tests/e2e/online_serving/test_flux_kontext_expansion.py b/tests/e2e/online_serving/test_flux_kontext_expansion.py index c85d8a3c3c0..574bb3db4a9 100644 --- a/tests/e2e/online_serving/test_flux_kontext_expansion.py +++ b/tests/e2e/online_serving/test_flux_kontext_expansion.py @@ -6,8 +6,7 @@ import pytest from tests.helpers.media import generate_synthetic_image -from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler -from tests.helpers.stage_config import dummy_messages_from_mix_data +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data EDIT_PROMPT = "Transform this modern, geometrist image into a Vincent van Gogh style impressionist painting." NEGATIVE_PROMPT = "blurry, low quality, modern, geometrist" diff --git a/tests/e2e/online_serving/test_longcat_image_edit_expansion.py b/tests/e2e/online_serving/test_longcat_image_edit_expansion.py index 28f6c3de005..9a96280f2e9 100644 --- a/tests/e2e/online_serving/test_longcat_image_edit_expansion.py +++ b/tests/e2e/online_serving/test_longcat_image_edit_expansion.py @@ -15,8 +15,7 @@ from tests.helpers.mark import hardware_marks from tests.helpers.media import generate_synthetic_image -from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler -from tests.helpers.stage_config import dummy_messages_from_mix_data +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data EDIT_PROMPT = "Transform this modern image into a cinematic animation style with vibrant colors and soft lighting." NEGATIVE_PROMPT = "blurry, low quality, distorted, oversaturated" diff --git a/tests/e2e/online_serving/test_longcat_image_expansion.py b/tests/e2e/online_serving/test_longcat_image_expansion.py index f0b0ca905d0..7d12db6a2ca 100644 --- a/tests/e2e/online_serving/test_longcat_image_expansion.py +++ b/tests/e2e/online_serving/test_longcat_image_expansion.py @@ -14,8 +14,7 @@ import pytest from tests.helpers.mark import hardware_marks -from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler -from tests.helpers.stage_config import dummy_messages_from_mix_data +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data TEXT_TO_IMAGE_PROMPT = ( "A cinematic illustration of a cat typing on a silver laptop, soft window light, highly detailed." diff --git a/tests/e2e/online_serving/test_mimo_audio.py b/tests/e2e/online_serving/test_mimo_audio.py index 3349e0a8a63..331baf18a54 100644 --- a/tests/e2e/online_serving/test_mimo_audio.py +++ b/tests/e2e/online_serving/test_mimo_audio.py @@ -11,8 +11,8 @@ from tests.helpers.mark import hardware_test from tests.helpers.media import generate_synthetic_audio -from tests.helpers.runtime import OmniServerParams -from tests.helpers.stage_config import dummy_messages_from_mix_data, modify_stage_config +from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data +from tests.helpers.stage_config import modify_stage_config from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) diff --git a/tests/e2e/online_serving/test_qwen_image_edit_expansion.py b/tests/e2e/online_serving/test_qwen_image_edit_expansion.py index 8b3f9b0e95b..c2461977d64 100644 --- a/tests/e2e/online_serving/test_qwen_image_edit_expansion.py +++ b/tests/e2e/online_serving/test_qwen_image_edit_expansion.py @@ -9,8 +9,7 @@ from tests.helpers.mark import hardware_marks from tests.helpers.media import generate_synthetic_image -from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler -from tests.helpers.stage_config import dummy_messages_from_mix_data +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data EDIT_PROMPT = "Transform this modern, geometrist image into a Vincent van Gogh style impressionist painting." MULTI_EDIT_PROMPT = ( diff --git a/tests/e2e/online_serving/test_qwen_image_expansion.py b/tests/e2e/online_serving/test_qwen_image_expansion.py index 7c31042694e..b6f91c13daa 100644 --- a/tests/e2e/online_serving/test_qwen_image_expansion.py +++ b/tests/e2e/online_serving/test_qwen_image_expansion.py @@ -13,8 +13,7 @@ import pytest from tests.helpers.mark import hardware_marks -from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler -from tests.helpers.stage_config import dummy_messages_from_mix_data +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data T2I_PROMPT = "A photo of a cat sitting on a laptop keyboard, digital art style." NEGATIVE_PROMPT = "blurry, low quality" diff --git a/tests/e2e/online_serving/test_qwen_image_layered_expansion.py b/tests/e2e/online_serving/test_qwen_image_layered_expansion.py index 8f08be928fc..b958cfc054c 100644 --- a/tests/e2e/online_serving/test_qwen_image_layered_expansion.py +++ b/tests/e2e/online_serving/test_qwen_image_layered_expansion.py @@ -16,8 +16,7 @@ from tests.helpers.mark import hardware_marks from tests.helpers.media import decode_b64_image, generate_synthetic_image -from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler -from tests.helpers.stage_config import dummy_messages_from_mix_data +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data MODEL = "Qwen/Qwen-Image-Layered" EDIT_PROMPT = "Decompose this image into layers." diff --git a/tests/helpers/stage_config.py b/tests/helpers/stage_config.py index 1b901c356b5..29a80372ecf 100644 --- a/tests/helpers/stage_config.py +++ b/tests/helpers/stage_config.py @@ -253,7 +253,8 @@ def delete_by_path(config_dict: dict, path: str) -> None: return str(output_path) -_REPO_ROOT = Path(__file__).resolve().parent.parent +# ``stage_config.py`` lives under ``tests/helpers/``; repo root is three parents up. +_REPO_ROOT = Path(__file__).resolve().parent.parent.parent _DEPLOY_DIR = _REPO_ROOT / "vllm_omni" / "deploy" _CI_GENERATED_DIR = _REPO_ROOT / "tests" / ".ci_generated" From d62b67a2eb671014659fdc8a9841d7d02fd73195 Mon Sep 17 00:00:00 2001 From: wangyu <410167048@qq.com> Date: Mon, 20 Apr 2026 10:59:12 +0800 Subject: [PATCH 19/19] Refactor test imports and update helper module paths Signed-off-by: wangyu <410167048@qq.com> --- docs/contributing/ci/tests_style.md | 12 +++++------- tests/e2e/offline_inference/test_ming_flash_omni.py | 6 +++--- tests/e2e/online_serving/test_ming_flash_omni.py | 9 ++++----- tests/e2e/online_serving/test_nextstep_expansion.py | 4 ++-- .../test_qwen3_omni_realtime_websocket.py | 8 ++++---- tests/e2e/online_serving/test_qwen3_tts_websocket.py | 2 +- tests/engine/test_async_omni_engine_abort.py | 2 +- tests/test_config_factory.py | 6 +++--- 8 files changed, 23 insertions(+), 26 deletions(-) diff --git a/docs/contributing/ci/tests_style.md b/docs/contributing/ci/tests_style.md index 323ed3bf535..3a8cb0f127c 100644 --- a/docs/contributing/ci/tests_style.md +++ b/docs/contributing/ci/tests_style.md @@ -221,15 +221,13 @@ from pathlib import Path import openai import pytest -from tests.conftest import ( - OmniServer, - convert_audio_to_text, +from tests.helpers.media import ( + convert_audio_bytes_to_text, cosine_similarity_text, - dummy_messages_from_mix_data, generate_synthetic_video, - merge_base64_and_convert_to_text, ) -from tests.utils import get_deploy_config_path +from tests.helpers.runtime import OmniServer, dummy_messages_from_mix_data +from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config from vllm_omni.platforms import current_omni_platform # Edit: model name and stage config path @@ -406,7 +404,7 @@ def test_mix_to_text_audio_001(client: openai.OpenAI, omni_server, request) -> N # PURPOSE: Verify text and audio outputs convey the same information # CUSTOMIZATION: Adjust similarity threshold (0.9) based on accuracy requirements assert audio_data is not None, "No audio output is generated" - audio_content = merge_base64_and_convert_to_text(audio_data) + audio_content = convert_audio_bytes_to_text(audio_data) print(f"text content is: {text_content}") print(f"audio content is: {audio_content}") similarity = cosine_similarity_text(audio_content.lower(), text_content.lower()) diff --git a/tests/e2e/offline_inference/test_ming_flash_omni.py b/tests/e2e/offline_inference/test_ming_flash_omni.py index be0ed3b056f..c591e910ac3 100644 --- a/tests/e2e/offline_inference/test_ming_flash_omni.py +++ b/tests/e2e/offline_inference/test_ming_flash_omni.py @@ -10,13 +10,13 @@ import pytest -from tests.conftest import ( +from tests.helpers.mark import hardware_test +from tests.helpers.media import ( generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video, - modify_stage_config, ) -from tests.utils import hardware_test +from tests.helpers.stage_config import modify_stage_config models = ["Jonathan1909/Ming-flash-omni-2.0"] diff --git a/tests/e2e/online_serving/test_ming_flash_omni.py b/tests/e2e/online_serving/test_ming_flash_omni.py index 35b7b64c061..8161c438929 100644 --- a/tests/e2e/online_serving/test_ming_flash_omni.py +++ b/tests/e2e/online_serving/test_ming_flash_omni.py @@ -10,15 +10,14 @@ import pytest -from tests.conftest import ( - OmniServerParams, - dummy_messages_from_mix_data, +from tests.helpers.mark import hardware_test +from tests.helpers.media import ( generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video, - modify_stage_config, ) -from tests.utils import hardware_test +from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data +from tests.helpers.stage_config import modify_stage_config os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" diff --git a/tests/e2e/online_serving/test_nextstep_expansion.py b/tests/e2e/online_serving/test_nextstep_expansion.py index cd3d7f9bca8..004a0967b56 100644 --- a/tests/e2e/online_serving/test_nextstep_expansion.py +++ b/tests/e2e/online_serving/test_nextstep_expansion.py @@ -6,13 +6,13 @@ import pytest -from tests.conftest import ( +from tests.helpers.mark import hardware_marks +from tests.helpers.runtime import ( OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data, ) -from tests.utils import hardware_marks # L4: 4 GPUs + TP=4; XPU B60: 2 cards (use num_cards={"cuda": 4, "xpu": 4} if needed) FOUR_CARD_MARKS = hardware_marks( diff --git a/tests/e2e/online_serving/test_qwen3_omni_realtime_websocket.py b/tests/e2e/online_serving/test_qwen3_omni_realtime_websocket.py index 6a7cf1c67ec..f3b26108199 100644 --- a/tests/e2e/online_serving/test_qwen3_omni_realtime_websocket.py +++ b/tests/e2e/online_serving/test_qwen3_omni_realtime_websocket.py @@ -16,14 +16,14 @@ import pytest import websockets -from tests.conftest import ( - OmniServerParams, +from tests.helpers.mark import hardware_test +from tests.helpers.media import ( convert_audio_bytes_to_text, cosine_similarity_text, generate_synthetic_audio, - modify_stage_config, ) -from tests.utils import get_deploy_config_path, hardware_test +from tests.helpers.runtime import OmniServerParams +from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" diff --git a/tests/e2e/online_serving/test_qwen3_tts_websocket.py b/tests/e2e/online_serving/test_qwen3_tts_websocket.py index 836fee9ae5f..5ac021cf88b 100644 --- a/tests/e2e/online_serving/test_qwen3_tts_websocket.py +++ b/tests/e2e/online_serving/test_qwen3_tts_websocket.py @@ -11,8 +11,8 @@ import pytest import websockets -from tests.conftest import OmniServer from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServer from tests.helpers.stage_config import get_deploy_config_path os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" diff --git a/tests/engine/test_async_omni_engine_abort.py b/tests/engine/test_async_omni_engine_abort.py index 71d7ffe53e2..eda7a7a788e 100644 --- a/tests/engine/test_async_omni_engine_abort.py +++ b/tests/engine/test_async_omni_engine_abort.py @@ -15,7 +15,7 @@ SEED = 42 -# Single-stage thinker-only deploy, materialized from tests.utils._CI_OVERLAYS. +# Single-stage thinker-only deploy, materialized from tests.helpers.stage_config._CI_OVERLAYS. stage_config = get_deploy_config_path("ci/qwen2_5_omni_thinker_only.yaml") model = "Qwen/Qwen2.5-Omni-7B" diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index 1d65d3acd27..de403df2d09 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -945,7 +945,7 @@ class TestBaseConfigInheritance: """Test deploy YAML base_config inheritance.""" def test_ci_inherits_from_main(self): - from tests.utils import get_deploy_config_path + from tests.helpers.stage_config import get_deploy_config_path from vllm_omni.config.stage_config import load_deploy_config ci_path = Path(get_deploy_config_path("ci/qwen3_omni_moe.yaml")) @@ -962,12 +962,12 @@ def test_ci_inherits_from_main(self): assert deploy.connectors is not None assert "connector_of_shared_memory" in deploy.connectors # CI overlay explicitly sets async_chunk: False (see - # tests/utils.py::_CI_OVERLAYS and PR #2383 discussion). Overlay + # tests.helpers.stage_config._CI_OVERLAYS and PR #2383 discussion). Overlay # bool overrides base even when the base yaml has async_chunk: true. assert deploy.async_chunk is False def test_ci_sampling_merge(self): - from tests.utils import get_deploy_config_path + from tests.helpers.stage_config import get_deploy_config_path from vllm_omni.config.stage_config import load_deploy_config ci_path = Path(get_deploy_config_path("ci/qwen3_omni_moe.yaml"))