Skip to content

[Refactor][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips#1758

Merged
hsliuustc0106 merged 10 commits into
vllm-project:mainfrom
LJH-LBJ:qwen3-omni-decode-performance
Mar 10, 2026
Merged

[Refactor][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips#1758
hsliuustc0106 merged 10 commits into
vllm-project:mainfrom
LJH-LBJ:qwen3-omni-decode-performance

Conversation

@LJH-LBJ
Copy link
Copy Markdown
Contributor

@LJH-LBJ LJH-LBJ commented Mar 9, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

This PR delivers two groups of performance optimizations targeting per-token decode latency for Qwen3-Omni generation:

  1. Code Predictor MTP rewrite — replace the original HF-based autoregressive code predictor with a streamlined re-prefill implementation
  2. Decode hot-path GPU buffer optimizations — eliminate unnecessary GPU→CPU→GPU tensor transfers in every decode step

Motivation

Profiling the Qwen3-Omni decode loop revealed two categories of overhead:

  • The code predictor used HuggingFace's model.generate() under the hood, which introduced heavy Python-level dispatch, dynamic memory allocation, and KV cache management for very short sequences (2→33 tokens).
  • Several tensors (last_talker_hidden, trailing_text_hidden, tts_pad_embed_projected, suppressed token logits) were round-tripped GPU→CPU→GPU on every decode step, adding ~0.15-0.35ms per step × ~900 steps = 135-315ms total.

Changes

1. Code Predictor MTP Rewrite (qwen3_omni_moe_code_predictor_mtp.py)

Aspect Before After
Attention HF Qwen3MoeAttention with manual mask construction F.scaled_dot_product_attention with native GQA via enable_gqa=True
Decoding strategy Autoregressive with KV Cache (33 sequential steps) Re-prefill: recompute full sequence each step (2→33 tokens), no KV Cache
Memory allocation Dynamic allocation per call Persistent buffers (_proj_buf, _pos_ids) lazily initialized, zero runtime allocation
Sampling Custom sampling operators Inline top-k: torch.topkmasked_fill(-inf)multinomial, no custom ops
Module references nn.ModuleList traversal Cached plain lists (_lm_heads, _codec_embeds) bypass ModuleList.__getitem__ overhead
Compilation Off by default torch.compile ON by default (fullgraph=True, mode="max-autotune-no-cudagraphs")
Decorator None @torch.inference_mode() to skip autograd bookkeeping

2. Talker Single-Loop Refactor (qwen3_omni_moe_talker.py)

  • Replaced the two-phase (prefill + per-group loop) code_predictor_forward with a single unified loop that calls the rewritten MTP predictor once per group.

3. Decode Hot-Path GPU Buffer Optimizations (qwen3_omni.py)

Optimization Before After
suppressed_tokens logit masking Python list → logits.cpu() → scatter → .to(device) Pre-built torch.bool GPU mask + masked_fill_(-inf)
last_talker_hidden .detach().to("cpu").contiguous() every step .detach() — keep on GPU
trailing_text_hidden .detach().to("cpu").contiguous() every step .detach() — keep on GPU
tts_pad_embed_projected .detach().to("cpu").contiguous() every step .detach() — keep on GPU
talker_mtp clones 2× unnecessary .clone() on inputs_embeds and summed_embeddings Removed

4. Runner GPU-Resident Buffer Mechanism (gpu_model_runner.py)

  • Added gpu_resident_buffer_keys protocol: models declare a set[str] of buffer keys that should remain on GPU instead of being offloaded to CPU.
  • In _update_intermediate_buffer, keys in this set use v.detach().clone() instead of v.detach().to("cpu").contiguous() .
  • .clone() is required (not just .detach()) because CUDA Graph replay reuses the same GPU memory addresses — a bare .detach() returns a view that gets overwritten on the next forward pass.

Files Changed

File Description
qwen3_omni_moe_code_predictor_mtp.py Full rewrite: SDPA, re-prefill, persistent buffers, inline top-k, torch.compile
qwen3_omni_moe_talker.py Single-loop code_predictor_forward
qwen3_omni.py GPU-resident buffers, suppressed_tokens bool mask, remove unnecessary clones
gpu_model_runner.py gpu_resident_buffer_keys mechanism in _update_intermediate_buffer

Test Plan

  • Accuracy
pytest -sv tests/e2e/online_serving/test_qwen3_omni.py -m "advanced_model" --run-level "advanced_model"

def get_chunk_config():
    path = modify_stage_config(
        str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci_async_chunk.yaml"),
        updates={
            "async_chunk": True,
            "stage_args": {
                0: {
                    "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
                },
                1: {
                    "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
                },
            },
        },
        deletes={"stage_args": {2: ["custom_process_input_func"]}},
    )
    return path
  • Benchmark
vllm serve /workspace/models/Qwen3-Omni-30B-A3B-Instruct --omni --port 50713 --stage-configs-path ./vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml

Test Result

  • Accuracy
================== 2 passed, 19 warnings in 664.80s (0:11:04) ===============
  • Benchmark
Concurrency TTFP (ms) E2E (ms) RTF Audio(s/s) Req/s
Before
1 271.5 61969.9 0.32 3.08 0.02
4 721.4 61345.7 0.43 7.81 0.05
10 9209.1 126101.1 0.83 7.47 0.05
After
1 285.8 54511.4 0.30 3.30 0.02
4 715.9 57835.8 0.42 9.41 0.07
10 6867.3 118687.9 0.74 10.39 0.06
After

============ Serving Benchmark Result ============							
Successful requests:                     10        							
Failed requests:                         0         							
Maximum request concurrency:             1         							
Benchmark duration (s):                  545.13    							
Request throughput (req/s):              0.02      							
Peak concurrent requests:                2.00      							
----------------End-to-end Latency----------------							
Mean E2EL (ms):                          54511.35  							
Median E2EL (ms):                        64923.70  							
P99 E2EL (ms):                           77553.20  							
================== Text Result ===================							
Total input tokens:                      25000     							
Total generated tokens:                  9000      							
Output token throughput (tok/s):         16.51     							
Peak output token throughput (tok/s):    79.00     							
Peak concurrent requests:                2.00      							
Total Token throughput (tok/s):          62.37     							
---------------Time to First Token----------------							
Mean TTFT (ms):                          285.83    							
Median TTFT (ms):                        284.50    							
P99 TTFT (ms):                           313.60    							
-----Time per Output Token (excl. 1st token)------							
Mean TPOT (ms):                          13.52     							
Median TPOT (ms):                        13.50     							
P99 TPOT (ms):                           13.81     							
---------------Inter-token Latency----------------							
Mean ITL (ms):                           13.54     							
Median ITL (ms):                         13.00     							
P99 ITL (ms):                            58.18     							
================== Audio Result ==================							
Total audio duration generated(s):       1801.30   							
Total audio frames generated:            43231140  							
Audio throughput(audio duration/s):      3.30      							
---------------Time to First Packet---------------							
Mean AUDIO_TTFP (ms):                    1116.31   							
Median AUDIO_TTFP (ms):                  1115.46   							
P99 AUDIO_TTFP (ms):                     1153.88   							
-----------------Real Time Factor-----------------							
Mean AUDIO_RTF:                          0.30      							
Median AUDIO_RTF:                        0.30      							
P99 AUDIO_RTF:                           0.32      							
==================================================							

============ Serving Benchmark Result ============							
Successful requests:                     10        							
Failed requests:                         0         							
Maximum request concurrency:             4         							
Benchmark duration (s):                  144.65    							
Request throughput (req/s):              0.07      							
Peak concurrent requests:                6.00      							
----------------End-to-end Latency----------------							
Mean E2EL (ms):                          57835.79  							
Median E2EL (ms):                        51807.38  							
P99 E2EL (ms):                           97968.98  							
================== Text Result ===================							
Total input tokens:                      25000     							
Total generated tokens:                  9000      							
Output token throughput (tok/s):         62.22     							
Peak output token throughput (tok/s):    175.00    							
Peak concurrent requests:                6.00      							
Total Token throughput (tok/s):          235.05    							
---------------Time to First Token----------------							
Mean TTFT (ms):                          715.86    							
Median TTFT (ms):                        520.59    							
P99 TTFT (ms):                           1599.19   							
-----Time per Output Token (excl. 1st token)------							
Mean TPOT (ms):                          22.98     							
Median TPOT (ms):                        24.36     							
P99 TPOT (ms):                           26.51     							
---------------Inter-token Latency----------------							
Mean ITL (ms):                           23.23     							
Median ITL (ms):                         20.59     							
P99 ITL (ms):                            83.59     							
================== Audio Result ==================							
Total audio duration generated(s):       1361.32   							
Total audio frames generated:            32671770  							
Audio throughput(audio duration/s):      9.41      							
---------------Time to First Packet---------------							
Mean AUDIO_TTFP (ms):                    2160.37   							
Median AUDIO_TTFP (ms):                  1878.93   							
P99 AUDIO_TTFP (ms):                     3092.25   							
-----------------Real Time Factor-----------------							
Mean AUDIO_RTF:                          0.42      							
Median AUDIO_RTF:                        0.40      							
P99 AUDIO_RTF:                           0.51      							
==================================================							

============ Serving Benchmark Result ============							
Successful requests:                     10        							
Failed requests:                         0         							
Maximum request concurrency:             10        							
Benchmark duration (s):                  175.13    							
Request throughput (req/s):              0.06      							
Peak concurrent requests:                10.00     							
----------------End-to-end Latency----------------							
Mean E2EL (ms):                          118687.95 							
Median E2EL (ms):                        112462.90 							
P99 E2EL (ms):                           175094.62 							
================== Text Result ===================							
Total input tokens:                      25000     							
Total generated tokens:                  9000      							
Output token throughput (tok/s):         51.39     							
Peak output token throughput (tok/s):    226.00    							
Peak concurrent requests:                10.00     							
Total Token throughput (tok/s):          194.14    							
---------------Time to First Token----------------							
Mean TTFT (ms):                          6867.34   							
Median TTFT (ms):                        7580.77   							
P99 TTFT (ms):                           7582.99   							
-----Time per Output Token (excl. 1st token)------							
Mean TPOT (ms):                          55.75     							
Median TPOT (ms):                        55.11     							
P99 TPOT (ms):                           61.50     							
---------------Inter-token Latency----------------							
Mean ITL (ms):                           56.18     							
Median ITL (ms):                         51.04     							
P99 ITL (ms):                            190.60    							
================== Audio Result ==================							
Total audio duration generated(s):       1820.10   							
Total audio frames generated:            43682295  							
Audio throughput(audio duration/s):      10.39     							
---------------Time to First Packet---------------							
Mean AUDIO_TTFP (ms):                    10537.64  							
Median AUDIO_TTFP (ms):                  10551.63  							
P99 AUDIO_TTFP (ms):                     10711.51  							
-----------------Real Time Factor-----------------							
Mean AUDIO_RTF:                          0.74      							
Median AUDIO_RTF:                        0.71      							
P99 AUDIO_RTF:                           1.11      							
==================================================							

Before
============ Serving Benchmark Result ============							
Successful requests:                     10        							
Failed requests:                         0         							
Maximum request concurrency:             1         							
Benchmark duration (s):                  619.72    							
Request throughput (req/s):              0.02      							
Peak concurrent requests:                2.00      							
----------------End-to-end Latency----------------							
Mean E2EL (ms):                          61969.98  							
Median E2EL (ms):                        73588.89  							
P99 E2EL (ms):                           86040.26  							
================== Text Result ===================							
Total input tokens:                      25000     							
Total generated tokens:                  9000      							
Output token throughput (tok/s):         14.52     							
Peak output token throughput (tok/s):    86.00     							
Peak concurrent requests:                2.00      							
Total Token throughput (tok/s):          54.86     							
---------------Time to First Token----------------							
Mean TTFT (ms):                          271.47    							
Median TTFT (ms):                        270.52    							
P99 TTFT (ms):                           296.50    							
-----Time per Output Token (excl. 1st token)------							
Mean TPOT (ms):                          13.53     							
Median TPOT (ms):                        13.54     							
P99 TPOT (ms):                           13.74     							
---------------Inter-token Latency----------------							
Mean ITL (ms):                           13.60     							
Median ITL (ms):                         12.98     							
P99 ITL (ms):                            60.96     							
================== Audio Result ==================							
Total audio duration generated(s):       1909.22   							
Total audio frames generated:            45821415  							
Audio throughput(audio duration/s):      3.08      							
---------------Time to First Packet---------------							
Mean AUDIO_TTFP (ms):                    1146.93   							
Median AUDIO_TTFP (ms):                  1145.37   							
P99 AUDIO_TTFP (ms):                     1171.91   							
-----------------Real Time Factor-----------------							
Mean AUDIO_RTF:                          0.32      							
Median AUDIO_RTF:                        0.32      							
P99 AUDIO_RTF:                           0.33      							
==================================================							


============ Serving Benchmark Result ============							
Successful requests:                     10        							
Failed requests:                         0         							
Maximum request concurrency:             4         							
Benchmark duration (s):                  185.32    							
Request throughput (req/s):              0.05      							
Peak concurrent requests:                5.00      							
----------------End-to-end Latency----------------							
Mean E2EL (ms):                          61345.65  							
Median E2EL (ms):                        64988.48  							
P99 E2EL (ms):                           119090.31 							
================== Text Result ===================							
Total input tokens:                      25000     							
Total generated tokens:                  9000      							
Output token throughput (tok/s):         48.56     							
Peak output token throughput (tok/s):    176.00    							
Peak concurrent requests:                5.00      							
Total Token throughput (tok/s):          183.47    							
---------------Time to First Token----------------							
Mean TTFT (ms):                          721.36    							
Median TTFT (ms):                        331.18    							
P99 TTFT (ms):                           1703.88   							
-----Time per Output Token (excl. 1st token)------							
Mean TPOT (ms):                          19.30     							
Median TPOT (ms):                        16.97     							
P99 TPOT (ms):                           26.55     							
---------------Inter-token Latency----------------							
Mean ITL (ms):                           19.44     							
Median ITL (ms):                         15.35     							
P99 ITL (ms):                            90.46     							
================== Audio Result ==================							
Total audio duration generated(s):       1446.90   							
Total audio frames generated:            34725600  							
Audio throughput(audio duration/s):      7.81      							
---------------Time to First Packet---------------							
Mean AUDIO_TTFP (ms):                    2144.88   							
Median AUDIO_TTFP (ms):                  1588.08   							
P99 AUDIO_TTFP (ms):                     3250.21   							
-----------------Real Time Factor-----------------							
Mean AUDIO_RTF:                          0.43      							
Median AUDIO_RTF:                        0.43      							
P99 AUDIO_RTF:                           0.48      							
==================================================							


============ Serving Benchmark Result ============							
Successful requests:                     10        							
Failed requests:                         0         							
Maximum request concurrency:             10        							
Benchmark duration (s):                  213.86    							
Request throughput (req/s):              0.05      							
Peak concurrent requests:                10.00     							
----------------End-to-end Latency----------------							
Mean E2EL (ms):                          126101.09 							
Median E2EL (ms):                        104143.77 							
P99 E2EL (ms):                           213629.02 							
================== Text Result ===================							
Total input tokens:                      25000     							
Total generated tokens:                  9000      							
Output token throughput (tok/s):         42.08     							
Peak output token throughput (tok/s):    323.00    							
Peak concurrent requests:                10.00     							
Total Token throughput (tok/s):          158.98    							
---------------Time to First Token----------------							
Mean TTFT (ms):                          9209.13   							
Median TTFT (ms):                        10168.94  							
P99 TTFT (ms):                           10174.43  							
-----Time per Output Token (excl. 1st token)------							
Mean TPOT (ms):                          55.94     							
Median TPOT (ms):                        54.96     							
P99 TPOT (ms):                           64.21     							
---------------Inter-token Latency----------------							
Mean ITL (ms):                           56.26     							
Median ITL (ms):                         51.60     							
P99 ITL (ms):                            182.38    							
================== Audio Result ==================							
Total audio duration generated(s):       1596.91   							
Total audio frames generated:            38325960  							
Audio throughput(audio duration/s):      7.47      							
---------------Time to First Packet---------------							
Mean AUDIO_TTFP (ms):                    20780.58  							
Median AUDIO_TTFP (ms):                  20989.71  							
P99 AUDIO_TTFP (ms):                     22899.67  							
-----------------Real Time Factor-----------------							
Mean AUDIO_RTF:                          0.83      							
Median AUDIO_RTF:                        0.84      							
P99 AUDIO_RTF:                           0.98      							
==================================================							


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

LJH-LBJ added 4 commits March 9, 2026 11:45
Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
@LJH-LBJ LJH-LBJ requested a review from hsliuustc0106 as a code owner March 9, 2026 12:27
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

have you testes the L3(PR Merge) pipeline locally for qwen3-omni accuracy?

@LJH-LBJ
Copy link
Copy Markdown
Contributor Author

LJH-LBJ commented Mar 10, 2026

have you testes the L3(PR Merge) pipeline locally for qwen3-omni accuracy?

I have already tested it. Posted it in the Test Result.

LJH-LBJ and others added 4 commits March 10, 2026 10:27
Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
…H-LBJ/vllm-omni into qwen3-omni-decode-performance

Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
@hsliuustc0106 hsliuustc0106 merged commit 2ee4a07 into vllm-project:main Mar 10, 2026
6 of 7 checks passed
lishunyang12 pushed a commit to lishunyang12/vllm-omni that referenced this pull request Mar 11, 2026
…d eliminate decode hot-path CPU round-trips (vllm-project#1758)

Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
@Gaohan123 Gaohan123 added this to the v0.18.0 milestone Mar 11, 2026
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

I have to revert this PR since it introduce CI acc breakdown

LJH-LBJ added a commit to LJH-LBJ/vllm-omni that referenced this pull request Mar 19, 2026
Re-apply PR vllm-project#1758 (code predictor re-prefill + SDPA + eliminate decode
hot-path CPU round-trips) with the following critical bug fixes:

Bug fixes:
  - proj_buf: allocate locally each forward() call to prevent
    cross-request data pollution under concurrent requests
    (was persistent self._proj_buf shared across calls)
  - summed_embeddings: reshape 3D [B,S,H] to 2D [B*S,H] before
    adding text_step [B*S,H] to avoid silent broadcasting bug
    when batch_size > 1
  - torch.compile: restore torch.compile on inner 5-layer transformer
    (mode=default, dynamic=True) to reduce BF16 intermediate
    round-trip precision loss across the 31-step AR loop

Improvements over original PR vllm-project#1758:
  - SDPA with native GQA via enable_gqa (matching TTS code predictor)
  - Inline top-k + top-p sampling, removing custom op overhead
  - GPU-resident boolean mask for token suppression (no CPU roundtrip)
  - Cleaner code structure aligned with TTS code predictor pattern

Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
LJH-LBJ added a commit to LJH-LBJ/vllm-omni that referenced this pull request Mar 19, 2026
…curacy fixes

Re-apply PR vllm-project#1758 optimizations with bug fixes:

1. Code predictor rewritten with re-prefill (no KV cache), SDPA attention
   with native GQA (enable_gqa), and inline top-k + top-p sampling.

2. Eliminate decode hot-path CPU round-trips: gpu_resident_buffer_keys
   keeps last_talker_hidden/trailing_text_hidden/tts_pad_embed_projected
   on GPU via detach().clone().

3. Bug fixes over original PR vllm-project#1758:
   - Per-call proj_buf allocation to avoid cross-request buffer aliasing
     in concurrent batch inference.
   - summed_embeddings reshape 3D->2D before adding text_step to prevent
     silent broadcasting when batch_size > 1.
   - Restore torch.compile on inner 5-layer transformer (mode=default,
     dynamic=True) to reduce BF16 intermediate round-trip truncation.

Signed-off-by: Lucas <lucas@example.com>
Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
zhangj1an pushed a commit to zhangj1an/vllm-omni that referenced this pull request Mar 26, 2026
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
…d eliminate decode hot-path CPU round-trips (vllm-project#1758)

Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants