[Refactor][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips#1758
Merged
hsliuustc0106 merged 10 commits intoMar 10, 2026
Conversation
Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
Collaborator
|
have you testes the L3(PR Merge) pipeline locally for qwen3-omni accuracy? |
Contributor
Author
I have already tested it. Posted it in the Test Result. |
…H-LBJ/vllm-omni into qwen3-omni-decode-performance Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
lishunyang12
pushed a commit
to lishunyang12/vllm-omni
that referenced
this pull request
Mar 11, 2026
…d eliminate decode hot-path CPU round-trips (vllm-project#1758) Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com> Signed-off-by: lishunyang <lishunyang12@163.com>
Collaborator
|
I have to revert this PR since it introduce CI acc breakdown |
LJH-LBJ
added a commit
to LJH-LBJ/vllm-omni
that referenced
this pull request
Mar 19, 2026
Re-apply PR vllm-project#1758 (code predictor re-prefill + SDPA + eliminate decode hot-path CPU round-trips) with the following critical bug fixes: Bug fixes: - proj_buf: allocate locally each forward() call to prevent cross-request data pollution under concurrent requests (was persistent self._proj_buf shared across calls) - summed_embeddings: reshape 3D [B,S,H] to 2D [B*S,H] before adding text_step [B*S,H] to avoid silent broadcasting bug when batch_size > 1 - torch.compile: restore torch.compile on inner 5-layer transformer (mode=default, dynamic=True) to reduce BF16 intermediate round-trip precision loss across the 31-step AR loop Improvements over original PR vllm-project#1758: - SDPA with native GQA via enable_gqa (matching TTS code predictor) - Inline top-k + top-p sampling, removing custom op overhead - GPU-resident boolean mask for token suppression (no CPU roundtrip) - Cleaner code structure aligned with TTS code predictor pattern Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
LJH-LBJ
added a commit
to LJH-LBJ/vllm-omni
that referenced
this pull request
Mar 19, 2026
…curacy fixes Re-apply PR vllm-project#1758 optimizations with bug fixes: 1. Code predictor rewritten with re-prefill (no KV cache), SDPA attention with native GQA (enable_gqa), and inline top-k + top-p sampling. 2. Eliminate decode hot-path CPU round-trips: gpu_resident_buffer_keys keeps last_talker_hidden/trailing_text_hidden/tts_pad_embed_projected on GPU via detach().clone(). 3. Bug fixes over original PR vllm-project#1758: - Per-call proj_buf allocation to avoid cross-request buffer aliasing in concurrent batch inference. - summed_embeddings reshape 3D->2D before adding text_step to prevent silent broadcasting when batch_size > 1. - Restore torch.compile on inner 5-layer transformer (mode=default, dynamic=True) to reduce BF16 intermediate round-trip truncation. Signed-off-by: Lucas <lucas@example.com> Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
5 tasks
1 task
Gaohan123
pushed a commit
that referenced
this pull request
Mar 26, 2026
zhangj1an
pushed a commit
to zhangj1an/vllm-omni
that referenced
this pull request
Mar 26, 2026
…llm-project#1758 (vllm-project#2019) Signed-off-by: Zhang <jianmusings@gmail.com>
zhangj1an
pushed a commit
to zhangj1an/vllm-omni
that referenced
this pull request
Mar 26, 2026
zhangj1an
pushed a commit
to zhangj1an/vllm-omni
that referenced
this pull request
Mar 26, 2026
vraiti
pushed a commit
to vraiti/vllm-omni
that referenced
this pull request
Apr 9, 2026
lengrongfu
pushed a commit
to lengrongfu/vllm-omni
that referenced
this pull request
May 1, 2026
clodaghwalsh17
pushed a commit
to clodaghwalsh17/nm-vllm-omni-ent
that referenced
this pull request
May 12, 2026
…d eliminate decode hot-path CPU round-trips (vllm-project#1758) Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
clodaghwalsh17
pushed a commit
to clodaghwalsh17/nm-vllm-omni-ent
that referenced
this pull request
May 12, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
This PR delivers two groups of performance optimizations targeting per-token decode latency for Qwen3-Omni generation:
Motivation
Profiling the Qwen3-Omni decode loop revealed two categories of overhead:
model.generate()under the hood, which introduced heavy Python-level dispatch, dynamic memory allocation, and KV cache management for very short sequences (2→33 tokens).last_talker_hidden,trailing_text_hidden,tts_pad_embed_projected, suppressed token logits) were round-tripped GPU→CPU→GPU on every decode step, adding ~0.15-0.35ms per step × ~900 steps = 135-315ms total.Changes
1. Code Predictor MTP Rewrite (
qwen3_omni_moe_code_predictor_mtp.py)Qwen3MoeAttentionwith manual mask constructionF.scaled_dot_product_attentionwith native GQA viaenable_gqa=True_proj_buf,_pos_ids) lazily initialized, zero runtime allocationtorch.topk→masked_fill(-inf)→multinomial, no custom opsnn.ModuleListtraversal_lm_heads,_codec_embeds) bypassModuleList.__getitem__overheadtorch.compileON by default (fullgraph=True, mode="max-autotune-no-cudagraphs")@torch.inference_mode()to skip autograd bookkeeping2. Talker Single-Loop Refactor (
qwen3_omni_moe_talker.py)code_predictor_forwardwith a single unified loop that calls the rewritten MTP predictor once per group.3. Decode Hot-Path GPU Buffer Optimizations (
qwen3_omni.py)suppressed_tokenslogit maskinglogits.cpu()→ scatter →.to(device)torch.boolGPU mask +masked_fill_(-inf)last_talker_hidden.detach().to("cpu").contiguous()every step.detach()— keep on GPUtrailing_text_hidden.detach().to("cpu").contiguous()every step.detach()— keep on GPUtts_pad_embed_projected.detach().to("cpu").contiguous()every step.detach()— keep on GPUtalker_mtpclones.clone()oninputs_embedsandsummed_embeddings4. Runner GPU-Resident Buffer Mechanism (
gpu_model_runner.py)gpu_resident_buffer_keysprotocol: models declare aset[str]of buffer keys that should remain on GPU instead of being offloaded to CPU._update_intermediate_buffer, keys in this set usev.detach().clone()instead ofv.detach().to("cpu").contiguous()..clone()is required (not just.detach()) because CUDA Graph replay reuses the same GPU memory addresses — a bare.detach()returns a view that gets overwritten on the next forward pass.Files Changed
qwen3_omni_moe_code_predictor_mtp.pyqwen3_omni_moe_talker.pycode_predictor_forwardqwen3_omni.pygpu_model_runner.pygpu_resident_buffer_keysmechanism in_update_intermediate_bufferTest Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model. Please runmkdocs serveto sync the documentation editions to./docs.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)