diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 837ca2d58b..6f1e285408 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -85,6 +85,7 @@ jobs: . - name: Run all tests env: + NV_INFERENCE_API_KEY: ${{ secrets.NV_INFERENCE_API_KEY }} NVIDIA_API_KEY: ${{ secrets.NVIDIA_API_KEY }} HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | diff --git a/.gitignore b/.gitignore index e5adf3c582..e577e9265b 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,6 @@ AGENTS.md .codex .idea + +#scripts at root level +/*.sh diff --git a/nemo_skills/inference/generate.py b/nemo_skills/inference/generate.py index b3f0094c1a..21b559e10f 100644 --- a/nemo_skills/inference/generate.py +++ b/nemo_skills/inference/generate.py @@ -210,7 +210,7 @@ class GenerationTaskConfig: enable_litellm_cache: bool = False # List of content types to drop from messages (e.g., base64 audio) to keep output files smaller - drop_content_types: list[str] = field(default_factory=lambda: ["audio_url"]) + drop_content_types: list[str] = field(default_factory=lambda: ["audio_url", "input_audio"]) # Audio configuration - set by benchmarks that need audio processing (mmau-pro, audiobench, etc.) enable_audio: bool = False # Enable audio preprocessing (set by benchmark configs) @@ -632,7 +632,7 @@ def drop_fields_from_messages(self, output): # Filter out content types specified in drop_content_types config message["content"] = [ - content for content in message["content"] if content.get("type") not in self.cfg.drop_content_types + content for content in message["content"] if content["type"] not in self.cfg.drop_content_types ] async def postprocess_single_output(self, output, original_data_point): diff --git a/nemo_skills/inference/model/audio_utils.py b/nemo_skills/inference/model/audio_utils.py index 02c8eaf459..d4b4c44dbd 100644 --- a/nemo_skills/inference/model/audio_utils.py +++ b/nemo_skills/inference/model/audio_utils.py @@ -145,6 +145,8 @@ def make_audio_content_block(base64_audio: str, audio_format: str = "audio_url") if audio_format == "input_audio": # OpenAI native format (works with NVIDIA API / Gemini / Azure) return {"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}} - else: + elif audio_format == "audio_url": # Data URI format (works with vLLM / Qwen) return {"type": "audio_url", "audio_url": {"url": f"data:audio/wav;base64,{base64_audio}"}} + else: + raise ValueError(f"Unsupported audio_format '{audio_format}'. Use 'audio_url' or 'input_audio'.") diff --git a/nemo_skills/inference/model/vllm_multimodal.py b/nemo_skills/inference/model/vllm_multimodal.py index c2285b0ccc..4b8006abe5 100644 --- a/nemo_skills/inference/model/vllm_multimodal.py +++ b/nemo_skills/inference/model/vllm_multimodal.py @@ -17,6 +17,7 @@ This module provides a multimodal model class that handles: - Audio INPUT: encoding audio files to base64, chunking long audio - Audio OUTPUT: saving audio responses from the server to disk +- External API support: NVIDIA Inference API, OpenAI, and other OpenAI-compatible APIs """ import base64 @@ -32,6 +33,7 @@ audio_file_to_base64, chunk_audio, load_audio_file, + make_audio_content_block, save_audio_chunk_to_base64, ) from .vllm import VLLMModel @@ -43,10 +45,10 @@ class VLLMMultimodalModel(VLLMModel): - """VLLMModel with support for audio input and output. + """VLLMModel with support for audio input/output and external APIs. Audio INPUT capabilities: - 1. Converts audio file paths to base64-encoded audio_url format + 1. Converts audio file paths to base64-encoded input_audio format 2. Chunks long audio files for models with duration limits 3. Aggregates results from chunked audio processing @@ -54,30 +56,56 @@ class VLLMMultimodalModel(VLLMModel): 1. Saves audio responses from the server to disk / output_dir/audio/ 2. Replaces the base64 data with the file path in the result + Also supports external APIs (NVIDIA, OpenAI) via base_url parameter. + + Example usage: + # Local vLLM server + model = VLLMMultimodalModel(model="Qwen/Qwen2-Audio-7B") + + # NVIDIA Inference API + model = VLLMMultimodalModel( + model="gcp/google/gemini-2.5-pro", + base_url="https://inference-api.nvidia.com/v1" + ) """ def __init__( self, + model: str | None = None, + base_url: str | None = None, enable_audio_chunking: bool = True, audio_chunk_task_types: list[str] | None = None, chunk_audio_threshold_sec: int = 30, + audio_format: str | None = None, **kwargs, ): - """Initialize VLLMMultimodalModel with audio I/O support. + """Initialize VLLMMultimodalModel with audio I/O and external API support. Args: + model: Model name (e.g., "Qwen/Qwen2-Audio-7B" for local, "gcp/google/gemini-2.5-pro" for NVIDIA API). + base_url: API base URL. If None, defaults to local server. For external APIs, provide the full URL. enable_audio_chunking: Master switch for audio chunking. audio_chunk_task_types: If None, chunk all task types; if specified, only chunk these. chunk_audio_threshold_sec: Audio duration threshold for chunking (in seconds). + audio_format: Format for audio content ("audio_url" or "input_audio"). If None, select by mode. **kwargs: Other parameters passed to VLLMModel/BaseModel. """ - super().__init__(**kwargs) + super().__init__(model=model, base_url=base_url, **kwargs) + + # Determine if this is an external API (non-local URL) + self._external_api_mode = not self._is_local_url(self.base_url) # Audio INPUT config self.enable_audio_chunking = enable_audio_chunking self.audio_chunk_task_types = audio_chunk_task_types self.chunk_audio_threshold_sec = chunk_audio_threshold_sec + if audio_format is None: + audio_format = "input_audio" if self._external_api_mode else "audio_url" + if audio_format not in ("audio_url", "input_audio"): + raise ValueError(f"Unsupported audio_format '{audio_format}'. Use 'audio_url' or 'input_audio'.") + self.audio_format = audio_format + # Audio OUTPUT config self.output_audio_dir = None if self.output_dir: @@ -85,6 +113,108 @@ def __init__( os.makedirs(self.output_audio_dir, exist_ok=True) LOG.info(f"Audio responses will be saved to: {self.output_audio_dir}") + def _is_local_url(self, base_url: str | None) -> bool: + """Check if the base_url points to a local server. + + Args: + base_url: API base URL. + + Returns: + True if local server, False otherwise. + """ + if not base_url: + return True # No URL means local server (will use default host:port) + local_patterns = ["127.0.0.1", "localhost", "0.0.0.0"] + return any(pattern in base_url for pattern in local_patterns) + + def _get_api_key(self, api_key: str | None, api_key_env_var: str | None, base_url: str) -> str | None: + """Get API key with smart detection for external APIs. + + Checks for API keys in the following order: + 1. Explicit api_key argument + 2. Environment variable specified by api_key_env_var + 3. Auto-detect based on base_url (NVIDIA_API_KEY, OPENAI_API_KEY, etc.) + + Args: + api_key: Explicit API key. + api_key_env_var: Environment variable name for API key. + base_url: API base URL for auto-detection. + + Returns: + API key string or None. + """ + # First, try parent class logic (explicit key or env var) + api_key = super()._get_api_key(api_key, api_key_env_var, base_url) + + if api_key is not None: + return api_key + + # Auto-detect API key based on base_url + if base_url: + if "api.nvidia.com" in base_url or "inference-api.nvidia.com" in base_url: + api_key = os.getenv("NV_INFERENCE_API_KEY") or os.getenv("NVIDIA_API_KEY") + if not api_key: + raise ValueError( + "NV_INFERENCE_API_KEY or NVIDIA_API_KEY is required for NVIDIA APIs and could not be found. " + "Set NV_INFERENCE_API_KEY/NVIDIA_API_KEY environment variable or pass api_key explicitly." + ) + return api_key + + if "api.openai.com" in base_url: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError( + "OPENAI_API_KEY is required for OpenAI APIs and could not be found. " + "Set OPENAI_API_KEY environment variable or pass api_key explicitly." + ) + return api_key + + if "generativelanguage.googleapis.com" in base_url: + api_key = os.getenv("GOOGLE_API_KEY") + if not api_key: + raise ValueError( + "GOOGLE_API_KEY is required for Google APIs and could not be found. " + "Set GOOGLE_API_KEY environment variable or pass api_key explicitly." + ) + return api_key + + return api_key + + def _build_request_body(self, top_k, min_p, repetition_penalty, extra_body: dict | None = None): + """Build request body, skipping vLLM-specific params for external APIs. + + Args: + top_k: Top-k sampling parameter (vLLM, default -1). + min_p: Min-p sampling parameter (vLLM, default 0.0). + repetition_penalty: Repetition penalty parameter (vLLM, default 1.0). + extra_body: Additional parameters to include. + + Returns: + Dictionary of extra body parameters for the request. + + Raises: + ValueError: If vLLM-specific params are set to non-default values in external API mode. + """ + # For external APIs, fail if user explicitly set vLLM-specific parameters + if self._external_api_mode: + non_default_params = [] + if top_k != -1: + non_default_params.append(f"top_k={top_k}") + if min_p != 0.0: + non_default_params.append(f"min_p={min_p}") + if repetition_penalty != 1.0: + non_default_params.append(f"repetition_penalty={repetition_penalty}") + + if non_default_params: + raise ValueError( + f"vLLM-specific parameters are not supported for external APIs: {', '.join(non_default_params)}. " + "These parameters only work with local vLLM servers." + ) + return extra_body or {} + + # For local vLLM server, use full parameter set + return super()._build_request_body(top_k, min_p, repetition_penalty, extra_body=extra_body) + def _parse_chat_completion_response(self, response, include_response: bool = False, **kwargs) -> dict: """Parse chat completion response and save any audio to disk.""" result = super()._parse_chat_completion_response(response, include_response=include_response, **kwargs) @@ -154,66 +284,63 @@ def _process_audio_response(self, audio_data, response_id: str) -> dict: # Audio INPUT methods # ===================== + def _preprocess_messages_for_model(self, messages: list[dict]) -> list[dict]: + """Preprocess messages - creates copies to avoid mutation. + + Note: /no_think suffix is passed through unchanged (handled by the model). + + Args: + messages: List of message dicts. + + Returns: + Copy of message dicts. + """ + return [copy.deepcopy(msg) for msg in messages] + def content_text_to_list(self, message: dict) -> dict: """Convert message content with audio to proper list format. Handles 'audio' or 'audios' keys in messages and converts them to - base64-encoded audio_url content items. + base64-encoded input_audio content items. - CRITICAL: Audio must come BEFORE text for Qwen models to transcribe correctly. + CRITICAL: Audio must come BEFORE text for models to process correctly. Args: message: Message dict that may contain 'audio' or 'audios' fields. Returns: - Message dict with content converted to list format including audio. + New message dict with content converted to list format including audio. """ if "audio" not in message and "audios" not in message: - return message + return copy.deepcopy(message) + + result = copy.deepcopy(message) - content = message.get("content", "") + if "content" not in result: + raise KeyError("Missing required 'content' in message") + content = result["content"] if isinstance(content, str): - message["content"] = [{"type": "text", "text": content}] - elif isinstance(content, list): - message["content"] = content - else: + result["content"] = [{"type": "text", "text": content}] + elif not isinstance(content, list): raise TypeError(f"Unexpected content type: {type(content)}") audio_items = [] - if "audio" in message: - audio = message["audio"] + if "audio" in result: + audio = result.pop("audio") audio_path = os.path.join(self.data_dir, audio["path"]) base64_audio = audio_file_to_base64(audio_path) - audio_message = {"type": "audio_url", "audio_url": {"url": f"data:audio/wav;base64,{base64_audio}"}} - audio_items.append(audio_message) - del message["audio"] # Remove original audio field after conversion - elif "audios" in message: - for audio in message["audios"]: + audio_items.append(make_audio_content_block(base64_audio, self.audio_format)) + elif "audios" in result: + for audio in result.pop("audios"): audio_path = os.path.join(self.data_dir, audio["path"]) base64_audio = audio_file_to_base64(audio_path) - audio_message = {"type": "audio_url", "audio_url": {"url": f"data:audio/wav;base64,{base64_audio}"}} - audio_items.append(audio_message) - del message["audios"] # Remove original audios field after conversion + audio_items.append(make_audio_content_block(base64_audio, self.audio_format)) - # Insert audio items at the BEGINNING of content list (before text) if audio_items: - message["content"] = audio_items + message["content"] - - return message - - def _preprocess_messages_for_model(self, messages: list[dict]) -> list[dict]: - """Preprocess messages - creates copies to avoid mutation. - - Note: /no_think suffix is passed through unchanged (handled by the model). - - Args: - messages: List of message dicts. + result["content"] = audio_items + result["content"] - Returns: - Copy of message dicts. - """ - return [copy.deepcopy(msg) for msg in messages] + return result def _needs_audio_chunking(self, messages: list[dict], task_type: str = None) -> tuple[bool, str, float]: """Check if audio in messages needs chunking. @@ -235,11 +362,14 @@ def _needs_audio_chunking(self, messages: list[dict], task_type: str = None) -> # Find audio in messages for msg in messages: - if msg.get("role") == "user": - audio_info = msg.get("audio") - if not audio_info: - audios = msg.get("audios", []) + if msg["role"] == "user": + if "audio" in msg: + audio_info = msg["audio"] + elif "audios" in msg: + audios = msg["audios"] audio_info = audios[0] if audios else {} + else: + continue if audio_info and "path" in audio_info: audio_path = os.path.join(self.data_dir, audio_info["path"]) @@ -287,9 +417,7 @@ async def _generate_with_chunking( result = None # Track cumulative statistics across chunks - total_input_tokens = 0 - total_generated_tokens = 0 - total_time = 0.0 + total_num_generated_tokens = 0 for chunk_idx, audio_chunk in enumerate(chunks): chunk_messages = [] @@ -300,16 +428,16 @@ async def _generate_with_chunking( if msg_copy["role"] == "user" and ("audio" in msg_copy or "audios" in msg_copy): chunk_base64 = save_audio_chunk_to_base64(audio_chunk, sampling_rate) - content = msg_copy.get("content", "") + if "content" not in msg_copy: + raise KeyError("Missing required 'content' in message") + content = msg_copy["content"] if isinstance(content, str): text_content = [{"type": "text", "text": content}] else: text_content = content # Add audio chunk at the beginning (before text) - msg_copy["content"] = [ - {"type": "audio_url", "audio_url": {"url": f"data:audio/wav;base64,{chunk_base64}"}} - ] + text_content + msg_copy["content"] = [make_audio_content_block(chunk_base64, self.audio_format)] + text_content # Remove original audio fields msg_copy.pop("audio", None) @@ -317,18 +445,13 @@ async def _generate_with_chunking( chunk_messages.append(msg_copy) - # Preprocess messages (strip /no_think for Qwen) - chunk_messages = self._preprocess_messages_for_model(chunk_messages) - # Generate for this chunk using parent's generate_async result = await super().generate_async( prompt=chunk_messages, tokens_to_generate=tokens_to_generate, **kwargs ) # Sum statistics from each chunk - total_input_tokens += result.get("input_tokens", 0) - total_generated_tokens += result.get("generated_tokens", 0) - total_time += result.get("time_elapsed", 0.0) + total_num_generated_tokens += result["num_generated_tokens"] generation = result["generation"] chunk_results.append(generation.strip()) @@ -336,17 +459,12 @@ async def _generate_with_chunking( # Aggregate results aggregated_text = " ".join(chunk_results) - if not result: - raise RuntimeError("Audio chunk generation returned no result") - final_result = result.copy() final_result["generation"] = aggregated_text final_result["num_audio_chunks"] = len(chunks) final_result["audio_duration"] = duration # Update with summed statistics - final_result["input_tokens"] = total_input_tokens - final_result["generated_tokens"] = total_generated_tokens - final_result["time_elapsed"] = total_time + final_result["num_generated_tokens"] = total_num_generated_tokens return final_result @@ -354,7 +472,7 @@ async def generate_async( self, prompt: str | list[dict] | None = None, tokens_to_generate: int | None = None, - task_type: str = None, + task_type: str | None = None, **kwargs, ) -> dict: """Generate with automatic audio chunking for long audio files. @@ -379,28 +497,8 @@ async def generate_async( return await self._generate_with_chunking(messages, audio_path, duration, tokens_to_generate, **kwargs) # No chunking needed - convert audio fields to base64 format - messages = [self.content_text_to_list(copy.deepcopy(msg)) for msg in messages] - messages = self._preprocess_messages_for_model(messages) + messages = [self.content_text_to_list(msg) for msg in messages] prompt = messages # Call parent's generate_async (which handles audio OUTPUT via _parse_chat_completion_response) return await super().generate_async(prompt=prompt, tokens_to_generate=tokens_to_generate, **kwargs) - - def _build_chat_request_params( - self, - messages: list[dict], - **kwargs, - ) -> dict: - """Build chat request parameters with audio preprocessing. - - Args: - messages: List of message dicts. - **kwargs: Additional parameters for the request. - - Returns: - Request parameters dict. - """ - # content_text_to_list THEN preprocess - messages = [self.content_text_to_list(copy.deepcopy(msg)) for msg in messages] - messages = self._preprocess_messages_for_model(messages) - return super()._build_chat_request_params(messages=messages, **kwargs) diff --git a/nemo_skills/pipeline/utils/server.py b/nemo_skills/pipeline/utils/server.py index 87abca4a99..9fb84a680f 100644 --- a/nemo_skills/pipeline/utils/server.py +++ b/nemo_skills/pipeline/utils/server.py @@ -24,6 +24,7 @@ class SupportedServersSelfHosted(str, Enum): trtllm = "trtllm" vllm = "vllm" + vllm_multimodal = "vllm_multimodal" sglang = "sglang" megatron = "megatron" generic = "generic" @@ -32,6 +33,7 @@ class SupportedServersSelfHosted(str, Enum): class SupportedServers(str, Enum): trtllm = "trtllm" vllm = "vllm" + vllm_multimodal = "vllm_multimodal" sglang = "sglang" megatron = "megatron" openai = "openai" @@ -125,7 +127,7 @@ def get_server_command( # check if the model path is mounted if not vllm, sglang, or trtllm; # vllm, sglang, trtllm can also pass model name as "model_path" so we need special processing - if server_type not in ["vllm", "sglang", "trtllm", "generic"]: + if server_type not in ["vllm", "vllm_multimodal", "sglang", "trtllm", "generic"]: check_if_mounted(cluster_config, model_path) # the model path will be mounted, so generally it will start with / @@ -158,7 +160,8 @@ def get_server_command( f" --micro-batch-size 1 " # that's a training argument, ignored here, but required to specify.. f" {server_args} " ) - elif server_type == "vllm": + elif server_type in ["vllm", "vllm_multimodal"]: + # vllm_multimodal uses the same vLLM server; multimodal handling is on the client side server_entrypoint = server_entrypoint or "-m nemo_skills.inference.server.serve_vllm" start_vllm_cmd = ( f"python3 {server_entrypoint} " diff --git a/requirements/common-tests.txt b/requirements/common-tests.txt index 4eca96f92a..8b2fbcce4d 100644 --- a/requirements/common-tests.txt +++ b/requirements/common-tests.txt @@ -16,3 +16,4 @@ pytest pytest-asyncio pytest-cov pytest-timeout +soundfile diff --git a/tests/test_nvidia_inference_api.py b/tests/test_nvidia_inference_api.py new file mode 100644 index 0000000000..e6abe6d8b5 --- /dev/null +++ b/tests/test_nvidia_inference_api.py @@ -0,0 +1,135 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for NVIDIA Inference API with VLLMMultimodalModel (audio input).""" + +import asyncio +import os +from pathlib import Path + +import pytest + +from nemo_skills.inference.model.vllm_multimodal import VLLMMultimodalModel + +NVIDIA_BASE_URL = "https://inference-api.nvidia.com/v1" +MODEL = "gcp/google/gemini-2.5-flash-lite" + +TEST_AUDIO_DIR = Path(__file__).parent / "slurm-tests" / "asr_nim" / "wavs" +TEST_AUDIO_T2 = TEST_AUDIO_DIR / "t2_16.wav" # "sample 2 this is a test of text to speech synthesis" +TEST_AUDIO_T3 = TEST_AUDIO_DIR / "t3_16.wav" # "sample 3 hello how are you today" + +requires_nvidia_api_key = pytest.mark.skipif( + not (os.getenv("NV_INFERENCE_API_KEY") or os.getenv("NVIDIA_API_KEY")), + reason="NV_INFERENCE_API_KEY/NVIDIA_API_KEY environment variable not set", +) + +requires_test_audio = pytest.mark.skipif( + not TEST_AUDIO_T2.exists() or not TEST_AUDIO_T3.exists(), + reason="Test audio files not found at tests/slurm-tests/asr_nim/wavs/", +) + + +@requires_nvidia_api_key +def test_nvidia_api_text_only(): + """Smoke test: text-only chat completion via NVIDIA Inference API.""" + model = VLLMMultimodalModel( + model=MODEL, + base_url=NVIDIA_BASE_URL, + audio_format="input_audio", + ) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello! Can you help me?"}, + ] + + result = asyncio.run( + model.generate_async( + prompt=messages, + tokens_to_generate=1024, + temperature=0.7, + ) + ) + + assert "generation" in result + assert len(result["generation"]) > 0 + print(f"[text-only] generation: {result['generation'][:200]}") + + +@requires_nvidia_api_key +@requires_test_audio +def test_nvidia_api_audio_input(): + """Integration test: audio-input generation using a local test audio file.""" + model = VLLMMultimodalModel( + model=MODEL, + base_url=NVIDIA_BASE_URL, + audio_format="input_audio", + data_dir=str(TEST_AUDIO_DIR), + ) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "What do you hear in this audio? Describe it briefly.", + "audio": {"path": "t2_16.wav"}, + }, + ] + + result = asyncio.run( + model.generate_async( + prompt=messages, + tokens_to_generate=1024, + temperature=0.7, + ) + ) + + assert "generation" in result + assert len(result["generation"]) > 0 + print(f"[audio-input] generation: {result['generation'][:300]}") + + +@requires_nvidia_api_key +@requires_test_audio +def test_nvidia_api_audio_with_transcription_prompt(): + """Integration test: ask the model to transcribe audio content.""" + model = VLLMMultimodalModel( + model=MODEL, + base_url=NVIDIA_BASE_URL, + audio_format="input_audio", + data_dir=str(TEST_AUDIO_DIR), + ) + + messages = [ + {"role": "system", "content": "You are a helpful assistant that can listen to audio."}, + { + "role": "user", + "content": "Please listen to this audio and tell me what you hear.", + "audio": {"path": "t3_16.wav"}, + }, + ] + + result = asyncio.run( + model.generate_async( + prompt=messages, + tokens_to_generate=1024, + temperature=0.7, + ) + ) + + assert "generation" in result + assert len(result["generation"]) > 0 + assert result["num_generated_tokens"] > 0 + print(f"[transcription] generation: {result['generation'][:300]}") + print(f"[transcription] tokens: {result['num_generated_tokens']}") diff --git a/tests/test_vllm_audio.py b/tests/test_vllm_audio.py index 0c8ca1b89a..a919ab33cb 100644 --- a/tests/test_vllm_audio.py +++ b/tests/test_vllm_audio.py @@ -42,6 +42,15 @@ def test_audio_file_to_base64(): os.unlink(temp_path) +def _is_valid_audio_content(content_item: dict) -> bool: + """Check if content item is a valid audio block (either format).""" + if content_item.get("type") == "audio_url": + return content_item.get("audio_url", {}).get("url", "").startswith("data:audio/wav;base64,") + elif content_item.get("type") == "input_audio": + return "data" in content_item.get("input_audio", {}) + return False + + @pytest.fixture def mock_vllm_multimodal_model(tmp_path): """Create a mock VLLMMultimodalModel for testing audio preprocessing.""" @@ -53,6 +62,23 @@ def mock_vllm_multimodal_model(tmp_path): model.enable_audio_chunking = True model.audio_chunk_task_types = None model.chunk_audio_threshold_sec = 30 + model.audio_format = "audio_url" # Test audio_url format (for vLLM/Qwen) + model._tunnel = None + return model + + +@pytest.fixture +def mock_vllm_multimodal_model_input_audio(tmp_path): + """Create a mock VLLMMultimodalModel configured for input_audio.""" + with patch.object(VLLMMultimodalModel, "__init__", lambda self, **kwargs: None): + model = VLLMMultimodalModel() + model.data_dir = str(tmp_path) + model.output_dir = None + model.output_audio_dir = None + model.enable_audio_chunking = True + model.audio_chunk_task_types = None + model.chunk_audio_threshold_sec = 30 + model.audio_format = "input_audio" model._tunnel = None return model @@ -72,8 +98,25 @@ def test_content_text_to_list_with_audio(mock_vllm_multimodal_model, tmp_path): assert isinstance(result["content"], list) assert len(result["content"]) == 2 - assert result["content"][0]["type"] == "audio_url" - assert result["content"][0]["audio_url"]["url"].startswith("data:audio/wav;base64,") + assert _is_valid_audio_content(result["content"][0]) + assert result["content"][1]["type"] == "text" + + +def test_content_text_to_list_with_input_audio_format(mock_vllm_multimodal_model_input_audio, tmp_path): + """Test audio conversion with input_audio format (OpenAI native).""" + audio_path = tmp_path / "test.wav" + with open(audio_path, "wb") as f: + f.write(b"RIFF" + b"\x00" * 100) + + message = {"role": "user", "content": "Describe this audio", "audio": {"path": "test.wav"}} + result = mock_vllm_multimodal_model_input_audio.content_text_to_list(message) + + assert isinstance(result["content"], list) + assert len(result["content"]) == 2 + # Verify input_audio format structure + assert result["content"][0]["type"] == "input_audio" + assert "data" in result["content"][0]["input_audio"] + assert result["content"][0]["input_audio"]["format"] == "wav" assert result["content"][1]["type"] == "text" @@ -100,8 +143,8 @@ def test_content_text_to_list_with_multiple_audios(mock_vllm_multimodal_model, t assert isinstance(result["content"], list) assert len(result["content"]) == 3 # Audio MUST come before text for Qwen Audio - assert result["content"][0]["type"] == "audio_url" - assert result["content"][1]["type"] == "audio_url" + assert _is_valid_audio_content(result["content"][0]) + assert _is_valid_audio_content(result["content"][1]) assert result["content"][2]["type"] == "text"