[Feat][Qwen3TTS][Code2wav] triton SnakeBeta and Cuda Graph#1797
Conversation
Signed-off-by: pablo <pablo@agigo.ai>
Signed-off-by: pablo <pablo@agigo.ai>
Signed-off-by: pablo <pablo@agigo.ai>
… feat/code2wav-batch-cuda-graph
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 2fc84a4135
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if hidden_states.is_cuda and self._init_triton(): | ||
| return self._triton_forward(hidden_states) |
There was a problem hiding this comment.
Add eager fallback around Triton execution failures
The CUDA path now calls _triton_forward whenever _init_triton() returns true, but there is no runtime fallback if Triton kernel compilation or launch fails. In environments where Triton imports successfully but cannot execute (for example unsupported GPU/driver combinations or Triton runtime incompatibilities), this will raise and break decoding instead of preserving the prior eager behavior, so requests can fail entirely rather than degrade gracefully.
Useful? React with 👍 / 👎.
| try: | ||
| import triton | ||
| import triton.language as tl | ||
| except ImportError: | ||
| return False |
There was a problem hiding this comment.
Memoize Triton-unavailable state after import failure
If Triton is not installed, _init_triton() returns False but leaves _triton_kernel as None, so every CUDA forward re-attempts import triton and pays repeated ImportError costs. Because SnakeBeta is called many times per decode, this repeated exception path can materially hurt throughput on non-Triton CUDA deployments; caching a negative detection result would keep fallback overhead to a one-time check.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Pull request overview
This PR improves Qwen3-TTS decoder inference performance by adding a fused Triton implementation for the SnakeBeta activation and by making CUDA Graph capture sizing more adaptive to streaming/chunking configurations.
Changes:
- Add a fused Triton kernel path for
SnakeBeta(with eager fallback). - Extend CUDA graph enable/warmup plumbing to incorporate codec chunk/left-context sizes and compute better capture buckets.
- Adjust Code2Wav CUDA graph enablement to pass chunk/left-context config directly; simplify per-request decode loop; add targeted tests.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py |
Adds Triton-accelerated SnakeBeta and extends decoder CUDA-graph enablement parameters. |
vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py |
Passes codec chunk/left-context config into decoder CUDA-graph warmup and simplifies decode iteration. |
vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py |
Adds adaptive compute_capture_sizes, refines warmup/capture behavior and decode fallback checks. |
tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py |
Adds tests for compute_capture_sizes and Triton-vs-eager equivalence for SnakeBeta. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Applies the function to the input elementwise. | ||
| SnakeBeta ∶= x + 1/b * sin^2 (xa) | ||
| """ | ||
| if hidden_states.is_cuda and self._init_triton(): |
There was a problem hiding this comment.
The Triton fast path will run whenever hidden_states is CUDA, even when autograd is enabled / hidden_states.requires_grad=True. Since _triton_forward writes into a freshly allocated tensor and there’s no custom backward, this will silently break gradients. Please gate the Triton path behind not torch.is_grad_enabled() (or not hidden_states.requires_grad) and fall back to _eager_forward when gradients are needed (or implement a proper autograd.Function).
| if hidden_states.is_cuda and self._init_triton(): | |
| # Use Triton fast path only when gradients are not needed to avoid | |
| # silently breaking autograd. When autograd is enabled, fall back | |
| # to the eager PyTorch implementation, which is fully differentiable. | |
| if hidden_states.is_cuda and not torch.is_grad_enabled() and self._init_triton(): |
| except ImportError: | ||
| return False | ||
|
|
||
| @triton.jit |
There was a problem hiding this comment.
@linyueqian @tzhouam where should we place triton kernels?
There was a problem hiding this comment.
I think we should follow vLLM IR if we plan to introduce many triton kernels. As a light abstraction, can consider CustomOp, but vLLM is deprecating it.
There was a problem hiding this comment.
I think inline is fine for now since it's the only triton kernel in the repo.
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
… feat/code2wav-batch-cuda-graph
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
lishunyang12
left a comment
There was a problem hiding this comment.
the silent except: pass on the triton path could mask persistent failures
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
|
most things fixed. |
… feat/code2wav-batch-cuda-graph
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
… feat/code2wav-batch-cuda-graph
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
… feat/code2wav-batch-cuda-graph
… feat/code2wav-batch-cuda-graph
… feat/code2wav-batch-cuda-graph
hsliuustc0106
left a comment
There was a problem hiding this comment.
Gate Status
| Check | Status |
|---|---|
| DCO / pre-commit / build | ✅ |
| Main CI | ✅ |
| AMD CI | ❌ (may be unrelated) |
Evidence ✅
Comprehensive benchmark data provided:
- TTFP improvement: 4.6%–19.9%
- E2E improvement: 1.7%–15.2%
- Audio samples and comparison plot included
Code Quality ✅
- Triton fallback is properly handled with
logger.warning(..., exc_info=True)+ disables future attempts - Dynamic capture size computation is clean
- Test coverage for both features
Prior concern about silent exception handling is addressed. LGTM once AMD CI is investigated (likely unrelated to this PR).
|
I was thinking about whether any of these would help further:
|
… feat/code2wav-batch-cuda-graph
…uanPZuluaga/vllm-omni into feat/code2wav-batch-cuda-graph
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
… feat/code2wav-batch-cuda-graph
|
@linyueqian thanks for the comments!
done. i added
done. raised
I internally benchmarked I am adding the new results after rebasing merge main:
// EDIT I used this YAML -- relevant params changed in YAML: |
… feat/code2wav-batch-cuda-graph
… feat/code2wav-batch-cuda-graph
… feat/code2wav-batch-cuda-graph
…ect#1797) Signed-off-by: pablo <pablo@agigo.ai> Signed-off-by: JuanPZuluaga <juanz9312@gmal.com> Co-authored-by: pablo <pablo@agigo.ai> Co-authored-by: JuanPZuluaga <juanz9312@gmal.com> Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com> Co-authored-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
The code2wav processes codecs to audio using a conv-based pipeline with 29 SnakeBeta activation layers. Each SnakeBeta executes 4 separate elementwise GPU kernels (exp, sin, pow, add), creating ~116 small kernel launches per forward pass. this makes Code2Wav a small bottleneck under high concurrent load. thus, we introduce two optimizations:
Fused Triton SnakeBeta kernel: replaces 4 elementwise PyTorch ops with a single Triton kernel that reads and writes memory once. Auto-detected at runtime — uses Triton when available on CUDA, falls back to eager PyTorch otherwise. Zero configuration needed.
Smart CUDA graph capture sizes: instead of a hardcoded list [25, 50, 100, 150, 200, 250, 300], capture sizes are computed dynamically from the streaming config (codec_chunk_frames, codec_left_context_frames). This ensures exact graph hits for streaming chunk sizes (e.g., 33 and 58 for c=33/ctx=25) and includes power-of-2 small sizes [2, 4, 8, 16, 32, 64] aligned with the dynamic IC sizing in [feat][Qwen3TTS] Simple dynamic TTFA based on Code2Wav load #1714.
The capture size computation also generates bucket sizes for variable-length last chunks, ensuring high graph hit rate across all decode calls.
WIP: batched decoding.
Test Plan
Test Result
Benchmark Results
Improvement (triton vs main)
in the plot we can see that the TTFP/E2E latency, and everything is better.
some audio files:
audio from main: sample_2_stream_false.wav
audio from this PR: sample_2_stream_false.wav
Config YAML
relevant params changed in YAML:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model. Please runmkdocs serveto sync the documentation editions to./docs.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)