Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,36 @@ def convert_audio_to_text(audio_data):
return ""


def merge_base64_and_convert_to_text(base64_list):
"""
Merge a list of base64 encoded audio data and convert to text.
"""
import whisper
from pydub import AudioSegment

merged_audio = None
for base64_data in base64_list:
audio_data = base64.b64decode(base64_data.split(",", 1)[-1])
seg = AudioSegment.from_file(io.BytesIO(audio_data))
if merged_audio is None:
merged_audio = seg
else:
merged_audio += seg
output_path = f"./test_{int(time.time())}"
merged_audio.export(output_path, format="wav")
model = whisper.load_model("base")
text = model.transcribe(
output_path,
temperature=0.0,
word_timestamps=True,
condition_on_previous_text=False,
)["text"]
if text:
return text
else:
return ""


def modify_stage_config(
yaml_path: str,
updates: dict[str, Any],
Expand Down
16 changes: 4 additions & 12 deletions tests/e2e/online_serving/test_qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
generate_synthetic_audio,
generate_synthetic_image,
generate_synthetic_video,
merge_base64_and_convert_to_text,
modify_stage_config,
)
from vllm_omni.platforms import current_omni_platform
Expand Down Expand Up @@ -157,11 +158,6 @@ def test_mix_to_text_audio_001(client: openai.OpenAI, omni_server, request) -> N
Input Setting: stream=True
Datasets: single request
"""
# TODO:This skip will be removed when the chunk scenario supports multimodal input.
param = request.node.callspec.params.get("omni_server")

if param[1] == CHUNK_CONFIG_PATH:
pytest.skip("The current chunk scenario does not support multimodal.")

# Test single completion
e2e_list = list()
Expand All @@ -181,7 +177,7 @@ def test_mix_to_text_audio_001(client: openai.OpenAI, omni_server, request) -> N
chat_completion = client.chat.completions.create(model=omni_server.model, messages=messages, stream=True)

text_content = ""
audio_data = None
audio_data = []
for chunk in chat_completion:
for choice in chunk.choices:
if hasattr(choice, "delta"):
Expand All @@ -192,11 +188,7 @@ def test_mix_to_text_audio_001(client: openai.OpenAI, omni_server, request) -> N
modality = getattr(chunk, "modality", None)

if modality == "audio" and content:
# Audio chunk - content
if audio_data is None:
audio_data = content
else:
audio_data += content
audio_data.append(content)
elif modality == "text" and content:
# Text chunk - accumulate text content
text_content += content if content else ""
Expand All @@ -218,7 +210,7 @@ def test_mix_to_text_audio_001(client: openai.OpenAI, omni_server, request) -> N
), "The output does not contain any of the keywords."

# Verify text output same as audio output
audio_content = convert_audio_to_text(audio_data)
audio_content = merge_base64_and_convert_to_text(audio_data)
print(f"text content is: {text_content}")
print(f"audio content is: {audio_content}")
similarity = cosine_similarity_text(audio_content.lower(), text_content.lower())
Expand Down
9 changes: 5 additions & 4 deletions vllm_omni/distributed/omni_connectors/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def get_through_connector(connector, target_stage_id, stage_id, req_id, connecto
connector.get_requests[req_id] += 1
logger.debug(f"[Stage-{stage_id}] Received one chunk for request {connector_get_key}")
break
time.sleep(0.1)
time.sleep(0.01)
return payload_data


Expand Down Expand Up @@ -330,10 +330,11 @@ def put_chunk(

if stage_id == 0 and chunk_id == 0:
if connector.request_payload.get(request_id) is None:
connector.request_payload[request_id] = payload_data
return
if not payload_data.get("finished"):
connector.request_payload[request_id] = payload_data
return
else:
save_payload = connector.request_payload.get(request_id)
save_payload = connector.request_payload.pop(request_id)
payload_data["thinker_embeddings"] = torch.cat(
(save_payload.get("thinker_embeddings"), payload_data.get("thinker_embeddings")), dim=0
)
Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _consolidate_multimodal_tensors(self) -> None:
if k == "audio":
# When the audio tensor shape is inconsistent, torch.cat will fail.
# We need to use torch.cat in -1 dimension.
self.mm_accumulated[k] = torch.cat(v, dim=-1)
continue
else:
self.mm_accumulated[k] = torch.cat(v, dim=0)
except Exception:
Expand Down
88 changes: 40 additions & 48 deletions vllm_omni/entrypoints/async_omni.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import copy
import time
import weakref
from collections.abc import AsyncGenerator, Iterable, Sequence
from dataclasses import asdict
from pprint import pformat
from typing import Any, cast
from typing import Any

from vllm.config import VllmConfig
from vllm.inputs.preprocess import InputPreprocessor
Expand All @@ -32,7 +33,7 @@
from vllm_omni.entrypoints.utils import (
get_final_stage_id_for_e2e,
)
from vllm_omni.inputs.data import OmniPromptType, OmniSamplingParams, OmniTokensPrompt
from vllm_omni.inputs.data import OmniPromptType, OmniSamplingParams

# Internal imports (our code)
from vllm_omni.lora.request import LoRARequest
Expand Down Expand Up @@ -304,38 +305,25 @@ async def generate(
req_state = ClientRequestState(request_id)
req_state.metrics = metrics
self.request_states[request_id] = req_state

sp0: SamplingParams = sampling_params_list[0] # type: ignore[index]
task = {
"request_id": request_id,
"engine_inputs": prompt,
"sampling_params": sp0,
}
self.stage_list[0].submit(task)
metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time()
_req_start_ts[request_id] = time.time()
logger.info(
f"[{self._name}] Entering scheduling loop: stages={num_stages}, final_stage={final_stage_id_for_e2e}"
Comment on lines +314 to +318
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Initialize stage_queues before submitting stage-0 task

When async_chunk is enabled, the stage-0 task is submitted before req_state.stage_queues is created. The output handler only routes results to stage_queues if the attribute exists; otherwise it falls back to req_state.queue, which _process_async_results never drains. If stage-0 responds quickly (small prompts, cached model), its first result can be enqueued to the fallback queue and the async loop will wait forever on stage_queues, effectively hanging the request. Creating stage_queues before submit() (or draining the fallback queue in async mode) avoids this race.

Useful? React with 👍 / 👎.

)
if self.async_chunk:
stage_queues = {stage_id: asyncio.Queue() for stage_id in range(num_stages)}
req_state.stage_queues = stage_queues
for i in range(num_stages):
sp: SamplingParams = cast(SamplingParams, sampling_params_list[i])
engine_inputs = cast(OmniTokensPrompt, prompt)
if i != 0:
prompt_token_ids = engine_inputs["prompt_token_ids"]
prompt_1 = engine_inputs.copy()
prompt_1["prompt_token_ids"] = [0] * compute_talker_prompt_ids_length(prompt_token_ids)
prompt_1["multi_modal_data"] = prompt_1["mm_processor_kwargs"] = None
engine_inputs = prompt_1

task = {
"request_id": request_id,
"engine_inputs": engine_inputs,
"sampling_params": sp,
}
self.stage_list[i].submit(task)
metrics.stage_first_ts[i] = metrics.stage_first_ts[0] or time.time()

logger.info(f"[{self._name}] Enqueued request {request_id} to stage-{str(i)}")

_req_start_ts[request_id] = time.time()

logger.info(
f"[{self._name}] Entering scheduling loop: "
f"stages={num_stages}, final_stage={final_stage_id_for_e2e}"
)
async for output in self._process_async_results(
Comment on lines +308 to 323
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In async_chunk mode, stage-0 is submitted before req_state.stage_queues is initialized. If stage-0 produces output quickly, output_handler will fall back to req_state.queue (because stage_queues doesn’t exist yet), but _process_async_results only consumes from stage_queues, which can lead to the request hanging. Initialize stage_queues before submitting stage-0 (or make _process_async_results also consume from the fallback queue).

Copilot uses AI. Check for mistakes.
request_id,
prompt,
sampling_params_list,
req_state,
metrics,
final_stage_id_for_e2e,
Expand All @@ -344,22 +332,6 @@ async def generate(
):
yield output
else:
sp0: SamplingParams = sampling_params_list[0] # type: ignore[index]
task = {
"request_id": request_id,
"engine_inputs": prompt,
"sampling_params": sp0,
}
self.stage_list[0].submit(task)

_req_start_ts[request_id] = time.time()
# Mark first input time for stage-0
metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time()
logger.info(
f"[{self._name}] Entering scheduling loop: "
f"stages={num_stages}, final_stage={final_stage_id_for_e2e}"
)

async for output in self._process_sequential_results(
request_id,
req_state,
Expand Down Expand Up @@ -390,23 +362,43 @@ async def generate(
async def _process_async_results(
self,
request_id: str,
prompt: Any,
sampling_params_list: list[SamplingParams],
req_state: ClientRequestState,
metrics: OrchestratorMetrics,
final_stage_id_for_e2e: int,
req_start_ts: dict[int, float],
wall_start_ts: float,
) -> AsyncGenerator[OmniRequestOutput, None]:
all_stages_finished = {stage_id: False for stage_id in range(final_stage_id_for_e2e + 1)}
submit_flag = True
while not all(all_stages_finished.values()):
for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]):
if all_stages_finished[stage_id]:
continue
result = await req_state.stage_queues[stage_id].get()
logger.info(f"[{self._name}] Received result from stage-{stage_id}: {result}")
try:
result = req_state.stage_queues[stage_id].get_nowait()
except asyncio.QueueEmpty:
await asyncio.sleep(0.001)
continue
Comment on lines +379 to +383
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_process_async_results uses per-stage polling with get_nowait() and a fixed sleep(0.001) when a queue is empty. With multiple stages this can add avoidable latency and CPU overhead (sleep can be hit once per stage per loop iteration). Consider awaiting on queue.get() with a timeout, or using asyncio.wait/gather over pending get() tasks so you react to whichever stage produces output first without busy-polling.

Copilot uses AI. Check for mistakes.

engine_outputs, finished, output_to_yield = self._process_single_result(
result, stage, stage_id, metrics, req_start_ts, wall_start_ts, final_stage_id_for_e2e
)

if submit_flag and stage_id == 0:
submit_flag = False
prompt_token_ids = engine_outputs.prompt_token_ids
engine_input = copy.deepcopy(prompt)
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using copy.deepcopy(prompt) here can be very expensive for multimodal prompts (e.g., large base64 blobs / nested structures) even though multi_modal_data is immediately nulled out. A shallow copy (or constructing a minimal dict with just the needed fields like prompt_token_ids + required metadata) should be sufficient and avoids copying large unused payloads.

Suggested change
engine_input = copy.deepcopy(prompt)
engine_input = dict(prompt)

Copilot uses AI. Check for mistakes.
engine_input["prompt_token_ids"] = [0] * compute_talker_prompt_ids_length(prompt_token_ids)
engine_input["multi_modal_data"] = engine_input["mm_processor_kwargs"] = None
for i in range(1, len(self.stage_list)):
task = {
"request_id": request_id,
"engine_inputs": engine_input,
"sampling_params": sampling_params_list[i],
}
self.stage_list[i].submit(task)
metrics.stage_first_ts[i] = time.time()
Comment on lines +392 to +401
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Async chunk submission currently enqueues tasks for stages 1..len(self.stage_list)-1, even when final_stage_id_for_e2e is smaller. Stages beyond final_stage_id_for_e2e won’t be drained by _process_async_results, so their outputs can accumulate in stage_queues (memory growth) and do unnecessary work. Limit submission to stages up to final_stage_id_for_e2e (and/or only the stages required for the selected output_modalities).

Copilot uses AI. Check for mistakes.
all_stages_finished[stage_id] = finished

if output_to_yield:
Expand Down
8 changes: 6 additions & 2 deletions vllm_omni/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import TYPE_CHECKING, Any, Final, Optional, cast

import jinja2
import torch
from fastapi import Request
from PIL import Image
from pydantic import TypeAdapter
Expand Down Expand Up @@ -1656,10 +1657,13 @@ def _create_audio_choice(
):
choices: list[ChatCompletionResponseChoice] = []
final_res = omni_outputs.request_output
audio_data = final_res.multimodal_output.get("audio")
if stream:
audio_tensor = final_res.multimodal_output["audio"][-1].float().detach().cpu().numpy()
audio_tensor = audio_data[-1].float().detach().cpu().numpy()
else:
audio_tensor = final_res.multimodal_output["audio"].float().detach().cpu().numpy()
if isinstance(audio_data, list):
audio_data = torch.cat(audio_data, dim=-1)
audio_tensor = audio_data.float().detach().cpu().numpy()

# Ensure audio is 1D (flatten if needed)
if audio_tensor.ndim > 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def chunked_decode_streaming(
end_index = codes.shape[-1]
# TODO: need to optimize algorithms, current only support
# chunk_size = left_context_size = 25
if end_index == chunk_size:
if end_index <= chunk_size:
context_size = 0
else:
context_size = left_context_size
Expand Down