Skip to content

[Feat][Qwen3-Omni] Shared code predictor module for Qwen3-TTS and Qwen3-Omni#2375

Merged
ZeldaHuang merged 28 commits into
vllm-project:mainfrom
JuanPZuluaga:feat/cuda-graph-code-predictor
Apr 15, 2026
Merged

[Feat][Qwen3-Omni] Shared code predictor module for Qwen3-TTS and Qwen3-Omni#2375
ZeldaHuang merged 28 commits into
vllm-project:mainfrom
JuanPZuluaga:feat/cuda-graph-code-predictor

Conversation

@JuanPZuluaga
Copy link
Copy Markdown
Contributor

@JuanPZuluaga JuanPZuluaga commented Mar 31, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

in this PR:

  • unified code predictor logic of Qwen3-TTS and Qwen3-Omni into a shared CodePredictorWrapper base class common/qwen3_code_predictor.py
  • Both model-specific wrappers are now config-only subclasses, can be modified by a CodePredictorWrapperConfig dataclass with 5 behavioral flags: use_cuda_graphs,
    use_parallel_embedding, use_projection, return_proj_buf, sampling_mode
  • it also includes the torch.compile(dynamic=False) flag + CUDA graph capture per power-of-2 batch buckets, with epilogue_fusion=False to preserve float32 precision in RMSNorm/RoPE for audio quality (this was reported in previous PRs)
  • Bugfix in stage_init_utils.py: hasattr returned True for None-valued custom_process_input_func; replaced with getattr(..., None) truthiness check

Test Plan

Test Result

the e2e time is more or less the same overall, but the code-predictor is a bit faster.

main_vs_code-predictor

some audios generated at concurrency=16:

output_0_4178681c-d9ac-423e-a274-8daaf2bd4b64.wav
output_1_fd088db1-2725-4321-9286-cf7d966dfff0.wav


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

… warmup

Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
@amy-why-3459
Copy link
Copy Markdown
Contributor

@LJH-LBJ PTAL

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a few comments. the compile+bucket change looks solid overall — nice that it follows the TTS code predictor pattern.

# Convert to numpy array and ensure correct format
# In async_chunk mode, audio may arrive as a list of chunks
if isinstance(audio_tensor, list):
import torch
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch is already used transitively elsewhere in this file (via audio_tensor.float()). Move the import to the top-level imports instead of burying it inside a conditional.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to the top.

def _ensure_buffers(self, device: torch.device, dtype: torch.dtype, min_bsz: int = 0) -> None:
"""Pre-allocate projection buffer sized to max(max_num_seqs, min_bsz)."""
max_seq = self.num_code_groups + 1
max_bsz = max(self._vllm_config.scheduler_config.max_num_seqs, min_bsz)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The min_bsz parameter is not present in the TTS code predictor version of _ensure_buffers. Is this needed? max_num_seqs should already be the upper bound — if bsz > max_num_seqs something else has gone wrong.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed now.

proj_buf[:bsz, 0:1, :] = last_talker_hidden
proj_buf[:bsz, 1:2, :] = layer0_embed

# Get pre-computed pos_ids for this bucket
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: _setup_compile does warmup internally which can be expensive. Might be worth adding a log line or comment at the call site so someone debugging a slow first-call knows to look there.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect, done

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZeldaHuang
Copy link
Copy Markdown
Collaborator

The implementations of _setup_compile_warmup_buckets _padded_bsz seem to overlap significantly with those in qwen3_tts_code_predictor. We could abstract them into a separated class for future reuse.

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

The implementations of _setup_compile_warmup_buckets _padded_bsz seem to overlap significantly with those in qwen3_tts_code_predictor. We could abstract them into a separated class for future reuse.

@ZeldaHuang that's correct. It's quite overlapped; should I propose a shared module for Qwen3TTS and Qwen3Omni?

@ZeldaHuang
Copy link
Copy Markdown
Collaborator

The implementations of _setup_compile_warmup_buckets _padded_bsz seem to overlap significantly with those in qwen3_tts_code_predictor. We could abstract them into a separated class for future reuse.

@ZeldaHuang that's correct. It's quite overlapped; should I propose a shared module for Qwen3TTS and Qwen3Omni?

You can include it in this PR if it’s not too complicated, and it would be great to add some tests to protect the module as well. Thanks!

@ZeldaHuang
Copy link
Copy Markdown
Collaborator

@JuanPZuluaga Hi, I notice you abstract the whole code predictor model, can you change the PR title?

Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
@ZeldaHuang
Copy link
Copy Markdown
Collaborator

@JuanPZuluaga To speed up the process, it would be better to first land just the torch.compile abstraction in this PR, and leave the rest (modeling, cudagraph support, etc.) for follow-up PRs.

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

JuanPZuluaga commented Apr 11, 2026

Hi, I'll update the body. @ZeldaHuang

@JuanPZuluaga To speed up the process, it would be better to first land just the torch.compile abstraction in this PR, and leave the rest (modeling, cudagraph support, etc.) for follow-up PRs.

the issue is that these optimizations are already done on Qwen3TTS code-predictor model. If we drop them, it would regress Qwen3TTS. The shared module keeps the full stack and gates cudagraph capture behind CodePredictorWrapperConfig.use_cuda_graphs, so Qwen3TTS migrates with use_cuda_graphs=True (byte-identical behavior) and Qwen3Omni can opt in separately. But, if you think it's still fine, i can do it. please let know.

@JuanPZuluaga JuanPZuluaga changed the title [Feat][Qwen3-Omni] Optimize code predictor with torch.compile bucket warmup [Feat][Qwen3-Omni] Shared code predictor module for Qwen3-TTS and Qwen3-Omni Apr 11, 2026
@hsliuustc0106 hsliuustc0106 added the merge-test label to trigger buildkite merge test CI label Apr 11, 2026
@ZeldaHuang
Copy link
Copy Markdown
Collaborator

Hi, I'll update the body. @ZeldaHuang

@JuanPZuluaga To speed up the process, it would be better to first land just the torch.compile abstraction in this PR, and leave the rest (modeling, cudagraph support, etc.) for follow-up PRs.

the issue is that these optimizations are already done on Qwen3TTS code-predictor model. If we drop them, it would regress Qwen3TTS. The shared module keeps the full stack and gates cudagraph capture behind CodePredictorWrapperConfig.use_cuda_graphs, so Qwen3TTS migrates with use_cuda_graphs=True (byte-identical behavior) and Qwen3Omni can opt in separately. But, if you think it's still fine, i can do it. please let know.

It make sense. For this PR, we can focus on resolving the shared module first, while keeping the current CUDA graph capture approach for each code predictor unchanged.

@@ -0,0 +1,654 @@
"""Code Predictor -- optimized re-prefill, no KV cache.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to rename this shared module from CodePredictor to QwenCodePredictor (since other models also use code predictors, such as Fish Speech), or to Qwen3OmniCodePredictor (since it was first introduced in Qwen3Omni)?

@ZeldaHuang
Copy link
Copy Markdown
Collaborator

@JuanPZuluaga Please fix conflicts

Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
@ZeldaHuang
Copy link
Copy Markdown
Collaborator

CI Test Failure: Tensor Shape Mismatch in Code Predictor

The CI test test_mix_to_text_audio_001[omni_server0] is failing with a tensor dimension mismatch error.
Error Message

  RuntimeError: The expanded size of the tensor (5) must match the existing size (8) at non-singleton dimension 0. Target sizes:   [5, 1024]. Tensor sizes: [8, 1024]

Location: vllm_omni/model_executor/models/common/qwen3_code_predictor.py line 537

@JuanPZuluaga PTAL

Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

CI Test Failure: Tensor Shape Mismatch in Code Predictor

The CI test test_mix_to_text_audio_001[omni_server0] is failing with a tensor dimension mismatch error. Error Message

  RuntimeError: The expanded size of the tensor (5) must match the existing size (8) at non-singleton dimension 0. Target sizes:   [5, 1024]. Tensor sizes: [8, 1024]

Location: vllm_omni/model_executor/models/common/qwen3_code_predictor.py line 537

@JuanPZuluaga PTAL

thanks for caching this issue:

the thing was that with the unified _ensure_buffers, which sizes _proj_buf to max_num_seqs (5 in CI), but during server init _capture_talker_mtp_graphs calls the code predictor with CUDA graph capture sizes with powers of 2: 1, 2, 4, 8. this exceeds the max_num_seqs. The original Omni code avoided this when allocating proj_buf fresh each call.

the fix is only _ensure_buffers`` now takes the actual batch size needed instead of reading max_num_seqs` internally, does the buffer grows on demand.

output_wav = os.path.join(output_dir, f"output_{request_id}.wav")

# Convert to numpy array and ensure correct format
# In async_chunk mode, audio may arrive as a list of chunks
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have examples/offline_inference/qwen3_omni/end2end_async_chunk.py to run offline inference with async_chunk enabled

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed this @ZeldaHuang and also modified the PR body with more consistent to what was done. Thanks for the review :)

Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
@ZeldaHuang ZeldaHuang added the ready label to trigger buildkite CI label Apr 15, 2026
@ZeldaHuang ZeldaHuang enabled auto-merge (squash) April 15, 2026 07:14
@ZeldaHuang ZeldaHuang merged commit 82f8c93 into vllm-project:main Apr 15, 2026
8 checks passed
y123456y78 pushed a commit to y123456y78/vllm-omni that referenced this pull request Apr 15, 2026
…n3-Omni (vllm-project#2375)

Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
lvliang-intel pushed a commit to lvliang-intel/vllm-omni that referenced this pull request Apr 20, 2026
…n3-Omni (vllm-project#2375)

Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
lengrongfu pushed a commit to lengrongfu/vllm-omni that referenced this pull request May 1, 2026
…n3-Omni (vllm-project#2375)

Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
…n3-Omni (vllm-project#2375)

Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
@JuanPZuluaga JuanPZuluaga deleted the feat/cuda-graph-code-predictor branch May 17, 2026 09:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge-test label to trigger buildkite merge test CI ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants