Skip to content

[Feat]audio streaming input for async chunk#3614

Merged
hsliuustc0106 merged 12 commits into
vllm-project:mainfrom
Shirley125:main_realtime_async
May 31, 2026
Merged

[Feat]audio streaming input for async chunk#3614
hsliuustc0106 merged 12 commits into
vllm-project:mainfrom
Shirley125:main_realtime_async

Conversation

@Shirley125
Copy link
Copy Markdown
Contributor

@Shirley125 Shirley125 commented May 14, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Streaming audio input means the client does not wait for the full audio request to arrive before starting inference. Instead, the audio is split into multiple slices. Each slice <|im_start|>user\n{audio_placeholder}<|im_end|>\n<|im_start|>assistant\n is submitted to the engine as a sub-request of the same session. A sub-request waits until the previous round finishes, then continues with conversation memory from earlier slices.

input slice 0 -> output segment 0 -> input slice 1 -> output segment 1 -> ... -> final slice -> final output

Difference from Chunked Prefill

Streaming input changes request boundaries: audio arrives as multiple slices, and generation can run between slices. Its output may differ from submitting the whole audio at once.

Chunked prefill only changes scheduling granularity: the full input is already known, but prefill is scheduled in smaller token chunks. It should preserve the same request semantics and output.

Streaming input:
  split input by time -> infer each slice sequentially -> result may differ

Chunked prefill:
  same full input -> split prefill scheduling only -> result should match

Flow Comparison

Without async_chunk, each streaming input segment is forwarded by the Orchestrator after the previous stage finishes. Downstream stages receive normal submit_update calls, and each stage replaces or extends its request state through the scheduler update path.
image

With async_chunk, the Orchestrator only prewarms downstream stages. Stage-to-stage data is transferred by connector chunks: the upstream stage puts chunks, while the downstream scheduler polls and gets chunks before running the next segment.
image

Main Changes

Orchestrator

  • Use add_request for the first audio slice and streaming_update for later slices. This reuses the same req_state and keeps all slices under one logical session.
  • In the existing async-chunk admission path, _prewarm_async_chunk_stages() passes resumable=req_state.streaming.enabled. This marks the downstream request as an unfinished streaming-input session.

Cross-stage data processing layer

  • Existing async-chunk logic treats the first transferred chunk as the prefill input.
    For streaming input, a later slice can also produce a new sub-request prefill. The current implementation identifies this case with request.resumable and thinker_emb.shape[0] > 1.

  • When async scheduling has produced one extra placeholder embedding at the end of a slice, that unconfirmed embedding must not be transferred downstream.

Data Transfer Layer

  • Add is_segment_finished beside the existing request-level finish marker:
`meta.is_segment_finished`: current audio slice is done
stops receiving upstream chunks for that input segment.It resumes receiving chunks when the next streaming input slice is scheduled. 
`meta.finished`: whole streaming request is done

Test Plan

accuracy:

python vllm-omni/examples/online_serving/qwen3_omni/openai_realtime_client.py --model /data/why/Qwen3-Omni-30B-A3B-Instruct --url ws://localhost:8316/v1/realtime --input-wav output.wav  --delta-dump-dir ./rt_delta_wav

input wave:
output.wav

performance:
Use different lengths' speech request and split it into streaming input slices.
Expected behavior:

  • TTFP improves because inference can start after the first audio slice instead of waiting for the full request.
  • The improvement ratio depends on slice granularity.
  • E2E/RTF should be at least on par with the non streaming input setup.
    Reproduce script:
    benchmark_realtime_vs_chat.py

Test Result

accuracy test output:
realtime_output.wav
perf:

audio wav audio input duration (s) TTFP realtime (s) TTFP chat (s) TTFP improve RTF realtime RTF chat
output.wav 20s 0.447 0.539 -17.1% 0.155 0.155
output1.wav 200s 0.398 0.806 -50.6% 0.149 0.147

TTFP:
1.Realtime is almost unaffected by input length: 0.45s → 0.40s, which matches the expectation of chunk-wise independent inference + concurrent uploading.
2.Chat latency increases significantly with longer input: 0.54s → 0.81s, because it requires uploading the entire audio as base64 at once and completing the full prefill before emitting the first packet.
3.The longer the input audio, the larger the advantage of realtime:
short input: only ~17% improvement
long input: nearly 50% reduction in TTFP
RTF:
1.The two endpoints are almost identical (difference ≤ 0.01), indicating that the throughput after the first packet is essentially the same.
2.RTF itself remains stable across different input lengths as well:
realtime: 0.149–0.155
chat: 0.147–0.155
Conclusion:
For long-audio scenarios, the primary value of the realtime endpoint lies in reducing TTFP, while the generation-phase throughput remains comparable between the two endpoints.

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@indevn
Copy link
Copy Markdown
Contributor

indevn commented May 16, 2026

hello, I’m opening a smaller related PR for /v1/realtime + async_chunk.

The PR only enables a commit-then-generate compatibility bridge: realtime audio chunks are buffered until input_audio_buffer.commit(final=true), then submitted as one normal Qwen3-Omni multimodal request through the existing async-chunk pipeline.

It does not implement early-start streaming input, prompt extension, or slice-by-slice async-chunk scheduling, so I believe it is complementary to this draft rather than a replacement. Leaving a note here since the PRs touch the same realtime/async-chunk area. Happy to adjust the scope if you think it conflicts with #3614.

Related PR: #3654

@Shirley125 Shirley125 force-pushed the main_realtime_async branch 4 times, most recently from 53b0a2a to 2260de2 Compare May 20, 2026 08:04
@Shirley125 Shirley125 marked this pull request as ready for review May 21, 2026 05:10
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@Shirley125 Shirley125 force-pushed the main_realtime_async branch 3 times, most recently from d8f14c2 to e73729a Compare May 23, 2026 08:39
@Gaohan123 Gaohan123 added this to the v0.22.0 milestone May 26, 2026
@Shirley125 Shirley125 changed the title [wip]audio streaming input for async chunk [Feat]audio streaming input for async chunk May 26, 2026
@Shirley125
Copy link
Copy Markdown
Contributor Author

@amy-why-3459 @ZeldaHuang @Sy0307 @lishunyang12 PTAL

@Shirley125
Copy link
Copy Markdown
Contributor Author

Shirley125 commented May 27, 2026

pytest -sv qwen3/tests/e2e/online_serving/test_qwen3_omni_expansion.py

36 passed, 3 skipped, 19 warnings in 4288.87s (1:11:28)

@Shirley125 Shirley125 force-pushed the main_realtime_async branch from 6a401a0 to 919d63d Compare May 27, 2026 06:13
@hsliuustc0106 hsliuustc0106 added the high priority high priority issue, needs to be done asap label May 28, 2026
@Shirley125 Shirley125 force-pushed the main_realtime_async branch from 919d63d to 5e85587 Compare May 28, 2026 01:12
Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
@Shirley125
Copy link
Copy Markdown
Contributor Author

demo.html

realtime_demo_new.mp4

Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
Comment thread vllm_omni/core/sched/omni_generation_scheduler.py Outdated
Comment thread vllm_omni/core/sched/omni_generation_scheduler.py Outdated
Comment thread vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

This PR adds streaming audio input support for the async chunk path, enabling /v1/realtime to work with async_chunk enabled. The design is sound: a resumable flag marks streaming sessions, is_segment_finished separates segment-level from request-level finish, and prefill chunks are cached until the first decode chunk arrives.

Validated:

  • CI gates passing (DCO, build, pre-commit, docs)
  • PR provides accuracy test output and performance benchmarks (TTFP improvement 17-50%, RTF parity)
  • New unit tests cover AR scheduler segment handling, adapter segment tracking, and streaming input prefill caching
  • Realtime websocket test now runs with async_chunk enabled

Blocking issue: bare except Exception with incomplete fallback in construct_next_stage_streaming_input_prompt.

PR size note: 18 files changed (exceeds 10-file threshold). Please run L3 tests locally and paste results: https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/ci/test_guide/#l3-level--l4-level


Reviewed by Claude Code with glm-5.1

request.prompt_token_ids.extend(new_prompt or ())
request.update_block_hashes()
request.num_prompt_tokens = len(request.prompt_token_ids)
except Exception:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bare except Exception silently swallows all errors. More critically, the fallback (line 249) only sets request.prompt_token_ids but does not update _all_token_ids, num_prompt_tokens, num_computed_tokens, or call update_block_hashes(). This can leave the request in an inconsistent state where the scheduler allocates the wrong number of KV blocks.

Suggested fix:

  1. Catch a narrower exception type (e.g., except (KeyError, AttributeError, IndexError)), or at minimum log the error.
  2. In the fallback, complete the state update:
except Exception:
    if prompt_token_ids is not None:
        next_prompt_len = max(1, len(prompt_token_ids))
        request._all_token_ids.clear()
        request.prompt_token_ids = [0] * next_prompt_len
        request._all_token_ids.extend(request.prompt_token_ids)
        request.num_computed_tokens = 0
        request.update_block_hashes()
        request.num_prompt_tokens = len(request.prompt_token_ids)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

pending_prefills = getattr(transfer_manager, "_pending_streaming_prefills", None)
if pending_prefills is None:
pending_prefills = {}
transfer_manager._pending_streaming_prefills = pending_prefills
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_pending_streaming_prefills is dynamically attached to transfer_manager here rather than initialized in ChunkTransferAdapter.__init__. This is fragile:

  1. The attribute may not exist when cleanup_sender runs (it uses getattr with None default, so cleanup is a no-op instead of cleaning a real dict).
  2. If this function is called after cleanup, it recreates the dict and adds an entry that is never cleaned up.

Please add self._pending_streaming_prefills: dict[str, dict] = {} to ChunkTransferAdapter.__init__ (line ~58 of chunk_transfer_adapter.py) and replace the getattr / monkey-patch pattern with a direct self._pending_streaming_prefills access.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment thread vllm_omni/core/sched/omni_generation_scheduler.py Outdated
@Shirley125 Shirley125 force-pushed the main_realtime_async branch from 7b31184 to 3f3580f Compare May 30, 2026 16:16
Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
@Shirley125 Shirley125 force-pushed the main_realtime_async branch from 3f3580f to 5689239 Compare May 30, 2026 16:35
Shirley125 and others added 2 commits May 31, 2026 00:41
Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

can you help resolve the doc build failures?

Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
@Shirley125
Copy link
Copy Markdown
Contributor Author

can you help resolve the doc build failures?

sure, resolved

@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label May 31, 2026
Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
@Shirley125 Shirley125 force-pushed the main_realtime_async branch from b63d1d0 to b287f81 Compare May 31, 2026 02:43
Copy link
Copy Markdown
Collaborator

@ZeldaHuang ZeldaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@Sy0307
Copy link
Copy Markdown
Collaborator

Sy0307 commented May 31, 2026

In vllm_omni/distributed/omni_connectors/adapter.py:219-245, the newly added helper’s docstring does not quite match the implementation. The docstring says this helper “resets the computed-token cursor”, but the code below does not reset request.num_computed_tokens. Instead, it keeps the old computed cursor and appends the previously generated output tokens back into prompt_token_ids:

def construct_next_stage_streaming_input_prompt(payload_data: dict[str, Any], request: Any) -> None:
    """Update a downstream streaming request prompt from connector payload ids.

    Async chunk downstream stages are prewarmed before the real Talker prompt is
    known. When a Thinker payload carries `ids.prompt`, this helper rebuilds the
    placeholder prompt length for the next stage, resets the computed-token
    cursor, and refreshes block hashes so the scheduler allocates KV slots that
    match the newly received streaming slice.
    """
    ids = payload_data.get("ids", {})
    prompt_token_ids = ids.get("prompt", None)
    if not prompt_token_ids:
        return
    num_computed_tokens = request.num_computed_tokens
    kept_output_tokens = request._all_token_ids[request.num_prompt_tokens : num_computed_tokens]
    del request._all_token_ids[num_computed_tokens:]
    request._output_token_ids.clear()
    assert request.prompt_token_ids is not None
    # Extend prompt with kept output tokens.
    request.prompt_token_ids.extend(kept_output_tokens)
    next_prompt_len = max(1, compute_talker_prompt_ids_length(prompt_token_ids))
    new_prompt = [0] * next_prompt_len
    request._all_token_ids.extend(new_prompt or ())
    request.prompt_token_ids.extend(new_prompt or ())
    request.update_block_hashes()
    request.num_prompt_tokens = len(request.prompt_token_ids)

This path appears to work in the current e2e flow, so I would not call it a correctness bug by itself. But the docstring should be clarified: this helper preserves num_computed_tokens and extends the prompt to maintain the scheduler’s token watermark; it does not reset the cursor. Otherwise future changes may misunderstand this state-machine invariant.

The rest content LGTM. Thanks for amazing improvement.

Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
@Shirley125 Shirley125 force-pushed the main_realtime_async branch from 3b1b400 to 93122b2 Compare May 31, 2026 08:15
@Shirley125
Copy link
Copy Markdown
Contributor Author

In vllm_omni/distributed/omni_connectors/adapter.py:219-245, the newly added helper’s docstring does not quite match the implementation. The docstring says this helper “resets the computed-token cursor”, but the code below does not reset request.num_computed_tokens. Instead, it keeps the old computed cursor and appends the previously generated output tokens back into prompt_token_ids:

def construct_next_stage_streaming_input_prompt(payload_data: dict[str, Any], request: Any) -> None:
    """Update a downstream streaming request prompt from connector payload ids.

    Async chunk downstream stages are prewarmed before the real Talker prompt is
    known. When a Thinker payload carries `ids.prompt`, this helper rebuilds the
    placeholder prompt length for the next stage, resets the computed-token
    cursor, and refreshes block hashes so the scheduler allocates KV slots that
    match the newly received streaming slice.
    """
    ids = payload_data.get("ids", {})
    prompt_token_ids = ids.get("prompt", None)
    if not prompt_token_ids:
        return
    num_computed_tokens = request.num_computed_tokens
    kept_output_tokens = request._all_token_ids[request.num_prompt_tokens : num_computed_tokens]
    del request._all_token_ids[num_computed_tokens:]
    request._output_token_ids.clear()
    assert request.prompt_token_ids is not None
    # Extend prompt with kept output tokens.
    request.prompt_token_ids.extend(kept_output_tokens)
    next_prompt_len = max(1, compute_talker_prompt_ids_length(prompt_token_ids))
    new_prompt = [0] * next_prompt_len
    request._all_token_ids.extend(new_prompt or ())
    request.prompt_token_ids.extend(new_prompt or ())
    request.update_block_hashes()
    request.num_prompt_tokens = len(request.prompt_token_ids)

This path appears to work in the current e2e flow, so I would not call it a correctness bug by itself. But the docstring should be clarified: this helper preserves num_computed_tokens and extends the prompt to maintain the scheduler’s token watermark; it does not reset the cursor. Otherwise future changes may misunderstand this state-machine invariant.

The rest content LGTM. Thanks for amazing improvement.

thx, updated

@hsliuustc0106 hsliuustc0106 enabled auto-merge (squash) May 31, 2026 08:26
@hsliuustc0106 hsliuustc0106 disabled auto-merge May 31, 2026 09:39
@hsliuustc0106 hsliuustc0106 merged commit 5dfdf58 into vllm-project:main May 31, 2026
7 of 8 checks passed
linyueqian added a commit that referenced this pull request May 31, 2026
Resolve docstring whitespace conflict in Qwen3-TTS prompt_embeds_builder; align with upstream/main style introduced by #3614.

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
86MaxCao pushed a commit to 86MaxCao/vllm-omni that referenced this pull request Jun 4, 2026
Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

high priority high priority issue, needs to be done asap ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants