Skip to content

[NPU] Support code predictor NPU graph#2695

Merged
gcanlin merged 16 commits intovllm-project:mainfrom
gxxx-hum:optimize-qwen3-tts-talker
Apr 22, 2026
Merged

[NPU] Support code predictor NPU graph#2695
gcanlin merged 16 commits intovllm-project:mainfrom
gxxx-hum:optimize-qwen3-tts-talker

Conversation

@gxxx-hum
Copy link
Copy Markdown
Contributor

@gxxx-hum gxxx-hum commented Apr 11, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Optimize Qwen3-TTS code predictor inference on NPU.

Test Plan

Functional Test

Used examples/online_serving/qwen3_tts/openai_speech_client.py for smoke tests and compared Eager vs NPUGraph outputs :

# Base
python openai_speech_client.py \
--api-base "http://127.0.0.1:8080" \
--model /Qwen3-TTS-12Hz-1.7B-Base \
--task-type Base \
--text "Today I want to take a moment to reflect on how quickly technology is changing the way we communicate, learn, and create. A few years ago, speaking naturally with an intelligent system still felt like science fiction, but now it is becoming part of everyday life. What matters most is not only whether the voice sounds clear, but whether it feels steady, expressive, and comfortable to listen to over time. A good speech model should handle long sentences, short pauses, changes in rhythm, and subtle emotional tone without sounding mechanical or rushed. If this audio sounds smooth from beginning to end, with natural pacing and consistent pronunciation, then it is a strong sign that the generation quality is moving in the right direction." \
--ref-audio "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav" \
--ref-text "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you."

# CustomVoice
python openai_speech_client.py \
--api-base "http://127.0.0.1:8080" \
--text "你好,我是通义千问。今天我们来做一段稍微长一点的语音生成测试,主要观察声音在连续表达时的稳定性、自然度和情绪一致性。一个好的语音模型,不仅要把每个字读清楚,还要在长句子里保持合适的停顿、语速和语调变化。比如在讲解复杂内容的时候,声音应该听起来从容、连贯,而不是忽快忽慢,或者在句子中间突然变得生硬。如果这段音频从开头到结尾都比较自然,发音清晰,节奏稳定,听起来没有明显的机械感,那就说明当前模型在中文长文本生成上的表现是比较可靠的。" \
--model /Qwen3-TTS-12Hz-1.7B-CustomVoice \
--voice vivian \
--language Chinese \
--output customvoice_long_zh.wav

# VoiceDesign
python /openai_speech_client.py \
--api-base "http://127.0.0.1:8080" \
--model /nas/disk1/Qwen3-TTS-12Hz-1.7B-VoiceDesign \
--task-type VoiceDesign \
--text "今天我给你讲一个小笑话。有一只小猫第一次去上学,老师问它,一加一等于几呀?小猫想了半天,认真地说,等于两条小鱼!老师愣了一下,问它为什么。小猫眨眨眼说,因为我只要看到数字,就会想起午饭啦。虽然这个答案不太对,但是它说得太认真了,大家都忍不住笑了起来。" \
--instructions "体现稚嫩的萝莉女声,音调偏高,语气活泼可爱,讲笑话时带一点俏皮和开心的感觉" \
--output voicedesign_loli_joke.wav

Performance Test

A3(Tesed by @gcanlin)

Qwen3-TTS Performance Comparison

Note:

  • Lower is better for E2EL, AUDIO_TTFP, and AUDIO_RTF
  • Higher is better for Request throughput and Audio throughput
  • Sample sizes differ (main: 50 requests, current PR: 10 requests), so throughput and tail-latency results should be interpreted with that in mind
Metric main Current PR Delta Summary
Successful requests 50 10 - No failures in either run
Failed requests 0 0 - Stable
Maximum concurrency 1 1 0 Same setup
Benchmark duration (s) 389.10 36.01 -90.7% Strongly affected by sample count
Request throughput (req/s) 0.13 0.28 +115.4% ~2.15x higher
Mean E2EL (ms) 7781.79 3600.34 -53.7% Significantly reduced
Median E2EL (ms) 7656.77 3511.29 -54.1% Significantly reduced
P99 E2EL (ms) 13959.07 5020.51 -64.0% Major tail-latency improvement
Total audio duration generated (s) 287.04 56.72 - Affected by sample count
Audio throughput (audio duration/s) 0.74 1.58 +113.5% ~2.14x higher
Mean AUDIO_TTFP (ms) 293.47 178.08 -39.3% Faster first packet
Median AUDIO_TTFP (ms) 291.97 177.42 -39.2% Faster first packet
P99 AUDIO_TTFP (ms) 310.69 194.94 -37.3% Better tail first-packet latency
Mean AUDIO_RTF 1.356 0.636 -53.1% ~2.13x improvement
Median AUDIO_RTF 1.349 0.636 -52.9% ~2.12x improvement
P99 AUDIO_RTF 1.431 0.660 -53.9% ~2.17x improvement

Qwen3-Omni Performance Comparison

Note:

  • Lower is better for E2EL, TTFT, TPOT, ITL, AUDIO_TTFP, and AUDIO_RTF
  • Higher is better for throughput metrics
  • Results are grouped by maximum request concurrency

Concurrency = 1

Metric main Current PR Delta Summary
Successful requests 4 4 - No failures in either run
Failed requests 0 0 - Stable
Maximum request concurrency 1 1 0 Same setup
Benchmark duration (s) 239.25 114.52 -52.1% Much shorter run time
Request throughput (req/s) 0.02 0.03 +50.0% Improved
Peak concurrent requests 2.00 2.00 0 Same
Mean E2EL (ms) 59812.32 28629.16 -52.1% Significantly reduced
Median E2EL (ms) 63586.14 30457.98 -52.1% Significantly reduced
P99 E2EL (ms) 64524.93 31488.19 -51.2% Strong tail-latency improvement
Total input tokens 456 456 0 Same workload
Total generated tokens 400 400 0 Same workload
Output token throughput (tok/s) 1.67 3.49 +109.0% ~2.09x higher
Peak output token throughput (tok/s) 24.00 24.00 0 Same peak
Total token throughput (tok/s) 3.58 7.47 +108.7% ~2.09x higher
Mean TTFT (ms) 183.42 163.41 -10.9% Faster first token
Median TTFT (ms) 165.34 163.03 -1.4% Slightly better
P99 TTFT (ms) 244.62 166.80 -31.8% Better tail
Mean TPOT (ms) 50.95 50.43 -1.0% Nearly unchanged
Median TPOT (ms) 50.53 50.45 -0.2% Nearly unchanged
P99 TPOT (ms) 52.35 50.60 -3.3% Slight improvement
Mean ITL (ms) 50.44 49.92 -1.0% Nearly unchanged
Median ITL (ms) 50.44 49.98 -0.9% Nearly unchanged
P99 ITL (ms) 98.99 249.36 +151.9% Worse tail ITL
Total audio duration generated (s) 133.61 128.61 -3.7% Comparable workload
Audio throughput (audio duration/s) 0.56 1.12 +100.0% ~2.0x higher
Mean AUDIO_TTFP (ms) 9643.59 2065.79 -78.6% Major improvement
Median AUDIO_TTFP (ms) 3890.90 2069.08 -46.8% Strong improvement
P99 AUDIO_TTFP (ms) 26272.60 2088.16 -92.1% Massive tail improvement
Mean AUDIO_RTF 1.79 0.88 -50.8% ~2.03x improvement
Median AUDIO_RTF 1.78 0.89 -50.0% ~2.0x improvement
P99 AUDIO_RTF 1.84 0.89 -51.6% Strong improvement

Key Takeaways for Concurrency = 1

  • Mean E2EL reduced by 52.1%
  • Output token throughput improved by ~2.09x
  • Audio throughput improved by ~2.0x
  • Mean AUDIO_TTFP reduced by 78.6%
  • Mean AUDIO_RTF reduced by 50.8%
  • Text decoding steady-state latency (TPOT / ITL) is mostly unchanged, though P99 ITL regressed

Concurrency = 4

Metric main Current PR Delta Summary
Successful requests 10 10 - No failures in either run
Failed requests 0 0 - Stable
Maximum request concurrency 4 4 0 Same setup
Benchmark duration (s) 212.97 105.73 -50.4% Much shorter run time
Request throughput (req/s) 0.05 0.09 +80.0% Improved
Peak concurrent requests 5.00 6.00 +20.0% Higher peak concurrency observed
Mean E2EL (ms) 77712.11 36416.19 -53.1% Significantly reduced
Median E2EL (ms) 69565.38 35756.54 -48.6% Strong improvement
P99 E2EL (ms) 116068.41 50156.70 -56.8% Strong tail improvement
Total input tokens 1140 1140 0 Same workload
Total generated tokens 939 942 +0.3% Comparable workload
Output token throughput (tok/s) 4.41 8.91 +102.0% ~2.02x higher
Peak output token throughput (tok/s) 46.00 80.00 +73.9% Much higher peak
Total token throughput (tok/s) 9.76 19.69 +101.7% ~2.02x higher
Mean TTFT (ms) 291.23 250.25 -14.1% Faster first token
Median TTFT (ms) 300.33 239.80 -20.2% Faster first token
P99 TTFT (ms) 480.08 349.39 -27.2% Better tail
Mean TPOT (ms) 78.86 64.21 -18.6% Better decode efficiency
Median TPOT (ms) 77.01 60.56 -21.4% Better decode efficiency
P99 TPOT (ms) 109.53 108.92 -0.6% Nearly unchanged
Mean ITL (ms) 72.30 57.50 -20.5% Better inter-token latency
Median ITL (ms) 51.75 54.23 +4.8% Slight regression
P99 ITL (ms) 288.31 451.59 +56.6% Worse tail ITL
Total audio duration generated (s) 330.61 318.32 -3.7% Comparable workload
Audio throughput (audio duration/s) 1.55 3.01 +94.2% ~1.94x higher
Mean AUDIO_TTFP (ms) 5679.28 2976.20 -47.6% Strong improvement
Median AUDIO_TTFP (ms) 4436.45 2929.14 -34.0% Strong improvement
P99 AUDIO_TTFP (ms) 8039.99 3473.08 -56.8% Major tail improvement
Mean AUDIO_RTF 2.36 1.14 -51.7% ~2.07x improvement
Median AUDIO_RTF 1.91 1.15 -39.8% Strong improvement
P99 AUDIO_RTF 3.43 1.16 -66.2% Major tail improvement

Key Takeaways for Concurrency = 4

  • Mean E2EL reduced by 53.1%
  • Output token throughput improved by ~2.02x
  • Audio throughput improved by ~1.94x
  • Mean AUDIO_TTFP reduced by 47.6%
  • Mean AUDIO_RTF reduced by 51.7%
  • Decode efficiency improved (TTFT, TPOT, mean ITL), but P99 ITL regressed

Overall Summary

Across both concurrency settings, the current PR shows:

  • ~2x improvement in output token throughput
  • ~2x improvement in audio throughput
  • ~50% reduction in mean end-to-end latency
  • substantial reduction in audio time-to-first-packet
  • ~2x improvement in audio real-time factor

The main remaining regression to watch is:

  • Tail inter-token latency (P99 ITL) is worse in the current PR, despite overall throughput and E2EL improvements

Key Takeaways

Compared with main, the current PR shows:

  • Request throughput improved by ~2.15x
  • Audio throughput improved by ~2.14x
  • Mean end-to-end latency reduced by 53.7%
  • P99 end-to-end latency reduced by 64.0%
  • Mean audio time-to-first-packet reduced by 39.3%
  • Mean audio real-time factor reduced by 53.1%

Short PR Summary

Metric main Current PR Improvement
Request throughput (req/s) 0.13 0.28 +115%
Mean E2EL (ms) 7781.79 3600.34 -53.7%
P99 E2EL (ms) 13959.07 5020.51 -64.0%
Audio throughput 0.74 1.58 +113%
Mean AUDIO_TTFP (ms) 293.47 178.08 -39.3%
Mean AUDIO_RTF 1.356 0.636 -53.1%

Used benchmarks/qwen3_tts/vllm_omni/bench_tts_serve.py for performance tests and compared Eager vs NPUGraph, Using Base as an example :

# run command
nohup vllm serve /nas/disk1/Qwen3-TTS-12Hz-1.7B-Base \
--omni \
--allowed-local-media-path /workspace \
--port 8080 > tts_base.log 2>&1 &

# test command 
# concurrency: 1
nohup python bench_tts_serve.py \
--host 127.0.0.1 --port 8080 \
--task-type Base \
--num-prompts 10 --max-concurrency 1 \
--config-name base_baseline \
--result-dir benchmarks/qwen3-tts/results/ > tts_base_1.log 2>&1 &

# concurrency: 4 
# serving config: stage0_max_num_seqs: 4,stage0_max_num_batched_tokens: 2048
nohup python bench_tts_serve.py \
--host 127.0.0.1 --port 8080 \
--task-type Base \
--num-prompts 40 --max-concurrency 4 \
--config-name base_baseline \
--result-dir benchmarks/qwen3-tts/results/ > tts_base_4.log 2>&1 &

# concurrency: 10 
# serving config: stage0-max_num_seqs: 10,max_num_batched_tokens: 5120, max-model-len: 5120
nohup python bench_tts_serve.py \
--host 127.0.0.1 --port 8080 \
--task-type Base \
--num-prompts 100 --max-concurrency 10 \
--config-name base_baseline \
--result-dir benchmarks/qwen3-tts/results/ > tts_base_10.log 2>&1 &

Test Result

Functional Test

Summary

  • Generated audio from Base, CustomVoice, and VoiceDesign was tested with both short/long texts and Chinese/English inputs. The spoken content was consistent and no obvious noise or artifacts were observed.

  • For CustomVoice, minor speaking-rate differences were observed. Running the same script multiple times in both Eager and NPUGraph modes also produced audio with slightly different speaking speeds, so this is likely due to CustomVoice being more sensitive to the sampling parameters rather than an NPUGraph-specific regression.

Base

CustomVoice

VoiceDesign

Performance Test

Summary

  • Base: Clear improvements are observed at concurrency 1/4/10. At low concurrency, request throughput and audio throughput nearly doubled; at concurrency 10, throughput still improved by over 50%. E2E and Audio RTF both dropped by around 50% at low concurrency, and still improved by around 35% at concurrency 10. TTFP dropped by around 50% at low concurrency and around 10% at concurrency 10.

  • CustomVoice: Clear improvements are observed at concurrency 1/4/10. At low concurrency, throughput improved by up to 133%, and at concurrency 10, it still improved by up to 95%. E2E and Audio RTF dropped by around 56%-48% at concurrency 1/4/10. TTFP decreased by around 38%-32%.

  • VoiceDesign: Clear improvements are observed at concurrency 1/4/10. At low concurrency, throughput improved by around 130%, and at concurrency 10, it still improved by around 30%. E2E dropped by around 50% at low concurrency and around 20% at concurrency 10. TTFP decreased by around 45% at low concurrency and around 32% at concurrency 10.

Base

Concurrency Metrics Eager NPUGraph Change
并发1 Request throughput 0.07 0.15 114.3%
Mean E2EL 13540.65 6661.53 -50.8%
Audio throughput 0.32 0.65 103.1%
Mean AUDIO_TTFP 1053.42 822.71 -21.9%
Mean AUDIO_RTF 3.123 1.54 -50.7%
并发4 Request throughput 0.22 0.36 63.6%
Mean E2EL 17540.92 10929.86 -37.7%
Audio throughput 0.97 1.59 63.9%
Mean AUDIO_TTFP 3327.8 2729.74 -18.0%
Mean AUDIO_RTF 4.066 2.498 -38.6%
并发10 Request throughput 0.31 0.49 58.1%
Mean E2EL 30817.65 19862.14 -35.5%
Audio throughput 1.38 2.15 55.8%
Mean AUDIO_TTFP 8761.28 8055.83 -8.1%
Mean AUDIO_RTF 7.009 4.561 -34.9%

CustomVoice

Concurrency Metrics Eager NPUGraph Change
并发1 Request throughput 0.06 0.14 133.3%
Mean E2EL 16416.17 7129.04 -56.6%
Audio throughput 0.36 0.84 133.3%
Mean AUDIO_TTFP 583.89 357.69 -38.7%
Mean AUDIO_RTF 2.744 1.201 -56.2%
并发4 Request throughput 0.22 0.50 127.3%
Mean E2EL 17515.88 7914.26 -54.8%
Audio throughput 1.27 2.86 125.2%
Mean AUDIO_TTFP 2211.63 1061.67 -52%
Mean AUDIO_RTF 3.033 1.384 -54.4%
并发10 Request throughput 0.48 0.94 95.8%
Mean E2EL 19867.88 10183.69 -48.7%
Audio throughput 2.77 5.32 92.1%
Mean AUDIO_TTFP 3878.31 2600.98 -32.9%
Mean AUDIO_RTF 3.473 1.798 -48.2%

VoiceDesign

Concurrency Metrics Eager NPUGraph Change
并发1 Request throughput 0.07 0.16 128.6%
Mean E2EL 14693.99 6093.27 -58.5%
Audio throughput 0.32 0.78 143.8%
Mean AUDIO_TTFP 662.51 368.01 -44.5%
Mean AUDIO_RTF 3.1 1.285 -58.5%
并发4 Request throughput 0.24 0.6 150%
Mean E2EL 15744.72 6525.22 -58.6%
Audio throughput 1.14 2.81 146.5%
Mean AUDIO_TTFP 2442.42 1084.59 -55.6%
Mean AUDIO_RTF 3.356 1.404 -60.6%
并发10 Request throughput 0.64 0.83 29.7%
Mean E2EL 15007 11725.23 -21.9%
Audio throughput 2.92 3.79 29.8%
Mean AUDIO_TTFP 3576.28 3021.47 -15.5%
Mean AUDIO_RTF 3.273 2.561 -21.8%

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@gxxx-hum gxxx-hum requested a review from hsliuustc0106 as a code owner April 11, 2026 05:05
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@gcanlin
Copy link
Copy Markdown
Collaborator

gcanlin commented Apr 11, 2026

Impressive improvement for performance! Have you check the accuracy of outputted audios?

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review blocked by gate failures.

  • DCO: ACTION_REQUIRED
  • pre-commit: FAILURE

Please fix both before this can proceed.

Preliminary note: the hardcoded 2048×2048 fusion causal mask has no guard — if _num_groups + 1 > 2048, npu_fusion_attention will silently misbehave. Consider asserting or at minimum logging a warning when max_seq exceeds the mask size.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

NPU fusion attention + NPUGraph for code predictor. Two issues:

  1. DCO + pre-commit failing — please fix git commit signing and lint errors before requesting review.
  2. Hardcoded 2048 masktorch.ones(2048, ...) with no guard. If max_seq_len ever exceeds 2048 this silently produces incorrect masks. Consider using max_model_len from model config or at least adding an assertion.

# Ascend SDPA is_causal migration example uses a fixed 2048x2048
# compressed causal mask with sparse_mode=2.
fusion_mask = torch.triu(
torch.ones(2048, 2048, dtype=torch.bool),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto on the 2048 — please add an assert against self._num_groups + 1 at minimum.

@gcanlin gcanlin self-assigned this Apr 11, 2026
@gxxx-hum
Copy link
Copy Markdown
Contributor Author

Impressive improvement for performance! Have you check the accuracy of outputted audios?

I only ran a functional test with a single case, and the generated audio matches the output in eager-mode. Is there any dataset available for accuracy test? In addition, I will continue to complete the performance and accuracy tests for three model types.

@gxxx-hum
Copy link
Copy Markdown
Contributor Author

@hsliuustc0106 @lishunyang12 Thanks!I will update in a follow-up commit.

Signed-off-by: XIN GAO <1037396230@qq.com>
Signed-off-by: XIN GAO <1037396230@qq.com>
@gxxx-hum gxxx-hum force-pushed the optimize-qwen3-tts-talker branch from e714a89 to 2c021e3 Compare April 12, 2026 10:10
@hahadashi
Copy link
Copy Markdown

which platform 910B? 910C?

Signed-off-by: XIN GAO <1037396230@qq.com>
Signed-off-by: XIN GAO <1037396230@qq.com>
@gxxx-hum
Copy link
Copy Markdown
Contributor Author

gxxx-hum commented Apr 13, 2026

which platform 910B? 910C?

910B

@gxxx-hum
Copy link
Copy Markdown
Contributor Author

@gcanlin @hsliuustc0106 @lishunyang12 The previously mentioned issues have all been fixed. I also reran the functional and performance benchmarks, and updated the PR description .

@gcanlin
Copy link
Copy Markdown
Collaborator

gcanlin commented Apr 14, 2026

@gxxx-hum Thanks for clear benchmark! I ran a benchmark for Qwen3-TTS in v0.18.0 and get about 1.5 RTF. Why does this PR get the 2.7~3.1 RTF for 1 concurrency? Any idea about it?

@gxxx-hum
Copy link
Copy Markdown
Contributor Author

gxxx-hum commented Apr 14, 2026

@gxxx-hum Thanks for clear benchmark! I ran a benchmark for Qwen3-TTS in v0.18.0 and get about 1.5 RTF. Why does this PR get the 2.7~3.1 RTF for 1 concurrency? Any idea about it?

We can align the parameters.

The environment I used is:

  • Driver: 25.5.0
  • NPU: 910B3
  • CANN: 8.5.1
  • vLLM: v0.18.0
  • vllm-ascend: commit d781902 on top of v0.18.0
  • vllm-omni: commit 32af3af on the main branch
image

Also, I adjusted gpu_memory_utilization in the config yaml to 0.6:

async_chunk: true
stage_args:
  - stage_id: 0
    stage_type: llm
    is_comprehension: true
    runtime:
      devices: "0"
    engine_args:
      model_stage: qwen3_tts
      max_num_seqs: 1
      model_arch: Qwen3TTSTalkerForConditionalGeneration
      worker_type: ar
      scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
      enforce_eager: true
      trust_remote_code: true
      async_scheduling: false
      enable_prefix_caching: false
      engine_output_type: latent
      gpu_memory_utilization: 0.6
      distributed_executor_backend: "mp"
      max_num_batched_tokens: 512
      max_model_len: 4096
      custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
    # Use named connector to apply runtime.connectors.extra.
    output_connectors:
      to_stage_1: connector_of_shared_memory
    default_sampling_params:
      temperature: 0.9
      top_k: 50
      max_tokens: 4096
      seed: 42
      detokenize: false
      repetition_penalty: 1.05
      stop_token_ids: [2150]

  - stage_id: 1
    stage_type: llm
    runtime:
      devices: "0"
    engine_args:
      model_stage: code2wav
      max_num_seqs: 1
      model_arch: Qwen3TTSCode2Wav
      worker_type: generation
      scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
      enforce_eager: true
      trust_remote_code: true
      async_scheduling: false
      enable_prefix_caching: false
      engine_output_type: audio
      gpu_memory_utilization: 0.2
      distributed_executor_backend: "mp"
      # Must be divisible by num_code_groups and cover (left_context + chunk).
      max_num_batched_tokens: 32768
      # async_chunk appends windows per step; max_model_len must cover accumulated stream.
      max_model_len: 32768
    engine_input_source: [0]
    final_output: true
    final_output_type: audio
    # Distributed connector configuration
    input_connectors:
      from_stage_0: connector_of_shared_memory
    tts_args:
      max_instructions_length: 500
    default_sampling_params:
      temperature: 0.0
      top_p: 1.0
      top_k: -1
      max_tokens: 65536
      seed: 42
      detokenize: true
      repetition_penalty: 1.0

runtime:
  enabled: true
  defaults:
    window_size: -1
    max_inflight: 1

  connectors:
    connector_of_shared_memory:
      name: SharedMemoryConnector
      extra:
        shm_threshold_bytes: 65536
        # Frame-aligned codec streaming transport.
        codec_streaming: true
        # Connector polling / timeout (unit: loop count, sleep interval in seconds).
        connector_get_sleep_s: 0.01
        connector_get_max_wait_first_chunk: 3000
        connector_get_max_wait: 300
        # Align with Omni: small chunks with sufficient context overlap.
        codec_chunk_frames: 25
        codec_left_context_frames: 25

  edges:
    - from: 0
      to: 1
      window_size: -1

@gcanlin

@gcanlin
Copy link
Copy Markdown
Collaborator

gcanlin commented Apr 15, 2026

Good job! I ran it on A3 and get more impressive performance.

==================================================
             Serving Benchmark Result
==================================================
Successful requests:                    50
Failed requests:                        0
Maximum request concurrency:            1
Benchmark duration (s):                 177.59
Request throughput (req/s):             0.28
--------------------------------------------------
                End-to-end Latency
--------------------------------------------------
Mean E2EL (ms):                         3551.67
Median E2EL (ms):                       3474.04
P99 E2EL (ms):                          5675.86
==================================================
                   Audio Result
==================================================
Total audio duration generated (s):     289.12
Audio throughput (audio duration/s):    1.63
--------------------------------------------------
               Time to First Packet
--------------------------------------------------
Mean AUDIO_TTFP (ms):                   178.66
Median AUDIO_TTFP (ms):                 174.51
P99 AUDIO_TTFP (ms):                    199.50
--------------------------------------------------
                 Real Time Factor
--------------------------------------------------
Mean AUDIO_RTF:                         0.617
Median AUDIO_RTF:                       0.613
P99 AUDIO_RTF:                          0.653
==================================================

@hahadashi
Copy link
Copy Markdown

Good job! I ran it on A3 and get more impressive performance.

==================================================
             Serving Benchmark Result
==================================================
Successful requests:                    50
Failed requests:                        0
Maximum request concurrency:            1
Benchmark duration (s):                 177.59
Request throughput (req/s):             0.28
--------------------------------------------------
                End-to-end Latency
--------------------------------------------------
Mean E2EL (ms):                         3551.67
Median E2EL (ms):                       3474.04
P99 E2EL (ms):                          5675.86
==================================================
                   Audio Result
==================================================
Total audio duration generated (s):     289.12
Audio throughput (audio duration/s):    1.63
--------------------------------------------------
               Time to First Packet
--------------------------------------------------
Mean AUDIO_TTFP (ms):                   178.66
Median AUDIO_TTFP (ms):                 174.51
P99 AUDIO_TTFP (ms):                    199.50
--------------------------------------------------
                 Real Time Factor
--------------------------------------------------
Mean AUDIO_RTF:                         0.617
Median AUDIO_RTF:                       0.613
P99 AUDIO_RTF:                          0.653
==================================================

910B what different @gxxx-hum and @gcanlin

@gxxx-hum
Copy link
Copy Markdown
Contributor Author

I might be wrong,but the 910C seems closer to a 2 x 910B designed with shared memory and on-package interconnect @hahadashi

@OceanWang71
Copy link
Copy Markdown

@gxxx-hum Thanks for clear benchmark! I ran a benchmark for Qwen3-TTS in v0.18.0 and get about 1.5 RTF. Why does this PR get the 2.7~3.1 RTF for 1 concurrency? Any idea about it?

We can align the parameters.

The environment I used is:

  • Driver: 25.5.0
  • NPU: 910B3
  • CANN: 8.5.1
  • vLLM: v0.18.0
  • vllm-ascend: commit d781902 on top of v0.18.0
  • vllm-omni: commit 32af3f on the main branch
image Also, I adjusted `gpu_memory_utilization` in the config yaml to `0.6`:
async_chunk: true
stage_args:
  - stage_id: 0
    stage_type: llm
    is_comprehension: true
    runtime:
      devices: "0"
    engine_args:
      model_stage: qwen3_tts
      max_num_seqs: 1
      model_arch: Qwen3TTSTalkerForConditionalGeneration
      worker_type: ar
      scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
      enforce_eager: true
      trust_remote_code: true
      async_scheduling: false
      enable_prefix_caching: false
      engine_output_type: latent
      gpu_memory_utilization: 0.6
      distributed_executor_backend: "mp"
      max_num_batched_tokens: 512
      max_model_len: 4096
      custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
    # Use named connector to apply runtime.connectors.extra.
    output_connectors:
      to_stage_1: connector_of_shared_memory
    default_sampling_params:
      temperature: 0.9
      top_k: 50
      max_tokens: 4096
      seed: 42
      detokenize: false
      repetition_penalty: 1.05
      stop_token_ids: [2150]

  - stage_id: 1
    stage_type: llm
    runtime:
      devices: "0"
    engine_args:
      model_stage: code2wav
      max_num_seqs: 1
      model_arch: Qwen3TTSCode2Wav
      worker_type: generation
      scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
      enforce_eager: true
      trust_remote_code: true
      async_scheduling: false
      enable_prefix_caching: false
      engine_output_type: audio
      gpu_memory_utilization: 0.2
      distributed_executor_backend: "mp"
      # Must be divisible by num_code_groups and cover (left_context + chunk).
      max_num_batched_tokens: 32768
      # async_chunk appends windows per step; max_model_len must cover accumulated stream.
      max_model_len: 32768
    engine_input_source: [0]
    final_output: true
    final_output_type: audio
    # Distributed connector configuration
    input_connectors:
      from_stage_0: connector_of_shared_memory
    tts_args:
      max_instructions_length: 500
    default_sampling_params:
      temperature: 0.0
      top_p: 1.0
      top_k: -1
      max_tokens: 65536
      seed: 42
      detokenize: true
      repetition_penalty: 1.0

runtime:
  enabled: true
  defaults:
    window_size: -1
    max_inflight: 1

  connectors:
    connector_of_shared_memory:
      name: SharedMemoryConnector
      extra:
        shm_threshold_bytes: 65536
        # Frame-aligned codec streaming transport.
        codec_streaming: true
        # Connector polling / timeout (unit: loop count, sleep interval in seconds).
        connector_get_sleep_s: 0.01
        connector_get_max_wait_first_chunk: 3000
        connector_get_max_wait: 300
        # Align with Omni: small chunks with sufficient context overlap.
        codec_chunk_frames: 25
        codec_left_context_frames: 25

  edges:
    - from: 0
      to: 1
      window_size: -1

@gcanlin

The correct commit in the vllm-omni repository here should be 32af3af.

@gcanlin
Copy link
Copy Markdown
Collaborator

gcanlin commented Apr 21, 2026

I plan to land this PR first though it's adding new hardcode for hardware. But before abstracting attention backend(See #2967), we may need at least 2 instances. And cuda hardcode is existing for a long time. So I think integrating NPU hardcode temporarily would make sense. And I will abstract them ASAP. @hsliuustc0106

gcanlin added 4 commits April 21, 2026 08:18
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin gcanlin added the ready label to trigger buildkite CI label Apr 21, 2026
gcanlin added 4 commits April 22, 2026 01:43
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin gcanlin changed the title [NPU] Optimize Qwen3-TTS code predictor [NPU] Support code predictor NPU graph Apr 22, 2026
@gcanlin gcanlin added the omni-test label to trigger buildkite omni model test in nightly CI label Apr 22, 2026
gcanlin added 2 commits April 22, 2026 07:32
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin gcanlin removed the omni-test label to trigger buildkite omni model test in nightly CI label Apr 22, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Copy link
Copy Markdown
Collaborator

@gcanlin gcanlin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I updated the latest test results in the PR description. And have some legacy things needed to do:

  1. abstract code predictor attention to CustomOp;
  2. abstract sub-model graph wrapper;
  3. analyze why the performance is better when enabling use_cuda_graph for code predictor of Qwen3-Omni.

@gcanlin gcanlin added the omni-test label to trigger buildkite omni model test in nightly CI label Apr 22, 2026
@gcanlin gcanlin enabled auto-merge (squash) April 22, 2026 09:11
@gcanlin gcanlin merged commit 9d1392d into vllm-project:main Apr 22, 2026
8 checks passed
qinganrice pushed a commit to qinganrice/vllm-omni that referenced this pull request Apr 23, 2026
Signed-off-by: XIN GAO <1037396230@qq.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

omni-test label to trigger buildkite omni model test in nightly CI ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants