Skip to content

[Feat][Qwen3TTS][Code2wav] triton SnakeBeta and Cuda Graph#1797

Merged
hsliuustc0106 merged 34 commits into
vllm-project:mainfrom
JuanPZuluaga:feat/code2wav-batch-cuda-graph
Mar 20, 2026
Merged

[Feat][Qwen3TTS][Code2wav] triton SnakeBeta and Cuda Graph#1797
hsliuustc0106 merged 34 commits into
vllm-project:mainfrom
JuanPZuluaga:feat/code2wav-batch-cuda-graph

Conversation

@JuanPZuluaga
Copy link
Copy Markdown
Contributor

@JuanPZuluaga JuanPZuluaga commented Mar 10, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

The code2wav processes codecs to audio using a conv-based pipeline with 29 SnakeBeta activation layers. Each SnakeBeta executes 4 separate elementwise GPU kernels (exp, sin, pow, add), creating ~116 small kernel launches per forward pass. this makes Code2Wav a small bottleneck under high concurrent load. thus, we introduce two optimizations:

  1. Fused Triton SnakeBeta kernel: replaces 4 elementwise PyTorch ops with a single Triton kernel that reads and writes memory once. Auto-detected at runtime — uses Triton when available on CUDA, falls back to eager PyTorch otherwise. Zero configuration needed.

  2. Smart CUDA graph capture sizes: instead of a hardcoded list [25, 50, 100, 150, 200, 250, 300], capture sizes are computed dynamically from the streaming config (codec_chunk_frames, codec_left_context_frames). This ensures exact graph hits for streaming chunk sizes (e.g., 33 and 58 for c=33/ctx=25) and includes power-of-2 small sizes [2, 4, 8, 16, 32, 64] aligned with the dynamic IC sizing in [feat][Qwen3TTS] Simple dynamic TTFA based on Code2Wav load #1714.

The capture size computation also generates bucket sizes for variable-length last chunks, ensuring high graph hit rate across all decode calls.

WIP: batched decoding.

Test Plan

python -m pytest tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py -v

Test Result

Benchmark Results

Metric Concurrency triton main
TTFP (ms) 1 62.6 69.8
TTFP (ms) 2 91.2 95.5
TTFP (ms) 4 105.0 114.9
TTFP (ms) 6 117.8 147.0
TTFP (ms) 8 138.5 159.7
TTFP (ms) 10 499.5 562.8
E2E (ms) 1 1227.5 1248.4
E2E (ms) 2 1390.2 1447.2
E2E (ms) 4 1564.7 1665.2
E2E (ms) 6 1770.8 2002.2
E2E (ms) 8 1895.9 2236.1
E2E (ms) 10 2300.8 2591.4
RTF 1 0.217 0.221
RTF 2 0.246 0.253
RTF 4 0.274 0.295
RTF 6 0.314 0.356
RTF 8 0.338 0.386
RTF 10 0.407 0.465
Throughput (audio-s/s) 1 4.61 4.53
Throughput (audio-s/s) 2 7.99 7.79
Throughput (audio-s/s) 4 14.00 13.06
Throughput (audio-s/s) 6 17.29 15.92
Throughput (audio-s/s) 8 21.72 19.37
Throughput (audio-s/s) 10 21.61 18.99
Throughput (audio-s/s) 1 4.61 4.53
Throughput (audio-s/s) 2 7.99 7.79
Throughput (audio-s/s) 4 14.00 13.06
Throughput (audio-s/s) 6 17.29 15.92
Throughput (audio-s/s) 8 21.72 19.37
Throughput (audio-s/s) 10 21.61 18.99

Improvement (triton vs main)

Metric Concurrency Improvement
TTFP 1 +10.3%
TTFP 2 +4.6%
TTFP 4 +8.6%
TTFP 6 +19.9%
TTFP 8 +13.3%
TTFP 10 +11.2%
E2E 1 +1.7%
E2E 2 +3.9%
E2E 4 +6.0%
E2E 6 +11.6%
E2E 8 +15.2%
E2E 10 +11.2%
RTF 1 +1.6%
RTF 2 +2.8%
RTF 4 +7.1%
RTF 6 +11.8%
RTF 8 +12.6%
RTF 10 +12.5%
Plot saved to vllm_omni/results/comparison.png

in the plot we can see that the TTFP/E2E latency, and everything is better.

comparison

some audio files:

Config YAML

relevant params changed in YAML:

  - stage_id: 1
    stage_type: llm
    runtime:
      devices: "0"
      max_batch_size: 8
....
    max_inflight: 8
....
        codec_chunk_frames: 33
        codec_left_context_frames: 25
        initial_codec_chunk_frames: 2

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

pablo added 4 commits March 10, 2026 19:29
Signed-off-by: pablo <pablo@agigo.ai>
Signed-off-by: pablo <pablo@agigo.ai>
Signed-off-by: pablo <pablo@agigo.ai>
@JuanPZuluaga JuanPZuluaga changed the title [Feat][Qwen3TTS][Code2wav] cuda graph [Feat][Qwen3TTS][Code2wav] triton SnakeBeta and Cuda Graph Mar 10, 2026
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 2fc84a4135

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +675 to +676
if hidden_states.is_cuda and self._init_triton():
return self._triton_forward(hidden_states)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Add eager fallback around Triton execution failures

The CUDA path now calls _triton_forward whenever _init_triton() returns true, but there is no runtime fallback if Triton kernel compilation or launch fails. In environments where Triton imports successfully but cannot execute (for example unsupported GPU/driver combinations or Triton runtime incompatibilities), this will raise and break decoding instead of preserving the prior eager behavior, so requests can fail entirely rather than degrade gracefully.

Useful? React with 👍 / 👎.

Comment on lines +624 to +628
try:
import triton
import triton.language as tl
except ImportError:
return False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Memoize Triton-unavailable state after import failure

If Triton is not installed, _init_triton() returns False but leaves _triton_kernel as None, so every CUDA forward re-attempts import triton and pays repeated ImportError costs. Because SnakeBeta is called many times per decode, this repeated exception path can materially hurt throughput on non-Triton CUDA deployments; caching a negative detection result would keep fallback overhead to a one-time check.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves Qwen3-TTS decoder inference performance by adding a fused Triton implementation for the SnakeBeta activation and by making CUDA Graph capture sizing more adaptive to streaming/chunking configurations.

Changes:

  • Add a fused Triton kernel path for SnakeBeta (with eager fallback).
  • Extend CUDA graph enable/warmup plumbing to incorporate codec chunk/left-context sizes and compute better capture buckets.
  • Adjust Code2Wav CUDA graph enablement to pass chunk/left-context config directly; simplify per-request decode loop; add targeted tests.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py Adds Triton-accelerated SnakeBeta and extends decoder CUDA-graph enablement parameters.
vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py Passes codec chunk/left-context config into decoder CUDA-graph warmup and simplifies decode iteration.
vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py Adds adaptive compute_capture_sizes, refines warmup/capture behavior and decode fallback checks.
tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py Adds tests for compute_capture_sizes and Triton-vs-eager equivalence for SnakeBeta.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Applies the function to the input elementwise.
SnakeBeta ∶= x + 1/b * sin^2 (xa)
"""
if hidden_states.is_cuda and self._init_triton():
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Triton fast path will run whenever hidden_states is CUDA, even when autograd is enabled / hidden_states.requires_grad=True. Since _triton_forward writes into a freshly allocated tensor and there’s no custom backward, this will silently break gradients. Please gate the Triton path behind not torch.is_grad_enabled() (or not hidden_states.requires_grad) and fall back to _eager_forward when gradients are needed (or implement a proper autograd.Function).

Suggested change
if hidden_states.is_cuda and self._init_triton():
# Use Triton fast path only when gradients are not needed to avoid
# silently breaking autograd. When autograd is enabled, fall back
# to the eager PyTorch implementation, which is fully differentiable.
if hidden_states.is_cuda and not torch.is_grad_enabled() and self._init_triton():

Copilot uses AI. Check for mistakes.
Comment thread vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py Outdated
except ImportError:
return False

@triton.jit
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@linyueqian @tzhouam where should we place triton kernels?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should follow vLLM IR if we plan to introduce many triton kernels. As a light abstraction, can consider CustomOp, but vLLM is deprecating it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think inline is fine for now since it's the only triton kernel in the repo.

JuanPZuluaga added 3 commits March 11, 2026 06:46
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the silent except: pass on the triton path could mask persistent failures

JuanPZuluaga added 2 commits March 11, 2026 07:45
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

most things fixed.

JuanPZuluaga added 4 commits March 11, 2026 12:57
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Comment thread vllm_omni/entrypoints/async_omni.py Outdated
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
@Gaohan123 Gaohan123 added this to the v0.18.0 milestone Mar 12, 2026
@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label Mar 18, 2026
Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gate Status

Check Status
DCO / pre-commit / build
Main CI
AMD CI ❌ (may be unrelated)

Evidence ✅

Comprehensive benchmark data provided:

  • TTFP improvement: 4.6%–19.9%
  • E2E improvement: 1.7%–15.2%
  • Audio samples and comparison plot included

Code Quality ✅

  • Triton fallback is properly handled with logger.warning(..., exc_info=True) + disables future attempts
  • Dynamic capture size computation is clean
  • Test coverage for both features

Prior concern about silent exception handling is addressed. LGTM once AMD CI is investigated (likely unrelated to this PR).

@linyueqian
Copy link
Copy Markdown
Collaborator

I was thinking about whether any of these would help further:

  1. Cache the exp() calls. Alpha and beta are frozen at inference time, no reason to recompute exp() in every kernel launch. Just precompute exp(alpha) and 1/(exp(beta)+eps) once after weight loading, store as buffers, and have the kernel load them directly. That's 2 transcendental ops saved per element across all 29 layers.

  2. Block size cap. Any reason for capping at 1024? After upsampling T is usually in the thousands, might be worth trying 2048/4096 to reduce grid launches.

  3. Have you compared against torch.compile? For a pure pointwise fusion like this, torch.compile(mode="reduce-overhead") on the eager forward might get you close with zero custom code. Would be interesting to see as a baseline.

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

JuanPZuluaga commented Mar 19, 2026

@linyueqian thanks for the comments!

Cache the exp() calls. Alpha and beta are frozen at inference time, no reason to recompute exp() in every kernel launch. Just precompute exp(alpha) and 1/(exp(beta)+eps) once after weight loading, store as buffers, and have the kernel load them directly. That's 2 transcendental ops saved per element across all 29 layers.

done. i added precompute_exp_cache(): it pre-computes them as persistent buffers.

Block size cap. Any reason for capping at 1024? After upsampling T is usually in the thousands, might be worth trying 2048/4096 to reduce grid launches.

done. raised _TRITON_MAX_BLOCK_T from 1024 to 4096. i already noticed a bit of improvement. Thanks.

Have you compared against torch.compile? For a pure pointwise fusion like this, torch.compile(mode="reduce-overhead") on the eager forward might get you close with zero custom code. Would be interesting to see as a baseline.

I internally benchmarked torch.compile(mode="reduce-overhead") vs the Triton kernel. with the new caching and max_block the kernel is ~2x faster than compile across all configs.

I am adding the new results after rebasing merge main:

comparison

// EDIT

I used this YAML -- relevant params changed in YAML:

  - stage_id: 0
      max_batch_size: 16
      max_num_batched_tokens: 4096
  - stage_id: 1
      max_batch_size: 16
....
    max_inflight: 16
....
        codec_chunk_frames: 25
        codec_left_context_frames: 25

Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@hsliuustc0106 hsliuustc0106 merged commit 81a90d2 into vllm-project:main Mar 20, 2026
7 checks passed
@JuanPZuluaga JuanPZuluaga deleted the feat/code2wav-batch-cuda-graph branch March 20, 2026 05:09
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
…ect#1797)

Signed-off-by: pablo <pablo@agigo.ai>
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Co-authored-by: pablo <pablo@agigo.ai>
Co-authored-by: JuanPZuluaga <juanz9312@gmal.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Co-authored-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants