diff --git a/tests/conftest.py b/tests/conftest.py index ace891db7bc..5692b094ff9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -623,6 +623,36 @@ def convert_audio_to_text(audio_data): return "" +def merge_base64_and_convert_to_text(base64_list): + """ + Merge a list of base64 encoded audio data and convert to text. + """ + import whisper + from pydub import AudioSegment + + merged_audio = None + for base64_data in base64_list: + audio_data = base64.b64decode(base64_data.split(",", 1)[-1]) + seg = AudioSegment.from_file(io.BytesIO(audio_data)) + if merged_audio is None: + merged_audio = seg + else: + merged_audio += seg + output_path = f"./test_{int(time.time())}" + merged_audio.export(output_path, format="wav") + model = whisper.load_model("base") + text = model.transcribe( + output_path, + temperature=0.0, + word_timestamps=True, + condition_on_previous_text=False, + )["text"] + if text: + return text + else: + return "" + + def modify_stage_config( yaml_path: str, updates: dict[str, Any], diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py index eb134ab5e78..6aec8d1003c 100644 --- a/tests/e2e/online_serving/test_qwen3_omni.py +++ b/tests/e2e/online_serving/test_qwen3_omni.py @@ -25,6 +25,7 @@ generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video, + merge_base64_and_convert_to_text, modify_stage_config, ) from vllm_omni.platforms import current_omni_platform @@ -157,11 +158,6 @@ def test_mix_to_text_audio_001(client: openai.OpenAI, omni_server, request) -> N Input Setting: stream=True Datasets: single request """ - # TODO:This skip will be removed when the chunk scenario supports multimodal input. - param = request.node.callspec.params.get("omni_server") - - if param[1] == CHUNK_CONFIG_PATH: - pytest.skip("The current chunk scenario does not support multimodal.") # Test single completion e2e_list = list() @@ -181,7 +177,7 @@ def test_mix_to_text_audio_001(client: openai.OpenAI, omni_server, request) -> N chat_completion = client.chat.completions.create(model=omni_server.model, messages=messages, stream=True) text_content = "" - audio_data = None + audio_data = [] for chunk in chat_completion: for choice in chunk.choices: if hasattr(choice, "delta"): @@ -192,11 +188,7 @@ def test_mix_to_text_audio_001(client: openai.OpenAI, omni_server, request) -> N modality = getattr(chunk, "modality", None) if modality == "audio" and content: - # Audio chunk - content - if audio_data is None: - audio_data = content - else: - audio_data += content + audio_data.append(content) elif modality == "text" and content: # Text chunk - accumulate text content text_content += content if content else "" @@ -218,7 +210,7 @@ def test_mix_to_text_audio_001(client: openai.OpenAI, omni_server, request) -> N ), "The output does not contain any of the keywords." # Verify text output same as audio output - audio_content = convert_audio_to_text(audio_data) + audio_content = merge_base64_and_convert_to_text(audio_data) print(f"text content is: {text_content}") print(f"audio content is: {audio_content}") similarity = cosine_similarity_text(audio_content.lower(), text_content.lower()) diff --git a/vllm_omni/distributed/omni_connectors/adapter.py b/vllm_omni/distributed/omni_connectors/adapter.py index a634595ac7f..b03202525cc 100644 --- a/vllm_omni/distributed/omni_connectors/adapter.py +++ b/vllm_omni/distributed/omni_connectors/adapter.py @@ -247,7 +247,7 @@ def get_through_connector(connector, target_stage_id, stage_id, req_id, connecto connector.get_requests[req_id] += 1 logger.debug(f"[Stage-{stage_id}] Received one chunk for request {connector_get_key}") break - time.sleep(0.1) + time.sleep(0.01) return payload_data @@ -330,10 +330,11 @@ def put_chunk( if stage_id == 0 and chunk_id == 0: if connector.request_payload.get(request_id) is None: - connector.request_payload[request_id] = payload_data - return + if not payload_data.get("finished"): + connector.request_payload[request_id] = payload_data + return else: - save_payload = connector.request_payload.get(request_id) + save_payload = connector.request_payload.pop(request_id) payload_data["thinker_embeddings"] = torch.cat( (save_payload.get("thinker_embeddings"), payload_data.get("thinker_embeddings")), dim=0 ) diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index ab5cf878c5a..65a69e0d462 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -114,7 +114,7 @@ def _consolidate_multimodal_tensors(self) -> None: if k == "audio": # When the audio tensor shape is inconsistent, torch.cat will fail. # We need to use torch.cat in -1 dimension. - self.mm_accumulated[k] = torch.cat(v, dim=-1) + continue else: self.mm_accumulated[k] = torch.cat(v, dim=0) except Exception: diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 179f58ed22b..ec2a18ac4d8 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import copy import time import weakref from collections.abc import AsyncGenerator, Iterable, Sequence from dataclasses import asdict from pprint import pformat -from typing import Any, cast +from typing import Any from vllm.config import VllmConfig from vllm.inputs.preprocess import InputPreprocessor @@ -32,7 +33,7 @@ from vllm_omni.entrypoints.utils import ( get_final_stage_id_for_e2e, ) -from vllm_omni.inputs.data import OmniPromptType, OmniSamplingParams, OmniTokensPrompt +from vllm_omni.inputs.data import OmniPromptType, OmniSamplingParams # Internal imports (our code) from vllm_omni.lora.request import LoRARequest @@ -304,38 +305,25 @@ async def generate( req_state = ClientRequestState(request_id) req_state.metrics = metrics self.request_states[request_id] = req_state - + sp0: SamplingParams = sampling_params_list[0] # type: ignore[index] + task = { + "request_id": request_id, + "engine_inputs": prompt, + "sampling_params": sp0, + } + self.stage_list[0].submit(task) + metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() + _req_start_ts[request_id] = time.time() + logger.info( + f"[{self._name}] Entering scheduling loop: stages={num_stages}, final_stage={final_stage_id_for_e2e}" + ) if self.async_chunk: stage_queues = {stage_id: asyncio.Queue() for stage_id in range(num_stages)} req_state.stage_queues = stage_queues - for i in range(num_stages): - sp: SamplingParams = cast(SamplingParams, sampling_params_list[i]) - engine_inputs = cast(OmniTokensPrompt, prompt) - if i != 0: - prompt_token_ids = engine_inputs["prompt_token_ids"] - prompt_1 = engine_inputs.copy() - prompt_1["prompt_token_ids"] = [0] * compute_talker_prompt_ids_length(prompt_token_ids) - prompt_1["multi_modal_data"] = prompt_1["mm_processor_kwargs"] = None - engine_inputs = prompt_1 - - task = { - "request_id": request_id, - "engine_inputs": engine_inputs, - "sampling_params": sp, - } - self.stage_list[i].submit(task) - metrics.stage_first_ts[i] = metrics.stage_first_ts[0] or time.time() - - logger.info(f"[{self._name}] Enqueued request {request_id} to stage-{str(i)}") - - _req_start_ts[request_id] = time.time() - - logger.info( - f"[{self._name}] Entering scheduling loop: " - f"stages={num_stages}, final_stage={final_stage_id_for_e2e}" - ) async for output in self._process_async_results( request_id, + prompt, + sampling_params_list, req_state, metrics, final_stage_id_for_e2e, @@ -344,22 +332,6 @@ async def generate( ): yield output else: - sp0: SamplingParams = sampling_params_list[0] # type: ignore[index] - task = { - "request_id": request_id, - "engine_inputs": prompt, - "sampling_params": sp0, - } - self.stage_list[0].submit(task) - - _req_start_ts[request_id] = time.time() - # Mark first input time for stage-0 - metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() - logger.info( - f"[{self._name}] Entering scheduling loop: " - f"stages={num_stages}, final_stage={final_stage_id_for_e2e}" - ) - async for output in self._process_sequential_results( request_id, req_state, @@ -390,6 +362,8 @@ async def generate( async def _process_async_results( self, request_id: str, + prompt: Any, + sampling_params_list: list[SamplingParams], req_state: ClientRequestState, metrics: OrchestratorMetrics, final_stage_id_for_e2e: int, @@ -397,16 +371,34 @@ async def _process_async_results( wall_start_ts: float, ) -> AsyncGenerator[OmniRequestOutput, None]: all_stages_finished = {stage_id: False for stage_id in range(final_stage_id_for_e2e + 1)} + submit_flag = True while not all(all_stages_finished.values()): for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]): if all_stages_finished[stage_id]: continue - result = await req_state.stage_queues[stage_id].get() - logger.info(f"[{self._name}] Received result from stage-{stage_id}: {result}") + try: + result = req_state.stage_queues[stage_id].get_nowait() + except asyncio.QueueEmpty: + await asyncio.sleep(0.001) + continue + engine_outputs, finished, output_to_yield = self._process_single_result( result, stage, stage_id, metrics, req_start_ts, wall_start_ts, final_stage_id_for_e2e ) - + if submit_flag and stage_id == 0: + submit_flag = False + prompt_token_ids = engine_outputs.prompt_token_ids + engine_input = copy.deepcopy(prompt) + engine_input["prompt_token_ids"] = [0] * compute_talker_prompt_ids_length(prompt_token_ids) + engine_input["multi_modal_data"] = engine_input["mm_processor_kwargs"] = None + for i in range(1, len(self.stage_list)): + task = { + "request_id": request_id, + "engine_inputs": engine_input, + "sampling_params": sampling_params_list[i], + } + self.stage_list[i].submit(task) + metrics.stage_first_ts[i] = time.time() all_stages_finished[stage_id] = finished if output_to_yield: diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index db8c972a1fd..61b00d8030f 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Final, Optional, cast import jinja2 +import torch from fastapi import Request from PIL import Image from pydantic import TypeAdapter @@ -1656,10 +1657,13 @@ def _create_audio_choice( ): choices: list[ChatCompletionResponseChoice] = [] final_res = omni_outputs.request_output + audio_data = final_res.multimodal_output.get("audio") if stream: - audio_tensor = final_res.multimodal_output["audio"][-1].float().detach().cpu().numpy() + audio_tensor = audio_data[-1].float().detach().cpu().numpy() else: - audio_tensor = final_res.multimodal_output["audio"].float().detach().cpu().numpy() + if isinstance(audio_data, list): + audio_data = torch.cat(audio_data, dim=-1) + audio_tensor = audio_data.float().detach().cpu().numpy() # Ensure audio is 1D (flatten if needed) if audio_tensor.ndim > 1: diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py index bc2242f7066..7adb6f96f89 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py @@ -219,7 +219,7 @@ def chunked_decode_streaming( end_index = codes.shape[-1] # TODO: need to optimize algorithms, current only support # chunk_size = left_context_size = 25 - if end_index == chunk_size: + if end_index <= chunk_size: context_size = 0 else: context_size = left_context_size