Skip to content

[BugFix][Qwen3TTS] CodePredictor CudaGraph Pool#2059

Merged
Gaohan123 merged 4 commits into
vllm-project:mainfrom
JuanPZuluaga:fix/code-predictor-cuda-graph-pool
Mar 21, 2026
Merged

[BugFix][Qwen3TTS] CodePredictor CudaGraph Pool#2059
Gaohan123 merged 4 commits into
vllm-project:mainfrom
JuanPZuluaga:fix/code-predictor-cuda-graph-pool

Conversation

@JuanPZuluaga
Copy link
Copy Markdown
Contributor

@JuanPZuluaga JuanPZuluaga commented Mar 21, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

In #1913 we switched the CodePredictor's torch.compile from mode="default" to mode="reduce-overhead" with dynamic=False and batch-size bucketing. This gave us quite a noticeble speed up (~30% better throughput) in high-concurrency cenarios, by letting Inductor capture its own internal CUDA graphs.

However, reduce-overhead creates a private CUDA graph memory pool (CUDAGraphTreeManager) that conflicts with vLLM's own CUDA graph pool. At batch sizes >= 16 (max_num_batched_tokens >= 2048), the two pools compete for GPU memory and cause cudaErrorIllegalAddress crashes. This was somehow signaled by @Sy0307 (#1913 (comment); sorry for not testing more thoroughly)

in this PR, we replace reduce-overhead with mode="default" again, but we add manual CUDA graph capture per bucket using the same pattern as vllm/compilation/cuda_graph.py (used in other parts of vllm-omni). This gives us both Inductor kernel fusion and CUDA graph replay, with all graphs sharing vLLM's single memory pool. No pool collision, no crash, same performance.

For me, I noticed the cudaErrorIllegalAddress with:

- stage_id: 0 
  - max_num_seqs: 16
  - max_num_batched_tokens: 2048

EDIT: minor update of the config yaml for benchmark of qwen3-tts.

Test Plan

Run benchmark with bs=8 and bs=32, concurrency 1 and 8, comparing three configurations: reduce-overhead (PR #1913), default-only (no CUDA graphs), and default+cudagraph (this PR).

Test Result

bs=32

Metric Concurrency default+cudagraph default-only reduce-overhead
Med TTFP (ms) 1 35.19 46.00 35.53
8 106.72 128.65 104.50
Med E2E (ms) 1 799.28 1169.25 809.77
8 1371.69 1558.34 1368.73
Mean RTF 1 0.14 0.21 0.14
8 0.24 0.28 0.25
Req/s 1 1.23 0.84 1.21
8 5.34 4.59 5.41
Audio tput (s/s) 1 7.11 4.84 7.01
8 30.78 26.15 30.86

bs=8

Metric Concurrency default+cudagraph default-only reduce-overhead
Med TTFP (ms) 1 41.98 54.05 41.90
8 180.63 373.76 204.52
Med E2E (ms) 1 802.98 1178.78 813.52
8 1291.19 1557.18 1318.93
Mean RTF 1 0.14 0.21 0.14
8 0.23 0.28 0.23
Req/s 1 1.24 0.84 1.22
8 5.75 4.69 5.66
Audio tput (s/s) 1 7.06 4.83 6.96
8 32.34 26.36 32.20

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

JuanPZuluaga added 3 commits March 21, 2026 07:17
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: fabee3fb0e

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

enable_prefix_caching: false
engine_output_type: audio
gpu_memory_utilization: 0.2
gpu_memory_utilization: 0.3
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep stage-1 memory override distinct in bs1 benchmark config

run_benchmark.sh rewrites gpu_memory_utilization: 0.3 to GPU_MEM_TALKER and 0.2 to GPU_MEM_CODE2WAV (benchmarks/qwen3-tts/run_benchmark.sh:89-92). After changing stage 1 here from 0.2 to 0.3, the default qwen3_tts_bs1.yaml path now gives both stages the talker budget and silently ignores GPU_MEM_CODE2WAV. Users who rely on the documented 0.3/0.2 split or tune code2wav separately can end up with an unexpected config and avoidable OOMs.

Useful? React with 👍 / 👎.

@Sy0307
Copy link
Copy Markdown
Contributor

Sy0307 commented Mar 21, 2026

Thanks for fix. I will verify it later and check possible issue more carefully. Also cc @linyueqian

@Gaohan123
Copy link
Copy Markdown
Collaborator

Hello, could you please check if your PR can solve the CI failure? https://buildkite.com/vllm/vllm-omni/builds/4642/steps/canvas?sid=019d0fb4-f422-4858-b196-a9d1ec25fcaf&tab=output

@Gaohan123 Gaohan123 added the ready label to trigger buildkite CI label Mar 21, 2026
@Gaohan123 Gaohan123 added this to the v0.18.0 milestone Mar 21, 2026
Copy link
Copy Markdown
Collaborator

@Gaohan123 Gaohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@Gaohan123 Gaohan123 merged commit 8e12007 into vllm-project:main Mar 21, 2026
7 of 8 checks passed
@JuanPZuluaga JuanPZuluaga deleted the fix/code-predictor-cuda-graph-pool branch March 21, 2026 11:37
hsliuustc0106 added a commit to hsliuustc0106/vllm-omni-skills that referenced this pull request Mar 22, 2026
### vllm-omni-audio-tts
- Source: [PR #2059](vllm-project/vllm-omni#2059) - [BugFix][Qwen3TTS] CodePredictor CudaGraph Pool
- Changes:
  - Bug fix: [BugFix][Qwen3TTS] CodePredictor CudaGraph Pool

### vllm-omni-perf
- Source: [PR #2059](vllm-project/vllm-omni#2059) - [BugFix][Qwen3TTS] CodePredictor CudaGraph Pool
- Changes:
  - Bug fix: [BugFix][Qwen3TTS] CodePredictor CudaGraph Pool

### vllm-omni-api
- Source: [PR #2058](vllm-project/vllm-omni#2058) - [Bugfix] Fix Fish Speech and CosyVoice3 online serving - missing is_comprehension and broken model detection
- Changes:
  - Bug fix: [Bugfix] Fix Fish Speech and CosyVoice3 online serving - missing is_comprehension and broken model detection

### vllm-omni-contrib
- Source: [PR #2045](vllm-project/vllm-omni#2045) - [Voxtral] Improve example

### vllm-omni-cicd
- Source: [PR #2045](vllm-project/vllm-omni#2045) - [Voxtral] Improve example

### vllm-omni-api
- Source: [PR #2042](vllm-project/vllm-omni#2042) - [bugfix] /chat/completion doesn't read extra_body for diffusion model
- Changes:
  - Bug fix: [bugfix] /chat/completion doesn't read extra_body for diffusion model

### vllm-omni-perf
- Source: [PR #2042](vllm-project/vllm-omni#2042) - [bugfix] /chat/completion doesn't read extra_body for diffusion model
- Changes:
  - Bug fix: [bugfix] /chat/completion doesn't read extra_body for diffusion model

### vllm-omni-contrib
- Source: [PR #2038](vllm-project/vllm-omni#2038) - [Doc] Update docs and dockerfiles for rebase of vllm v0.18.0

### vllm-omni-serving
- Source: [PR #2037](vllm-project/vllm-omni#2037) - [Rebase] Rebase to vllm v0.18.0

### vllm-omni-contrib
- Source: [PR #2037](vllm-project/vllm-omni#2037) - [Rebase] Rebase to vllm v0.18.0

### vllm-omni-api
- Source: [PR #2037](vllm-project/vllm-omni#2037) - [Rebase] Rebase to vllm v0.18.0

### vllm-omni-cicd
- Source: [PR #2037](vllm-project/vllm-omni#2037) - [Rebase] Rebase to vllm v0.18.0

### vllm-omni-cicd
- Source: [PR #2032](vllm-project/vllm-omni#2032) - [CI] Change Bagel online test environment variable `VLLM_TEST_CLEAN_GPU_MEMORY` to `0`

### vllm-omni-cicd
- Source: [PR #2031](vllm-project/vllm-omni#2031) - [CI] Fix test.
- Changes:
  - Bug fix: [CI] Fix test.

### vllm-omni-cicd
- Source: [PR #2017](vllm-project/vllm-omni#2017) - [CI] [ROCm] Setup `test-ready.yml` and `test-merge.yml`

### vllm-omni-cicd
- Source: [PR #2014](vllm-project/vllm-omni#2014) - [Test] Implement mock HTTP request handling in benchmark CLI tests

### vllm-omni-perf
- Source: [PR #2014](vllm-project/vllm-omni#2014) - [Test] Implement mock HTTP request handling in benchmark CLI tests

### vllm-omni-serving
- Source: [PR #2012](vllm-project/vllm-omni#2012) - [Fixbug][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips
- Changes:
  - Bug fix: [Fixbug][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips

### vllm-omni-image-gen
- Source: [PR #2012](vllm-project/vllm-omni#2012) - [Fixbug][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips
- Changes:
  - Bug fix: [Fixbug][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips

### vllm-omni-perf
- Source: [PR #2012](vllm-project/vllm-omni#2012) - [Fixbug][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips
- Changes:
  - Bug fix: [Fixbug][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips

### vllm-omni-serving
- Source: [PR #2009](vllm-project/vllm-omni#2009) - [Bugfix] revert PR#1758 which introduced the accuracy problem of qwen3-omni
- Changes:
  - Bug fix: [Bugfix] revert PR#1758 which introduced the accuracy problem of qwen3-omni

### vllm-omni-image-gen
- Source: [PR #2007](vllm-project/vllm-omni#2007) - [Bugfix]Fix bug of online server can not return mutli images
- Changes:
  - Bug fix: [Bugfix]Fix bug of online server can not return mutli images
- Additions:
  - Qwen-Image-Layered
  - Qwen-Image-Layered
  - Qwen-Image-Layered

### vllm-omni-api
- Source: [PR #2007](vllm-project/vllm-omni#2007) - [Bugfix]Fix bug of online server can not return mutli images
- Changes:
  - Bug fix: [Bugfix]Fix bug of online server can not return mutli images

### vllm-omni-cicd
- Source: [PR #1998](vllm-project/vllm-omni#1998) - [CI] Split BAGEL tests into dummy/real weight tiers (L2/L3)

### vllm-omni-serving
- Source: [PR #1985](vllm-project/vllm-omni#1985) - [Perf] [Qwen3-TTS] Keep audio_codes and last_talker_hidden on GPU to eliminate per-step sync stalls
- Changes:
  - Performance improvement: [Perf] [Qwen3-TTS] Keep audio_codes and last_talker_hidden on GPU to eliminate per-step sync stalls

### vllm-omni-audio-tts
- Source: [PR #1985](vllm-project/vllm-omni#1985) - [Perf] [Qwen3-TTS] Keep audio_codes and last_talker_hidden on GPU to eliminate per-step sync stalls
- Changes:
  - Performance improvement: [Perf] [Qwen3-TTS] Keep audio_codes and last_talker_hidden on GPU to eliminate per-step sync stalls

### vllm-omni-perf
- Source: [PR #1985](vllm-project/vllm-omni#1985) - [Perf] [Qwen3-TTS] Keep audio_codes and last_talker_hidden on GPU to eliminate per-step sync stalls
- Changes:
  - Performance improvement: [Perf] [Qwen3-TTS] Keep audio_codes and last_talker_hidden on GPU to eliminate per-step sync stalls

### vllm-omni-serving
- Source: [PR #1984](vllm-project/vllm-omni#1984) - [CI] [ROCm] Bugfix device environment issue
- Changes:
  - Bug fix: [CI] [ROCm] Bugfix device environment issue

### vllm-omni-api
- Source: [PR #1984](vllm-project/vllm-omni#1984) - [CI] [ROCm] Bugfix device environment issue
- Changes:
  - Bug fix: [CI] [ROCm] Bugfix device environment issue

### vllm-omni-serving
- Source: [PR #1982](vllm-project/vllm-omni#1982) - [Fix] Fix slow hasattr in CUDAGraphWrapper.__getattr__
- Changes:
  - Bug fix: [Fix] Fix slow hasattr in CUDAGraphWrapper.__getattr__

### vllm-omni-cicd
- Source: [PR #1982](vllm-project/vllm-omni#1982) - [Fix] Fix slow hasattr in CUDAGraphWrapper.__getattr__
- Changes:
  - Bug fix: [Fix] Fix slow hasattr in CUDAGraphWrapper.__getattr__

### vllm-omni-api
- Source: [PR #1979](vllm-project/vllm-omni#1979) - [Bugfix] Fix config misalignment between offline and online diffusion inference (Wan2.2, Qwen-Image series)
- Changes:
  - Bug fix: [Bugfix] Fix config misalignment between offline and online diffusion inference (Wan2.2, Qwen-Image series)
- Additions:
  - `/v1/chat/completions`

### vllm-omni-perf
- Source: [PR #1979](vllm-project/vllm-omni#1979) - [Bugfix] Fix config misalignment between offline and online diffusion inference (Wan2.2, Qwen-Image series)
- Changes:
  - Bug fix: [Bugfix] Fix config misalignment between offline and online diffusion inference (Wan2.2, Qwen-Image series)

### vllm-omni-contrib
- Source: [PR #1976](vllm-project/vllm-omni#1976) - [skip ci][Docs] Update WeChat QR code (fix filename case)
- Changes:
  - Bug fix: [skip ci][Docs] Update WeChat QR code (fix filename case)

### vllm-omni-contrib
- Source: [PR #1974](vllm-project/vllm-omni#1974) - [Docs] Update WeChat QR code for community support

### vllm-omni-cicd
- Source: [PR #1945](vllm-project/vllm-omni#1945) - Fix Base voice clone streaming quality and stop-token crash
- Changes:
  - Bug fix: Fix Base voice clone streaming quality and stop-token crash

### vllm-omni-cicd
- Source: [PR #1938](vllm-project/vllm-omni#1938) - [Test] L4 complete diffusion feature test for Bagel models
- Changes:
  - New feature: [Test] L4 complete diffusion feature test for Bagel models

### vllm-omni-perf
- Source: [PR #1938](vllm-project/vllm-omni#1938) - [Test] L4 complete diffusion feature test for Bagel models
- Changes:
  - New feature: [Test] L4 complete diffusion feature test for Bagel models

### vllm-omni-perf
- Source: [PR #1934](vllm-project/vllm-omni#1934) - Fix OmniGen2 transformer config loading for HF models
- Changes:
  - Bug fix: Fix OmniGen2 transformer config loading for HF models

### vllm-omni-audio-tts
- Source: [PR #1930](vllm-project/vllm-omni#1930) - [Bug][Qwen3TTS][Streaming] remove dynamic initial chunk and only compute on initial request

### vllm-omni-perf
- Source: [PR #1930](vllm-project/vllm-omni#1930) - [Bug][Qwen3TTS][Streaming] remove dynamic initial chunk and only compute on initial request

### vllm-omni-audio-tts
- Source: [PR #1926](vllm-project/vllm-omni#1926) - [Misc] removed qwen3_tts.py as it is out-dated

### vllm-omni-contrib
- Source: [PR #1920](vllm-project/vllm-omni#1920) - [Docs] Add Wan2.1-T2V as supported video generation models
- Changes:
  - New feature: [Docs] Add Wan2.1-T2V as supported video generation models

### vllm-omni-video-gen
- Source: [PR #1915](vllm-project/vllm-omni#1915) - [Bugfix] fix helios video generate use cpu device
- Changes:
  - Bug fix: [Bugfix] fix helios video generate use cpu device

### vllm-omni-perf
- Source: [PR #1915](vllm-project/vllm-omni#1915) - [Bugfix] fix helios video generate use cpu device
- Changes:
  - Bug fix: [Bugfix] fix helios video generate use cpu device

### vllm-omni-audio-tts
- Source: [PR #1913](vllm-project/vllm-omni#1913) - [Optim][Qwen3TTS][CodePredictor] support torch.compile with reduce-overhead and dynamic False

### vllm-omni-perf
- Source: [PR #1913](vllm-project/vllm-omni#1913) - [Optim][Qwen3TTS][CodePredictor] support torch.compile with reduce-overhead and dynamic False

### vllm-omni-api
- Source: [PR #1908](vllm-project/vllm-omni#1908) - [Entrypoint][Refactor] vLLM-Omni Entrypoint Refactoring

### vllm-omni-perf
- Source: [PR #1908](vllm-project/vllm-omni#1908) - [Entrypoint][Refactor] vLLM-Omni Entrypoint Refactoring

### vllm-omni-contrib
- Source: [PR #1908](vllm-project/vllm-omni#1908) - [Entrypoint][Refactor] vLLM-Omni Entrypoint Refactoring

### vllm-omni-serving
- Source: [PR #1908](vllm-project/vllm-omni#1908) - [Entrypoint][Refactor] vLLM-Omni Entrypoint Refactoring

### vllm-omni-cicd
- Source: [PR #1908](vllm-project/vllm-omni#1908) - [Entrypoint][Refactor] vLLM-Omni Entrypoint Refactoring

### vllm-omni-image-gen
- Source: [PR #1900](vllm-project/vllm-omni#1900) - [Feat] support HSDP for Flux family
- Changes:
  - New feature: [Feat] support HSDP for Flux family

### vllm-omni-contrib
- Source: [PR #1900](vllm-project/vllm-omni#1900) - [Feat] support HSDP for Flux family
- Changes:
  - New feature: [Feat] support HSDP for Flux family

### vllm-omni-distributed
- Source: [PR #1898](vllm-project/vllm-omni#1898) - [Feature]: Remove some useless `hf_overrides` in yaml
- Changes:
  - New feature: [Feature]: Remove some useless `hf_overrides` in yaml

### vllm-omni-quantization
- Source: [PR #1898](vllm-project/vllm-omni#1898) - [Feature]: Remove some useless `hf_overrides` in yaml
- Changes:
  - New feature: [Feature]: Remove some useless `hf_overrides` in yaml

### vllm-omni-cicd
- Source: [PR #1898](vllm-project/vllm-omni#1898) - [Feature]: Remove some useless `hf_overrides` in yaml
- Changes:
  - New feature: [Feature]: Remove some useless `hf_overrides` in yaml

### vllm-omni-perf
- Source: [PR #1898](vllm-project/vllm-omni#1898) - [Feature]: Remove some useless `hf_overrides` in yaml
- Changes:
  - New feature: [Feature]: Remove some useless `hf_overrides` in yaml

### vllm-omni-contrib
- Source: [PR #1890](vllm-project/vllm-omni#1890) - [NPU] Upgrade to v0.17.0

### vllm-omni-contrib
- Source: [PR #1889](vllm-project/vllm-omni#1889) - Add `Governance` section
- Changes:
  - New feature: Add `Governance` section

### vllm-omni-distributed
- Source: [PR #1881](vllm-project/vllm-omni#1881) - [Feat] Support T5 Tensor Parallelism
- Changes:
  - New feature: [Feat] Support T5 Tensor Parallelism

### vllm-omni-cicd
- Source: [PR #1881](vllm-project/vllm-omni#1881) - [Feat] Support T5 Tensor Parallelism
- Changes:
  - New feature: [Feat] Support T5 Tensor Parallelism
JuanPZuluaga added a commit to JuanPZuluaga/vllm-omni that referenced this pull request Mar 31, 2026
… warmup

- torch.compile(dynamic=False) with epilogue_fusion=False
- Power-of-2 batch-size bucket warmup to eliminate recompilation
- Pre-allocated projection buffer (avoids per-call torch.zeros)
- Fixed-shape padded slicing for compiled forward
- Fix double softmax in top-p sampling
- Fix async_chunk audio concat in end2end.py
- Fix hasattr bug in stage_init_utils.py

Same pattern as qwen3_tts_code_predictor_vllm.py (PR vllm-project#2059).

Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Co-authored-by: JuanPZuluaga <juanz9312@gmal.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants