Skip to content

[Feat][Qwen3TTS] reduce TTFA with flexible initial phase#1583

Merged
hsliuustc0106 merged 22 commits intovllm-project:mainfrom
JuanPZuluaga:feat/qwen3tts-config-ttfp
Mar 5, 2026
Merged

[Feat][Qwen3TTS] reduce TTFA with flexible initial phase#1583
hsliuustc0106 merged 22 commits intovllm-project:mainfrom
JuanPZuluaga:feat/qwen3tts-config-ttfp

Conversation

@JuanPZuluaga
Copy link
Copy Markdown
Contributor

@JuanPZuluaga JuanPZuluaga commented Mar 1, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Related to #938.

In this PR, we lower the TTFA of Qwen3TTS by having a smaller number of frames required to start to stream the output. Ideally, this would help to perceive a faster response of the whole system by reducing the TTFA.

The current implementation is not too flexible, as if one wants to reduce the TTFA, we would end up always decoding with very small chunks, which can reduce audio quality.

  • what we add is initial_chunk_size which dictates the chunk rate used to generate audio. we can call this phase "warmup phase": initial_chunk_size + full left context
  • then, after we have collected enough frames, we move to the standard decoding phase: chunk_size + left_context_size

PS: i decided to not add support of left_context_size=-1 (aka full left context), cause we might get OOM in very long sequences (too large context), though the user can just increase left context, to let's say 10s.

I will add some audio samples later.

Test Plan

Test Result

  • benchmark of sending to 10 prompts to the model
  • I tested: initial_codec_chunk_frames = 0, 2, 5, 10, 15, 20" ; 0 means no initial_codec_chunk_frames
  • I ran the e2e script and added a timer in between chunks and print that:
  • ic means, vlaue set for initial_codec_chunk_frames
  • first req always takes more time, probably due to compilation?

TTFA (Time To First Audio) in milliseconds

Config Req 1 Req 2 Req 3 Req 4 Req 5 Req 6 Req 7 Req 8 Req 9 Avg Min Max
ic_0 2092 2099 2263 1835 2180 2097 1879 1870 1952 2030 1835 2263
ic_2 253 196 197 178 191 184 216 215 220 205 178 253
ic_5 489 403 401 417 471 383 412 522 436 437 383 522
ic_10 796 790 737 732 906 859 777 767 837 800 732 906
ic_15 1283 1398 1209 1122 1274 1166 1227 1090 1108 1209 1090 1398
ic_20 1641 1673 1632 1461 1324 1309 1285 1327 1274 1436 1274 1673

Total Generation Time (ms)

Config Req 1 Req 2 Req 3 Req 4 Req 5 Req 6 Req 7 Req 8 Req 9 Avg
ic_0 5395 5809 5806 6054 5894 6013 5458 7293 6296 6002
ic_2 4886 5486 4724 5164 4886 5503 5500 7974 6725 5650
ic_5 4976 5256 5072 5256 5073 5340 5892 7141 6213 5580
ic_10 4720 6344 5287 5815 5488 5817 5416 6778 6310 5775
ic_15 5047 5777 5049 5421 5336 5741 5931 6998 6078 5709
ic_20 4880 6340 5103 4805 4372 4663 4578 6086 4855 5076

Inter-Chunk Time (ms)

Config Avg Min Max Std Chunks
ic_0 2025 1733 2271 173 11
ic_2 328 124 2245 532 110
ic_5 714 305 2058 634 47
ic_10 1522 702 2294 596 25
ic_15 1901 1706 2100 111 18
ic_20 1664 1503 2173 217 16

output_1_ic_0.wav
output_1_ic_2.wav
output_1_ic_5.wav
output_1_ic_10.wav
output_1_ic_15.wav
output_1_ic_20.wav

Note that ic_15 and ic_20, would yield similar results.

EDIT: removed the first sample (in the tables) due to overhead in compile.
EDIT2: updated the numbers with a fresh run.

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@JuanPZuluaga JuanPZuluaga changed the title [Feat][Qwen3TTS] increase TTFA by reduced initial_codec_frames at decoding time [Feat][Qwen3TTS] reduce TTFA with flexible warmup phase Mar 1, 2026
pablo added 2 commits March 1, 2026 19:30
Signed-off-by: pablo <pablo@agigo.ai>
Signed-off-by: pablo <pablo@agigo.ai>
@amy-why-3459
Copy link
Copy Markdown
Contributor

amy-why-3459 commented Mar 2, 2026

Can we reduce TTFP by adjusting chunk_size?

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

JuanPZuluaga commented Mar 2, 2026

Can we reduce TTFP by adjusting chunk_size?

Indeed, but then:

  • lower chunk size, means more calls to the code2wav model. Let’s say we go from 25 (current default) to 5, we will end up having 5x more calls
  • we could play with increasing the left context to give more overall context, and account for the lower chunk size; but the pain point would still be there; TTFA would be high unless we go to reduce chunk size to very low values.

With this approach, at least, we know that the TTFA can go below 500ms, without compromising too much the quality and not too much overhead. We could even increase the chunk size to higher values to have better audio quality without increasing TTFA.

Signed-off-by: pablo <pablo@agigo.ai>
@JuanPZuluaga JuanPZuluaga marked this pull request as ready for review March 2, 2026 08:33
@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

JuanPZuluaga commented Mar 2, 2026

please let me know what you think @amy-why-3459 @linyueqian. TTFC can go to below ~300ms

Also, should i add some test?

@amy-why-3459
Copy link
Copy Markdown
Contributor

Thank you so much for your contribution, it's a great idea. May I ask what your test scenario is? For example, concurrency and input/output length?

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

Do you mean a real use case? If it's that, ideally, we would like to reduce TTFC on high concurrency loads for voice assistants, where we need very low latency for generating the first audio to the user. For instance, in batched offline decoding scenarios, i wouldn't see much importance to this value.

@Sy0307
Copy link
Copy Markdown
Contributor

Sy0307 commented Mar 2, 2026

It makes sense to me that there is initial_codec_chunk_frames in the warm-up stage to reduce TTFA, and we need a similar scenario as well.

Additionally, I have a question: is it better to implement initial_codec_chunk_frames as a stage-level configuration or a request-level configuration? Or could both exist, with the request-level config taking higher priority? Does this make sense? Welcome to discuss.

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

JuanPZuluaga commented Mar 2, 2026

It makes sense to me that there is initial_codec_chunk_frames in the warm-up stage to reduce TTFA, and we need a similar scenario as well.

Additionally, I have a question: is it better to implement initial_codec_chunk_frames as a stage-level configuration or a request-level configuration? Or could both exist, with the request-level config taking higher priority? Does this make sense? Welcome to discuss.

this is a good idea actually, though it means adding yet more params to be set request-level. Let me know what you think, and we can implement it.

@amy-why-3459
Copy link
Copy Markdown
Contributor

I'd like to ask about the test scenarios in your Test Result. Also, could you please adapt this solution to qwen3-omni and mimo_audio as well?

fix
Signed-off-by: pablo <pablo@agigo.ai>
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

what are the definitions of TTFA/TTFP?

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

JuanPZuluaga commented Mar 2, 2026

what are the definitions of TTFA/TTFP?

Sorry about that i've mixed both, but they mean the same: Time to First Audio or Time to first Packet. So far, i hear glitches while doing "live streaming" due to the fact that the model cannot keep with real time processing (meaning RTFx is less than one). I expect it to be fast once we get the different models compiled, etc.

@linyueqian
Copy link
Copy Markdown
Collaborator

Thanks for this feature! The TTFA improvement at ic=2 is really impressive (~10x reduction with negligible total latency overhead).

I had a question about the warmup to normal phase transition for longer responses. Correct me if my understanding is wrong:

At ic=2, each warmup chunk delivers 2 × 40ms = 80ms of audio. By the time all warmup chunks are delivered, the player has received a total of 12 × 80ms = 960ms of audio. But from that point, the first normal-phase chunk takes ~1593ms to generate (matching the ic_0 baseline). Since the total accumulated warmup audio (960ms) is less than the normal chunk generation time (1593ms), would the player run out of audio before the first normal chunk arrives? I noticed ic_2 has max inter-chunk = 1614ms and std = 430ms. Is that spike the transition point?

I suspect the benchmark avoids this because the test prompts are short enough (~25 frames, roughly 1 second) that the entire response is served within the initial phase and the normal phase is never entered. Would it be possible to test with a longer prompt (say, 5 to 10 seconds of output) to confirm?

Also a small terminology question: the codebase already uses "warmup" in a few other contexts (model warmup, benchmark warmup requests, cache-dit warmup steps). Would a different name like "startup phase" or "initial-chunk phase" avoid ambiguity? Happy to defer to whatever the team prefers!

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

Thanks for this feature! The TTFA improvement at ic=2 is really impressive (~10x reduction with negligible total latency overhead).

Indeed! It was very surprising to me, that with only 2 chunks, the performance is so good! I did a signal processing analysis, and the difference is minimal. I'd put the script but it's too long...

I had a question about the warmup to normal phase transition for longer responses. Correct me if my understanding is wrong:

At ic=2, each warmup chunk delivers 2 × 40ms = 80ms of audio. By the time all warmup chunks are delivered, the player has received a total of 12 × 80ms = 960ms of audio. But from that point, the first normal-phase chunk takes ~1593ms to generate (matching the ic_0 baseline). Since the total accumulated warmup audio (960ms) is less than the normal chunk generation time (1593ms), would the player run out of audio before the first normal chunk arrives? I noticed ic_2 has max inter-chunk = 1614ms and std = 430ms. Is that spike the transition point?

you are correct, i can notice the glitches in the first startup phase when the ic is very low (in reality the audio doesn't have glitches, it's only the perceived audio, kind of getting cut, it only happens during the startup phase though). I think with the optimizations that are planned (cuda graph in the speech tokenizer, etc), this issue will be gone. also, perhaps there's some hyper-param that needs to be tuned that i am missing.

In any case, if the current functionality is supported, further future optimizations already planned will allow us to use very low ic.

I suspect the benchmark avoids this because the test prompts are short enough (~25 frames, roughly 1 second) that the entire response is served within the initial phase and the normal phase is never entered. Would it be possible to test with a longer prompt (say, 5 to 10 seconds of output) to confirm?

I will try with longer sequences, but in my erlier experiments, i didn't see any issue TBH.

Also a small terminology question: the codebase already uses "warmup" in a few other contexts (model warmup, benchmark warmup requests, cache-dit warmup steps). Would a different name like "startup phase" or "initial-chunk phase" avoid ambiguity? Happy to defer to whatever the team prefers!

Startup phase sounds good! I'll change this! thanks for the feedback.

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

I'd like to ask about the test scenarios in your Test Result. Also, could you please adapt this solution to qwen3-omni and mimo_audio as well?

@amy-why-3459 we would need to wait for #1423 in order to add the same functionality in qwen3-omni

pablo added 2 commits March 2, 2026 22:18
@Sy0307
Copy link
Copy Markdown
Contributor

Sy0307 commented Mar 3, 2026

It makes sense to me that there is initial_codec_chunk_frames in the warm-up stage to reduce TTFA, and we need a similar scenario as well.
Additionally, I have a question: is it better to implement initial_codec_chunk_frames as a stage-level configuration or a request-level configuration? Or could both exist, with the request-level config taking higher priority? Does this make sense? Welcome to discuss.

this is a good idea actually, though it means adding yet more params to be set request-level. Let me know what you think, and we can implement it.

For example, the current configurations could be included as request-level parameters in something like additional_information. Additionally, we could specify the warmup duration in single request. For instance, we might allow the network protocol to dynamically adjust the chunk_size during warmup within a single request based on current network conditions, which would provide better real-time feedback. (By the way, I also agree that the term "warm_up" needs to be revised — it can be somewhat ambiguous.)

Signed-off-by: pablo <pablo@agigo.ai>
@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

JuanPZuluaga commented Mar 3, 2026

It makes sense to me that there is initial_codec_chunk_frames in the warm-up stage to reduce TTFA, and we need a similar scenario as well.
Additionally, I have a question: is it better to implement initial_codec_chunk_frames as a stage-level configuration or a request-level configuration? Or could both exist, with the request-level config taking higher priority? Does this make sense? Welcome to discuss.

this is a good idea actually, though it means adding yet more params to be set request-level. Let me know what you think, and we can implement it.

For example, the current configurations could be included as request-level parameters in something like additional_information. Additionally, we could specify the warmup duration in single request. For instance, we might allow the network protocol to dynamically adjust the chunk_size during warmup within a single request based on current network conditions, which would provide better real-time feedback. (By the way, I also agree that the term "warm_up" needs to be revised — it can be somewhat ambiguous.)

this sounds cool, is it worth doing it? I can add the variable chunk size that goes into additional_information.

@amy-why-3459 @linyueqian i changed warmup to "initial phase" everywhere. Please let me know which are other things to try, so i can try, otherwise, we could check for merge?

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

Ok, updated with configurable initial_codec_chunk_frames that can be set in the request. the rest is the same. @linyueqian @Sy0307

Test results below

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

JuanPZuluaga commented Mar 4, 2026

benchmark overriding initial-codec-chunk-frames via the request.

Benchmark on:

NVIDIA RTX 6000 Ada 48GB.

step 1: run model serve:

CUDA_VISIBLE_DEVICES=0 python -m vllm_omni.entrypoints.cli.main serve \
    "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice" \
    --omni --host 127.0.0.1 --port 8000 \
    --stage-configs-path benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml \
    --trust-remote-code

note this is the default config, the one in main already.

step 2: run

python vllm_omni/bench_tts_serve.py \
    --port 8000 \
    --num-prompts 50 \
    --max-concurrency 1 4 10 \
    --config-name "async_chunk" \
    --result-dir results

(my local vllm_omni/bench_tts_serve.py includes initial-codec-chunk-frames param).

comparison_baseline_req_ic2_ic0 comparison_baseline_req_ic5_ic0 comparison_baseline_req_ic10_ic0 comparison_baseline_req_ic15_ic0

pablo added 2 commits March 4, 2026 09:12
Signed-off-by: pablo <pablo@agigo.ai>
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

benchmark overriding initial-codec-chunk-frames via the request.

Benchmark on:

NVIDIA RTX 6000 Ada 48GB.

step 1: run model serve:

CUDA_VISIBLE_DEVICES=0 python -m vllm_omni.entrypoints.cli.main serve \
    "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice" \
    --omni --host 127.0.0.1 --port 8000 \
    --stage-configs-path benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml \
    --trust-remote-code

note this is the default config, the one in main already.

step 2: run

python vllm_omni/bench_tts_serve.py \
    --port 8000 \
    --num-prompts 50 \
    --max-concurrency 1 4 10 \
    --config-name "async_chunk" \
    --result-dir results

(my local vllm_omni/bench_tts_serve.py includes initial-codec-chunk-frames param).

comparison_baseline_req_ic2_ic0 comparison_baseline_req_ic5_ic0 comparison_baseline_req_ic10_ic0 comparison_baseline_req_ic15_ic0

why RTF becomes worse in this scenario?

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

I think they're very similar @hsliuustc0106. Could you please point out to which experiment specifically?

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

I think they're very similar @hsliuustc0106. Could you please point out to which experiment specifically?

in the first figure, concurrency at 1/4, the rtf looks higher when IC=2

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

I think they're very similar @hsliuustc0106. Could you please point out to which experiment specifically?

in the first figure, concurrency at 1/4, the rtf looks higher when IC=2

Oh you’re correct. In this setting, probably the small RTF overhead is due to doing 12 forward passes of the code2wav (of ic2, 2x12=24 chunks) instead of 1 forward passes in ic0, which does it once we have cached 25 chunks.

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see inline

window_frames = transfer_manager.code_prompt_token_ids[request_id][-end_index:]
if in_initial_phase:
# Initial-chunk phase: emit every initial_chunk_size frames with full accumulated context.
already_sent = transfer_manager.put_req_chunk[request_id] * initial_chunk_size
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does put_req_chunk get incremented after a chunk is emitted? This function reads it but I don't see the write side in this diff.

Copy link
Copy Markdown
Contributor Author

@JuanPZuluaga JuanPZuluaga Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is done already in:

if success:
self.put_req_chunk[request_id] += 1
logger.debug(f"[Stage-{stage_id}] Sent {connector_put_key}")

specifically in ChunkTransferAdapter._send_single_request().

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@Sy0307
Copy link
Copy Markdown
Contributor

Sy0307 commented Mar 5, 2026

lgtm. BTW I believe this should probably become a general method that can be used to reduce TTFA/TTFV for all models with high real-time requirements.

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

lgtm. BTW I believe this should probably become a general method that can be used to reduce TTFA/TTFV for all models with high real-time requirements.

Thanks! yeah, which models do you think we should target? qwen3-omni is down in line. I can work on it. Also, I think we should implement a smarter way to select this chunk.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

lgtm. BTW I believe this should probably become a general method that can be used to reduce TTFA/TTFV for all models with high real-time requirements.

Thanks! yeah, which models do you think we should target? qwen3-omni is down in line. I can work on it. Also, I think we should implement a smarter way to select this chunk.

qwen3.5-omni is on the way :)

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

lgtm. BTW I believe this should probably become a general method that can be used to reduce TTFA/TTFV for all models with high real-time requirements.

Thanks! yeah, which models do you think we should target? qwen3-omni is down in line. I can work on it. Also, I think we should implement a smarter way to select this chunk.

we should provide an elegant abstraction for this dynamic IC selection

@hsliuustc0106 hsliuustc0106 merged commit 6a45efb into vllm-project:main Mar 5, 2026
7 checks passed
@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

lgtm. BTW I believe this should probably become a general method that can be used to reduce TTFA/TTFV for all models with high real-time requirements.

Thanks! yeah, which models do you think we should target? qwen3-omni is down in line. I can work on it. Also, I think we should implement a smarter way to select this chunk.

we should provide an elegant abstraction for this dynamic IC selection

I’ll brainstorm a bit, and come with an useful abstraction based on the load in the ‘code2wav’ block.

linyueqian pushed a commit to lishunyang12/vllm-omni that referenced this pull request Mar 5, 2026
…t#1583)

Signed-off-by: pablo <pablo@agigo.ai>
Co-authored-by: pablo <pablo@agigo.ai>
@JuanPZuluaga JuanPZuluaga deleted the feat/qwen3tts-config-ttfp branch March 5, 2026 21:02
hsliuustc0106 added a commit to hsliuustc0106/vllm-omni-skills that referenced this pull request Mar 7, 2026
### vllm-omni-api
- Source: [PR #1724](vllm-project/vllm-omni#1724) - Revert "[Profile] Adding metrics for Diffusion/DiT Single diffusion Pipeline (#668)"
- Changes:
  - New feature: Revert "[Profile] Adding metrics for Diffusion/DiT Single diffusion Pipeline (#668)"

### vllm-omni-contrib
- Source: [PR #1724](vllm-project/vllm-omni#1724) - Revert "[Profile] Adding metrics for Diffusion/DiT Single diffusion Pipeline (#668)"
- Changes:
  - New feature: Revert "[Profile] Adding metrics for Diffusion/DiT Single diffusion Pipeline (#668)"

### vllm-omni-api
- Source: [PR #1716](vllm-project/vllm-omni#1716) - [Feature]:  Add vae-patch-parallel CLI argument in online serving
- Changes:
  - New feature: [Feature]:  Add vae-patch-parallel CLI argument in online serving

### vllm-omni-contrib
- Source: [PR #1716](vllm-project/vllm-omni#1716) - [Feature]:  Add vae-patch-parallel CLI argument in online serving
- Changes:
  - New feature: [Feature]:  Add vae-patch-parallel CLI argument in online serving

### vllm-omni-contrib
- Source: [PR #1693](vllm-project/vllm-omni#1693) - [skip CI][Docs] Add TTS model developer guide
- Changes:
  - New feature: [skip CI][Docs] Add TTS model developer guide

### vllm-omni-audio-tts
- Source: [PR #1688](vllm-project/vllm-omni#1688) - [MiMo-Audio] Bugfix tp lg than 1
- Changes:
  - Bug fix: [MiMo-Audio] Bugfix tp lg than 1

### vllm-omni-distributed
- Source: [PR #1688](vllm-project/vllm-omni#1688) - [MiMo-Audio] Bugfix tp lg than 1
- Changes:
  - Bug fix: [MiMo-Audio] Bugfix tp lg than 1

### vllm-omni-perf
- Source: [PR #1688](vllm-project/vllm-omni#1688) - [MiMo-Audio] Bugfix tp lg than 1
- Changes:
  - Bug fix: [MiMo-Audio] Bugfix tp lg than 1

### vllm-omni-perf
- Source: [PR #1687](vllm-project/vllm-omni#1687) - [BugFix] Return proper HTTP status for ErrorResponse in create_speech
- Changes:
  - Bug fix: [BugFix] Return proper HTTP status for ErrorResponse in create_speech

### vllm-omni-distributed
- Source: [PR #1687](vllm-project/vllm-omni#1687) - [BugFix] Return proper HTTP status for ErrorResponse in create_speech
- Changes:
  - Bug fix: [BugFix] Return proper HTTP status for ErrorResponse in create_speech

### vllm-omni-api
- Source: [PR #1687](vllm-project/vllm-omni#1687) - [BugFix] Return proper HTTP status for ErrorResponse in create_speech
- Changes:
  - Bug fix: [BugFix] Return proper HTTP status for ErrorResponse in create_speech
- Additions:
  - `/v1/audio/speech`

### vllm-omni-quantization
- Source: [PR #1687](vllm-project/vllm-omni#1687) - [BugFix] Return proper HTTP status for ErrorResponse in create_speech
- Changes:
  - Bug fix: [BugFix] Return proper HTTP status for ErrorResponse in create_speech

### vllm-omni-cicd
- Source: [PR #1683](vllm-project/vllm-omni#1683) - [CI] Remove high concurrency tests before issue #1374 fixed.
- Changes:
  - Bug fix: [CI] Remove high concurrency tests before issue #1374 fixed.

### vllm-omni-audio-tts
- Source: [PR #1678](vllm-project/vllm-omni#1678) - Add non-async chunk support for Qwen3-TTS
- Changes:
  - New feature: Add non-async chunk support for Qwen3-TTS

### vllm-omni-cicd
- Source: [PR #1678](vllm-project/vllm-omni#1678) - Add non-async chunk support for Qwen3-TTS
- Changes:
  - New feature: Add non-async chunk support for Qwen3-TTS

### vllm-omni-cicd
- Source: [PR #1677](vllm-project/vllm-omni#1677) - Replace hard-coded cuda generator with current_omni_platform.device_type

### vllm-omni-perf
- Source: [PR #1677](vllm-project/vllm-omni#1677) - Replace hard-coded cuda generator with current_omni_platform.device_type

### vllm-omni-serving
- Source: [PR #1675](vllm-project/vllm-omni#1675) - [Misc] remove logits_processor_pattern this field, because vllm have …

### vllm-omni-cicd
- Source: [PR #1666](vllm-project/vllm-omni#1666) - [Cleanup] Move cosyvoice3 tests to model subdirectory

### vllm-omni-audio-tts
- Source: [PR #1664](vllm-project/vllm-omni#1664) - [Bugfix] Fix all-silence TTS output: use float32 for speech tokenizer decoder
- Changes:
  - Bug fix: [Bugfix] Fix all-silence TTS output: use float32 for speech tokenizer decoder

### vllm-omni-cicd
- Source: [PR #1664](vllm-project/vllm-omni#1664) - [Bugfix] Fix all-silence TTS output: use float32 for speech tokenizer decoder
- Changes:
  - Bug fix: [Bugfix] Fix all-silence TTS output: use float32 for speech tokenizer decoder

### vllm-omni-distributed
- Source: [PR #1656](vllm-project/vllm-omni#1656) - [Optimize][Qwen3-Omni] Reduce inter-packet latency in async chunk

### vllm-omni-contrib
- Source: [PR #1656](vllm-project/vllm-omni#1656) - [Optimize][Qwen3-Omni] Reduce inter-packet latency in async chunk

### vllm-omni-quantization
- Source: [PR #1652](vllm-project/vllm-omni#1652) - [UX] Add progress bar for diffusion models
- Changes:
  - New feature: [UX] Add progress bar for diffusion models

### vllm-omni-perf
- Source: [PR #1652](vllm-project/vllm-omni#1652) - [UX] Add progress bar for diffusion models
- Changes:
  - New feature: [UX] Add progress bar for diffusion models

### vllm-omni-distributed
- Source: [PR #1651](vllm-project/vllm-omni#1651) - docs: Announce vllm-omni-skills community project

### vllm-omni-quantization
- Source: [PR #1651](vllm-project/vllm-omni#1651) - docs: Announce vllm-omni-skills community project

### vllm-omni-perf
- Source: [PR #1651](vllm-project/vllm-omni#1651) - docs: Announce vllm-omni-skills community project

### vllm-omni-contrib
- Source: [PR #1649](vllm-project/vllm-omni#1649) - [Misc] update wechat

### vllm-omni-perf
- Source: [PR #1642](vllm-project/vllm-omni#1642) - [chore] add _repeated_blocks for regional compilation support
- Changes:
  - New feature: [chore] add _repeated_blocks for regional compilation support

### vllm-omni-api
- Source: [PR #1641](vllm-project/vllm-omni#1641) - [Bugfix] Add TTS request validation to prevent engine crashes
- Changes:
  - New feature: [Bugfix] Add TTS request validation to prevent engine crashes

### vllm-omni-cicd
- Source: [PR #1641](vllm-project/vllm-omni#1641) - [Bugfix] Add TTS request validation to prevent engine crashes
- Changes:
  - New feature: [Bugfix] Add TTS request validation to prevent engine crashes

### vllm-omni-image-gen
- Source: [PR #1640](vllm-project/vllm-omni#1640) - [FP8 Quantization] Add FP8 quantization support for Flux transformer
- Changes:
  - New feature: [FP8 Quantization] Add FP8 quantization support for Flux transformer
- Additions:
  - text-to-image
  - Text-to-Image
  - Flux

### vllm-omni-quantization
- Source: [PR #1640](vllm-project/vllm-omni#1640) - [FP8 Quantization] Add FP8 quantization support for Flux transformer
- Changes:
  - New feature: [FP8 Quantization] Add FP8 quantization support for Flux transformer
- Additions:
  - FP8 support or improvements

### vllm-omni-contrib
- Source: [PR #1640](vllm-project/vllm-omni#1640) - [FP8 Quantization] Add FP8 quantization support for Flux transformer
- Changes:
  - New feature: [FP8 Quantization] Add FP8 quantization support for Flux transformer

### vllm-omni-perf
- Source: [PR #1640](vllm-project/vllm-omni#1640) - [FP8 Quantization] Add FP8 quantization support for Flux transformer
- Changes:
  - New feature: [FP8 Quantization] Add FP8 quantization support for Flux transformer

### vllm-omni-contrib
- Source: [PR #1631](vllm-project/vllm-omni#1631) - [BugFix] Fix LongCat Sequence Parallelism / Small Cleanup
- Changes:
  - Bug fix: [BugFix] Fix LongCat Sequence Parallelism / Small Cleanup

### vllm-omni-cicd
- Source: [PR #1628](vllm-project/vllm-omni#1628) - [Test][Qwen3-Omni]Modify Qwen3-Omni benchmark test cases

### vllm-omni-perf
- Source: [PR #1628](vllm-project/vllm-omni#1628) - [Test][Qwen3-Omni]Modify Qwen3-Omni benchmark test cases

### vllm-omni-perf
- Source: [PR #1619](vllm-project/vllm-omni#1619) - [Bugfix] Fix Qwen3-TTS code predictor crash due to missing vLLM config context
- Changes:
  - Bug fix: [Bugfix] Fix Qwen3-TTS code predictor crash due to missing vLLM config context

### vllm-omni-perf
- Source: [PR #1617](vllm-project/vllm-omni#1617) - [Refactor][Perf] Qwen3-TTS: re-prefill Code Predictor with torch.compile + enable Code2Wav decoder CUDA Graph
- Changes:
  - Performance improvement: [Refactor][Perf] Qwen3-TTS: re-prefill Code Predictor with torch.compile + enable Code2Wav decoder CUDA Graph

### vllm-omni-contrib
- Source: [PR #1615](vllm-project/vllm-omni#1615) - [Doc] Fix links in the configuration doc
- Changes:
  - Bug fix: [Doc] Fix links in the configuration doc

### vllm-omni-audio-tts
- Source: [PR #1614](vllm-project/vllm-omni#1614) - perf: replace per-element .item() GPU syncs with batch .tolist() in TTS code predictor
- Changes:
  - Performance improvement: perf: replace per-element .item() GPU syncs with batch .tolist() in TTS code predictor

### vllm-omni-perf
- Source: [PR #1614](vllm-project/vllm-omni#1614) - perf: replace per-element .item() GPU syncs with batch .tolist() in TTS code predictor
- Changes:
  - Performance improvement: perf: replace per-element .item() GPU syncs with batch .tolist() in TTS code predictor

### vllm-omni-image-gen
- Source: [PR #1609](vllm-project/vllm-omni#1609) - [Bugfix] Fix filepath resolution for model with subdir and GLM-Image generation
- Changes:
  - Bug fix: [Bugfix] Fix filepath resolution for model with subdir and GLM-Image generation
- Additions:
  - GLM-Image
  - GLM-Image
  - GLM-Image
  - GLM-Image
  - GLM-Image
  - GLM-Image
  - GLM-Image
  - GLM-Image

### vllm-omni-api
- Source: [PR #1609](vllm-project/vllm-omni#1609) - [Bugfix] Fix filepath resolution for model with subdir and GLM-Image generation
- Changes:
  - Bug fix: [Bugfix] Fix filepath resolution for model with subdir and GLM-Image generation

### vllm-omni-perf
- Source: [PR #1609](vllm-project/vllm-omni#1609) - [Bugfix] Fix filepath resolution for model with subdir and GLM-Image generation
- Changes:
  - Bug fix: [Bugfix] Fix filepath resolution for model with subdir and GLM-Image generation

### vllm-omni-contrib
- Source: [PR #1604](vllm-project/vllm-omni#1604) - [Model]: support Helios  from ByteDance

### vllm-omni-perf
- Source: [PR #1604](vllm-project/vllm-omni#1604) - [Model]: support Helios  from ByteDance

### vllm-omni-serving
- Source: [PR #1602](vllm-project/vllm-omni#1602) - [Bugfix] fix kernel error for qwen3-omni
- Changes:
  - Bug fix: [Bugfix] fix kernel error for qwen3-omni

### vllm-omni-distributed
- Source: [PR #1598](vllm-project/vllm-omni#1598) - [BugFix] Fix load_weights error when loading HunyuanImage3.0
- Changes:
  - Bug fix: [BugFix] Fix load_weights error when loading HunyuanImage3.0

### vllm-omni-image-gen
- Source: [PR #1598](vllm-project/vllm-omni#1598) - [BugFix] Fix load_weights error when loading HunyuanImage3.0
- Changes:
  - Bug fix: [BugFix] Fix load_weights error when loading HunyuanImage3.0
- Additions:
  - HunyuanImage3
  - HunyuanImage3Pipeline
  - HunyuanImage3
  - HunyuanImage-3
  - HunyuanImage-3
  - HunyuanImage-3
  - HunyuanImage3Pipeline
  - HunyuanImage3Pipeline
  - HunyuanImage3Pipeline
  - HunyuanImage3Pipeline
  - HunyuanImage3Pipeline
  - HunyuanImage3Pipeline
  - HunyuanImage3Pipeline
  - HunyuanImage3Pipeline
  - HunyuanImage-3

### vllm-omni-quantization
- Source: [PR #1598](vllm-project/vllm-omni#1598) - [BugFix] Fix load_weights error when loading HunyuanImage3.0
- Changes:
  - Bug fix: [BugFix] Fix load_weights error when loading HunyuanImage3.0

### vllm-omni-perf
- Source: [PR #1598](vllm-project/vllm-omni#1598) - [BugFix] Fix load_weights error when loading HunyuanImage3.0
- Changes:
  - Bug fix: [BugFix] Fix load_weights error when loading HunyuanImage3.0

### vllm-omni-audio-tts
- Source: [PR #1583](vllm-project/vllm-omni#1583) - [Feat][Qwen3TTS] reduce TTFA with flexible initial phase
- Changes:
  - New feature: [Feat][Qwen3TTS] reduce TTFA with flexible initial phase

### vllm-omni-api
- Source: [PR #1583](vllm-project/vllm-omni#1583) - [Feat][Qwen3TTS] reduce TTFA with flexible initial phase
- Changes:
  - New feature: [Feat][Qwen3TTS] reduce TTFA with flexible initial phase

### vllm-omni-cicd
- Source: [PR #1583](vllm-project/vllm-omni#1583) - [Feat][Qwen3TTS] reduce TTFA with flexible initial phase
- Changes:
  - New feature: [Feat][Qwen3TTS] reduce TTFA with flexible initial phase

### vllm-omni-contrib
- Source: [PR #1583](vllm-project/vllm-omni#1583) - [Feat][Qwen3TTS] reduce TTFA with flexible initial phase
- Changes:
  - New feature: [Feat][Qwen3TTS] reduce TTFA with flexible initial phase

### vllm-omni-api
- Source: [PR #1579](vllm-project/vllm-omni#1579) - [1/N][Refactor] Clean up dead code in output processor

### vllm-omni-serving
- Source: [PR #1579](vllm-project/vllm-omni#1579) - [1/N][Refactor] Clean up dead code in output processor

### vllm-omni-distributed
- Source: [PR #1578](vllm-project/vllm-omni#1578) - [Feature][Bagel] Add CFG parallel mode
- Changes:
  - New feature: [Feature][Bagel] Add CFG parallel mode

### vllm-omni-cicd
- Source: [PR #1578](vllm-project/vllm-omni#1578) - [Feature][Bagel] Add CFG parallel mode
- Changes:
  - New feature: [Feature][Bagel] Add CFG parallel mode

### vllm-omni-perf
- Source: [PR #1578](vllm-project/vllm-omni#1578) - [Feature][Bagel] Add CFG parallel mode
- Changes:
  - New feature: [Feature][Bagel] Add CFG parallel mode

### vllm-omni-contrib
- Source: [PR #1576](vllm-project/vllm-omni#1576) - 0.16.0 release

### vllm-omni-audio-tts
- Source: [PR #1570](vllm-project/vllm-omni#1570) - [bugfix] Fix unexpected argument 'is_finished' in function llm2code2wav_async_chunk of mimo-audio
- Changes:
  - Bug fix: [bugfix] Fix unexpected argument 'is_finished' in function llm2code2wav_async_chunk of mimo-audio

### vllm-omni-api
- Source: [PR #1566](vllm-project/vllm-omni#1566) - [Bugfix] Import InputPreprocessor into Renderer
- Changes:
  - Bug fix: [Bugfix] Import InputPreprocessor into Renderer

### vllm-omni-distributed
- Source: [PR #1539](vllm-project/vllm-omni#1539) - [Debug] Enable curl retry aligned with openai

### vllm-omni-quantization
- Source: [PR #1539](vllm-project/vllm-omni#1539) - [Debug] Enable curl retry aligned with openai

### vllm-omni-perf
- Source: [PR #1539](vllm-project/vllm-omni#1539) - [Debug] Enable curl retry aligned with openai

### vllm-omni-image-gen
- Source: [PR #1537](vllm-project/vllm-omni#1537) - [NPU] [Features] [Bugfix] Support mindiesd adaln
- Changes:
  - New feature: [NPU] [Features] [Bugfix] Support mindiesd adaln
- Additions:
  - mindiesd
  - mindiesd
  - Qwen-Image-Edit-2509
  - mindiesd
  - mindiesd
  - mindiesd
  - mindiesd

### vllm-omni-perf
- Source: [PR #1537](vllm-project/vllm-omni#1537) - [NPU] [Features] [Bugfix] Support mindiesd adaln
- Changes:
  - New feature: [NPU] [Features] [Bugfix] Support mindiesd adaln

### vllm-omni-serving
- Source: [PR #1536](vllm-project/vllm-omni#1536) - [Bugfix] Fix transformers 5.x compat issues in online TTS serving
- Changes:
  - Bug fix: [Bugfix] Fix transformers 5.x compat issues in online TTS serving

### vllm-omni-perf
- Source: [PR #1536](vllm-project/vllm-omni#1536) - [Bugfix] Fix transformers 5.x compat issues in online TTS serving
- Changes:
  - Bug fix: [Bugfix] Fix transformers 5.x compat issues in online TTS serving
lishunyang12 pushed a commit to lishunyang12/vllm-omni that referenced this pull request Mar 11, 2026
…t#1583)

Signed-off-by: pablo <pablo@agigo.ai>
Co-authored-by: pablo <pablo@agigo.ai>
Signed-off-by: lishunyang <lishunyang12@163.com>
@mkgs210
Copy link
Copy Markdown

mkgs210 commented Apr 6, 2026

When I use initial_codec_chunk_frames < 25 I always have 1s latency in the middle of every generation
chunk=1 t=0.360s size=7680
...
chunk=11 t=1.894s size=7680
chunk=12 t=2.054s size=7680
chunk=13 t=3.496s size=8689

chunk=14 t=3.604s size=16384
...
chunk=22 t=7.773s size=18182
done total=8.250s bytes=318720

It hapens with concurrency 4 or more on 3090/A100 and 6 or more on 4090
rft is good (>1.5-2)
I use qwen3_tts_bs16.yaml

@linyueqian
Copy link
Copy Markdown
Collaborator

@JuanPZuluaga ptal

When I use initial_codec_chunk_frames < 25 I always have 1s latency in the middle of every generation chunk=1 t=0.360s size=7680 ... chunk=11 t=1.894s size=7680 chunk=12 t=2.054s size=7680 chunk=13 t=3.496s size=8689 chunk=14 t=3.604s size=16384 ... chunk=22 t=7.773s size=18182 done total=8.250s bytes=318720

It hapens with concurrency 4 or more on 3090/A100 and 6 or more on 4090 rft is good (>1.5-2) I use qwen3_tts_bs16.yaml

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants