Skip to content

[Model] Add unified Qwen3-TTS model definition and Triton serving example with TensorRT codec#3221

Open
vklimkov-nvidia wants to merge 14 commits into
vllm-project:mainfrom
vklimkov-nvidia:vklimkov/qwen3tts_nv
Open

[Model] Add unified Qwen3-TTS model definition and Triton serving example with TensorRT codec#3221
vklimkov-nvidia wants to merge 14 commits into
vllm-project:mainfrom
vklimkov-nvidia:vklimkov/qwen3tts_nv

Conversation

@vklimkov-nvidia
Copy link
Copy Markdown

@vklimkov-nvidia vklimkov-nvidia commented Apr 28, 2026

Purpose

This PR introduces a new model definition for Qwen3-TTS that follows a unified architecture, alongside a comprehensive end-to-end Triton Inference Server serving example.
Instead of modifying the existing Qwen3-TTS model implementation, this PR provides a parallel model definition built on the principle of keeping the code predictor as an internal component of the model.

Key Architectural Choices:

  • Localizes Qwen3-TTS quirks inside the model module, keeping the GPU model runner clean of multi-token-predictor (MTP) specific branches.
  • Captures the talker and code predictor as a single full CUDA graph during decode-only batches.
  • Captures piecewise CUDA graphs during prefill or mixed batches.
  • Enables faster end-to-end execution with fewer graph launches and a single replay on decode.

Triton Serving Example:
Added a production-ready recipe in examples/online_serving/qwen3_tts_nv_triton/ to serve the two distinct stages of Qwen3-TTS efficiently.

  • Talker: Served with vLLM-Omni as a Python Triton backend to leverage continuous batching and paged KV-cache for autoregressive decoding.
  • Codec Decoder: Exported to TensorRT with dynamic batching and sequence-length profiles, served via Triton's native tensorrt_plan backend to efficiently batch independent frame chunks.
  • Orchestration: Uses Triton's Business Logic Scripting (BLS) to stream codes from the talker to the codec, and streams the final waveform chunks back to the client over a decoupled gRPC stream.

Test Plan

  • Added unit tests for the new model definition in tests/model_executor/models/qwen3_tts_nv/test_qwen3_tts_talker_nv.py, covering the talker forward path, code predictor integration, and CUDA graph capture/replay behavior.
  • Added an end-to-end integration recipe under examples/online_serving/qwen3_tts_nv_triton/ that wires the new talker into a Triton Inference Server deployment (vLLM-Omni Python backend for the talker + TensorRT backend for the codec, orchestrated via BLS over decoupled gRPC streaming).
  • Added examples/online_serving/qwen3_tts_nv_triton/benchmark_service.py to benchmark the full Triton service end-to-end (throughput, RTF, TTFA) and to optionally dump the synthesized waveforms for offline quality inspection.
  • Quality validation: ran the benchmark at concurrency 32, dumped the synthesized audio, and ran ASR (speech recognition) over the dumped clips to verify intelligibility and synthesis stability under load.

Test Result

  • Unit tests: tests/model_executor/models/qwen3_tts_nv/test_qwen3_tts_talker_nv.py passes.
  • Triton integration: Server starts cleanly with both the Python backend (vLLM-Omni talker) and the TensorRT codec backend loaded; decoupled gRPC streaming returns audio chunks end-to-end.
  • Quality under load: Audio dumped from benchmark_service.py at concurrency 32 was transcribed via ASR — no intelligibility regressions and no synthesis-stability issues (no truncation, repetition, or collapse) were observed.
  • Performance (single RTX A6000, default max_num_seqs / engine config; latencies reported as mean / p95 in ms):
    End-to-end service (benchmark_service.py, talker + codec via Triton):
Concurrency Throughput (req/s) RTF TTFA mean / p95 (ms)
1 1.14 4.71x 72.8 / 76.9
4 2.69 13.52x 117.2 / 140.0
8 4.42 21.33x 161.8 / 189.5
32 7.34 37.05x 373.9 / 425.4

Talker only (benchmark_model.py, codec tokens only, no waveform):

Concurrency Throughput (req/s) TTFT mean / p95 (ms) ITL mean / p95 (ms)
1 0.73 28.32 / 31.28 15.44 / 16.70
4 2.59 46.84 / 57.45 17.09 / 21.19
8 4.39 55.85 / 64.12 19.87 / 26.98
32 9.89 100.31 / 112.5 33.04 / 45.13

Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
…3tts nv using triton

Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
…nfig.pbtxt

Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
…sting multiple codecs

Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 491217bbbe

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +1490 to +1492
full = self._build_prompt_embeds(
text=text, speaker=speaker, language=language
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard prefill prompt build to first pipeline rank

preprocess always calls _build_prompt_embeds for prefill spans, but this model only creates a real text_embedding on the first PP rank and uses PPMissingLayer elsewhere. Since OmniGPUModelRunner invokes preprocess per request on every rank, any run with pipeline_parallel_size > 1 and prefill traffic can hit this path on non-first ranks and fail before forward. Add a rank guard (or avoid PPMissingLayer here) so only the first rank performs prompt-embedding construction.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropped PP support for now

Comment on lines +1312 to +1313
dialect = spk_is_dialect.get(speaker.lower())
if isinstance(dialect, str) and dialect:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize speaker key before dialect lookup

Dialect resolution in _build_prompt_embeds uses spk_is_dialect.get(speaker.lower()) without trimming whitespace, while other paths (including prompt-length estimation and speaker-id lookup) use stripped speaker keys. If a request sends a valid speaker name with leading/trailing spaces, prefill can miss dialect language conditioning and diverge from the estimated prompt layout, which changes control tokens and can misalign placeholder length assumptions.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@vklimkov-nvidia
Copy link
Copy Markdown
Author

@Sy0307

I benchmarked the proposed model separately (just acoustic codes prediction) and the end-to-end (producing waveform from text) using triton inf server.

Model-Only Benchmark (Talker only)

Concurrency Implementation Throughput (seq/s) TTFT mean / p95 (ms) ITL mean / p95 (ms)
1 vllm omni main 0.66 15.09 / 17.06 17.96 / 19.88
1 PR 0.66 16.55 / 19.00 17.12 / 18.81
1 fork 1.13 17.27 / 19.73 14.72 / 15.45
4 vllm omni main 2.34 35.96 / 32.96 19.91 / 22.64
4 PR 2.32 38.15 / 48.70 19.36 / 22.51
4 fork 3.73 39.13 / 46.78 17.22 / 22.92
8 vllm omni main 3.63 48.92 / 59.66 24.88 / 30.43
8 PR 3.76 47.29 / 53.66 23.23 / 29.94
8 fork 5.55 49.92 / 72.52 21.26 / 29.91
32 vllm omni main 7.84 121.99 / 266.75 44.46 / 58.89
32 PR 7.91 126.25 / 279.79 41.31 / 56.48
32 fork 10.26 124.98 / 317.39 38.16 / 50.49

End-to-End Service Benchmark (Talker + Codec)

Concurrency Implementation Throughput (seq/s) RTF TTFA mean / p95 (ms)
1 vllm omni main 0.43 4.09 63.9 / 66.75
1 PR 0.85 4.39 62.4 / 65.1
1 fork 0.98 4.81 63.2 / 68.5
4 vllm omni main 1.23 12.75 152.69 / 191.7
4 PR 2.55 12.88 103.3 / 120.3
4 fork 2.88 14.37 97.7 / 119.4
8 vllm omni main 1.81 19.46 4150.0 / 5851.0
8 PR 3.82 19.96 143.1 / 165.4
8 fork 4.09 21.42 129.0 / 162.4
32 vllm omni main 2.45 24.91 8367.0 / 10836.0
32 PR 5.68 28.39 375.7 / 495.0
32 fork 8.06 39.27 274.9 / 367.2

Fork is this one: https://github.com/vklimkov-nvidia/vllm/tree/vklimkov/qwen3_tts_voices
The PR and main for vllm-omni are roughly similar performance when it comes to predicting acoustic tokens. Indeed, it should not make much difference performance-wise whether to call 2 or 1 cuda graph. The difference for e2e performance shows up however for higher concurrencies.

@hsliuustc0106
The vllm fork i was developing seems to be faster. I checked the profile and thats because of gaps between decoding steps. That is a problem for both vllm-omni main and PR. around 2ms overhead per each decoding step. Attaching profile picture, will look into it more.
Screenshot 2026-04-28 at 16 00 28

Overall I would advocate to have this change, since it does not touch core of the vllm-omni but
a) provides community with very performance qwen3tts endpoint;
b) provides example of how to keep sampling inside the model definition;
c) provides example of how to serve model using triton inference server and vllm-omni. i think that can be a useful addition in some cases.

…d back on request

Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
@Sy0307
Copy link
Copy Markdown
Contributor

Sy0307 commented Apr 28, 2026

Nice work! I will take a look later and work on finding out why decoding steps are slow. Thanks for your contribution :)

@linyueqian
Copy link
Copy Markdown
Collaborator

linyueqian commented Apr 29, 2026

Reproduced on H20: throughput claim holds, with caveats on bring-up

Spun up the full Triton + TRT-batched codec stack on a single H20 (1x H20-3e, driver 570.133.20, CUDA 12.8 host, Triton container 25.12-py3 since 26.02-py3 needs CUDA 13.1 / driver >= 580). Same Qwen3-TTS-12Hz-1.7B-CustomVoice model, same 20-prompt English file at every concurrency level (30 reqs each, with warmup). For the baseline I ran current origin/main's vllm-omni serve against the same prompts via a small streaming OpenAI client mirroring benchmark_service.py's metrics.

A/B vs current main (origin/main @ dd0fa02)

Concurrency main req/s main RTF main TTFA mean / p95 PR req/s PR RTF PR TTFA mean / p95 speedup (req/s)
1 0.80 6.75x 47 / 52 ms 1.08 7.50x 38 / 41 ms +35%
8 2.29 19.15x 280 / 386 ms 4.41 29.98x 173 / 397 ms +93%
16 2.44 21.52x 2191 / 4437 ms 4.52 31.30x 202 / 258 ms +85%
32 2.47 21.47x 4359 / 8626 ms 6.73 46.11x 420 / 463 ms +173%

Two takeaways worth flagging beyond raw throughput:

  • main's throughput plateaus at ~2.47 req/s starting at c=8, and additional concurrency just queues. TTFA on main grows from 0.4 s at c=8 to 4.4 s mean / 8.6 s p95 at c=32, which is functionally unusable for streaming TTS.
  • The PR's stack keeps TTFA p95 under 470 ms even at c=32. This is the user-visible win; the req/s improvement is real but the TTFA improvement is what makes high-concurrency streaming actually work.

So the claim in the PR description holds. Worth merging the serving recipe.

Bring-up friction (suggest folding these into examples/online_serving/qwen3_tts_nv_triton/README.md)

The recipe took ~3 hours of patch-and-rebuild on H20 (non-CUDA-13 driver). Most of that is captured below as concrete fixes; happy to send a follow-up PR with a Dockerfile.cu12 variant + README section if useful.

P1 (correctness, broken as-shipped):

  • Dockerfile clones the wrong branch: --branch qwen3tts_refactor is the closed PR [Model] Qwen3-TTS: integrate code predictor into model CUDA graph #3071 branch, not this one. Should be vklimkov/qwen3tts_nv (or whatever this PR's final branch name resolves to).
  • README step 1 says cd examples/online_serving/qwen3_tts_triton, actual path is qwen3_tts_nv_triton. Copy-paste fails.
  • python3-libnvinfer=10.15.1.29-1+cuda13.1 pin is too tight; on a CUDA-12 base (e.g. tritonserver:25.12-py3) the base image already ships TRT 10.x with cuda12 builds. Recommend dropping the explicit version pin and letting apt resolve to whatever the base image carries, or shipping two Dockerfiles.

P2 (robustness, will hit anyone reproducing on a non-author machine):

  • Final pip step pip install ... transformers==4.57.3 fails with ERROR: THESE PACKAGES DO NOT MATCH THE HASHES FROM THE REQUIREMENTS FILE because vllm 0.19.0 leaves hash-pinned dep constraints from earlier layers. Workaround: pip install --no-cache-dir --force-reinstall --no-deps transformers==4.57.3 in its own RUN.
  • numpy-1.26.4.dist-info ends up with invalid metadata entry 'name' after the layered installs; transformers crashes at import with Unable to compare versions for numpy>=1.17: need=1.17 found=None. Workaround: append RUN pip install --no-cache-dir --force-reinstall --no-deps "numpy==1.26.4" at the end of the Dockerfile.
  • For systems on driver < 580 (CUDA 13.1 not directly supported), the container's TRT 10.16 needs forward-compat libs. Adding ENV LD_LIBRARY_PATH=/usr/local/cuda-13.1/compat:/usr/local/cuda/lib64 makes trtexec and TRT codec runtime work via NVIDIA's forward-compat path. Worth either documenting or switching base image to a CUDA-12 tag (e.g. 25.12-py3).

P3 (config defaults that don't match the perf claim):

  • model_repository/codec_decoder/config.pbtxt ships max_queue_delay_microseconds: 100. At 100 microseconds Triton's dynamic batcher will rarely form batches > 1 in practice; the codec batching that drives the c=32 win comes from concurrent requests aligning by chance. Recommend a default in the 1000 to 5000 us range, or annotating that this knob is the headline perf knob (the README does mention tweaking it, but the default value should be one that already shows the gain).

Repro setup

  • Triton: nvcr.io/nvidia/tritonserver:25.12-py3 (CUDA 12.8 base, runs on driver 570.x without forward compat issues).
  • GPU: H20-3e, GPU 1, otherwise idle.
  • Codec engine: --minShapes 1x30x16 --optShapes 8x30x16 --maxShapes 32x30x16 --fp16, parity vs ONNX max_abs_diff 1.4e-5 PASSED.
  • Both stacks: same 20 English prompts, 30 requests per concurrency level, with warmup.

cc @hsliuustc0106 @Sy0307

Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
@vklimkov-nvidia
Copy link
Copy Markdown
Author

@Sy0307 I looked more into gap between forward passes. Turned out shm_broadcast was due to distributed_executor_backend="mp". I switched to "uni". It didn't work out of box either. OmniARScheduler inherits regular scheduler so it wasn't accounting for the token that was about to be generated. I added AsyncOmniARScheduler which simply inherits from both OmniARScheduler and AsyncScheduler. That ensures that scheduler doesn't wait between forward runs. It improves the throughput for 32 concurrent streams from 28xrt to 37xrt. You can also see the effect on a single stream, Inter-token-latency goes down from 17.12ms to 15.44ms.

Attaching updated profile picture, gap between forward passes is reduced from 2.9ms to 0.8ms
Screenshot from 2026-04-29 15-22-45

updated benchmark numbers in PR description and in README. The TTFA and TTFT growed by roughly 1 decoder step. I think decoder does one extra step now that goes into the measurement. Will look further into fixing it, but overall - that might be important finding for other models too. Gap between forward passes a bit overinflated now and can be reduced.

Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
@vklimkov-nvidia
Copy link
Copy Markdown
Author

@linyueqian

thanks for the big effort of checking it and sorry the recipe was raw.

happy to send a follow-up PR with a Dockerfile.cu12

Yes! cu12 would be definitely appreciated! Please push a commit with a docker file you end up with.

Dockerfile clones the wrong branch: --branch qwen3tts_refactor

my bad! fixed

README step 1 says

fixed

python3-libnvinfer=10.15.1.29-1+cuda13.1 pin is too tight

dropped it. its not required for the run. container doesnt have tensorrt python package.
it is needed if the codec trt conversion fails. then it's very handy for debug.
the way to install it - is to install appropriate version of python3-libnvinfer from apt-get.

transformers==4.57.3, numpy-1.26.4.dist-info

added a separate entry to dockerfile at the end installing those without dependencies.
I am not sure if thats the most elegant way to do it, but we need this pinned version
of transformers to convert the codec correctly. vllm 0.19 would install transformers==5.5.0,
which doesn't work with https://github.com/QwenLM/Qwen3-TTS.git.
Maybe a PR to original repo updating transformers would be a cleaner way to do it.

switching base image to a CUDA-12 tag (e.g. 25.12-py3).

happy to do that. please share your docker file, would be happy to check it in my env.
If 12 works we can certainly go with that as default

Triton's dynamic batcher will rarely form batches > 1 in practice

actually even with 0 queue delay i was getting average batch size of ~4 according to metrics port (8002).
but i realize it may be specific to my set up. changed to 1000us default.

Btw, i analyzed the gap between forward passes, and was able to push throughput to 37xrt on my hardware, i.e. another +23%. See comment to @Sy0307 above

@Sy0307
Copy link
Copy Markdown
Contributor

Sy0307 commented Apr 30, 2026

I think this PR is great, but perhaps we could split the scheduler work into a separate PR. Could you create a dedicated PR for it and write some tests? Or I could also take on some of that work.

In fact, we are also researching what the best scheduler design would be for TTS. A potentially relevant reference: #2568

ischencheng added a commit to ischencheng/vllm-omni that referenced this pull request May 2, 2026
… gap

Borrows AsyncOmniARScheduler from PR vllm-project#3221 and wires the LLM_AR scheduler
selection so any stage with async_scheduling=true automatically picks the
async-bookkeeping variant.

Background:

When async_scheduling=true, vLLM's EngineCoreProc drives
step_with_batch_queue, which speculatively schedules the next batch while
the current one is still on the GPU. For the queue to stay full, the
scheduler must increment request.num_output_placeholders after each
scheduled step (so the next schedule() call knows to launch one more decode
token before the previous step's output has merged) and decrement it again
when the output arrives. Base OmniARScheduler skips this bookkeeping, so
schedule() returns 0 tokens on every other step, the engine sleeps 1 ms,
and the alternating empty-step pattern adds a ~2-3 ms gap between every
talker forward - visible in nsys profiles and confirmed by PR vllm-project#3221's
reviewer.

AsyncOmniARScheduler injects vllm.v1.core.sched.AsyncScheduler into the
OmniARScheduler MRO so the placeholder bookkeeping takes effect while
preserving every Omni-specific behaviour (OmniNewRequestData wrapping,
KV-transfer metadata, chunk-transfer adapter, streaming-session hooks).

Wiring:

* New _resolve_scheduler_cls(execution_type, async_scheduling) helper in
  stage_config.py picks AsyncOmniARScheduler for LLM_AR stages whenever
  async_scheduling=true; sync stages continue to use OmniARScheduler.
* Re-exported from vllm_omni.core.sched for downstream callers.

Measured impact (single H100 80 GB, Qwen3-TTS-12Hz-0.6B-Base, default
qwen3_tts.yaml = both stages max_num_seqs=10, 30/60/80/128 reqs at
c=1/4/8/32 with 96-req warmup):

| Concurrency | TTFA mean (default) | TTFA mean (+Async) | rps default | rps +Async |
| ----------: | ------------------: | -----------------: | ----------: | ---------: |
|           1 |              259 ms |             260 ms |       0.93  |       0.94 |
|           4 |              761 ms |             728 ms |       1.26  |       1.39 |
|           8 |             1220 ms |            1129 ms |       1.75  |       1.55 |
|          32 |             7286 ms |            5775 ms |       3.24  |       3.91 |

c=32 sees TTFA mean -21% and rps +20% vs the base RFC vllm-project#3163 P0 fix; rps
also exceeds main (3.51) on the same workload. c=1 is unchanged.

Co-Authored-By: Viacheslav Klimkov (PR vllm-project#3221)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@ischencheng
Copy link
Copy Markdown

ischencheng commented May 3, 2026

Tested this PR on a single H100 80GB (RunPod, Triton 25.12 base image, vLLM 0.19.0 per the Dockerfile) using the SeedTTS English testset, 100 prompts each, `voice=Vivian` on both sides, closed-loop concurrency sweep:

Conc main req/s main TTFP p95 this PR req/s this PR TTFA p95
8 3.69 387 ms 2.44 589 ms
16 4.34 606 ms 3.27 898 ms
32 5.68 2085 ms 4.09 1711 ms

For reference, the H20-3e numbers from #3238's description on the same workload:

Conc main req/s main TTFA p95 this PR req/s this PR TTFA p95
8 2.29 386 ms 4.41 397 ms
16 2.44 4437 ms 4.52 258 ms
32 2.47 8626 ms 6.73 463 ms

On H100 main does not saturate at c=32 the way it does on H20-3e (5.68 vs 2.47 req/s), so the PR win that's clearly visible on H20 doesn't appear here — req/s comes out behind main and TTFA tails are roughly comparable.

Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
@vklimkov-nvidia
Copy link
Copy Markdown
Author

Thanks @ischencheng for having a look!

On the baseline comparison:
Good catch on #3326. I missed that the OpenAI endpoint in main was producing 2x more audio than the benchmark in my PR. I agree the previous comparison wasn't exactly apples-to-apples. However, the numbers for this PR remain valid, as I benchmarked the Triton gRPC service directly and verified with ASR that the generated audio matches the submitted text.

On H100 Performance:
I had some trouble reproducing the lower performance numbers you reported. Since an H100 should be significantly more powerful than my local RTX A6000, I provisioned an NVIDIA H100 80GB HBM3 to verify. Following the README from this PR, I obtained the following results:

Concurrency Throughput (req/s) RTF TTFA mean / p95 (ms)
1 1.76 8.83x 40.9 / 42.2
4 4.56 25.05x 72.5 / 94.9
8 7.78 40.56x 88.4 / 113.2
32 17.53 85.22x 183.9 / 230.3

These numbers align with the expectation that the H100 should provide a performance lift. Since TRT models usually perform on par with or better than CUDA graphs, and Talker changes specifically target speed-ups, the PR should ideally be faster than main.

Regarding Mainline:
I tried to benchmark the mainline branch I originated from (dc8a9e2f179), but due to the bug you reported, the throughput numbers are difficult to compare. Interestingly, the Time-to-First-Audio (TTFA) is still quite inflated there for 32 streams (3376 / 4406 ms).

Let me know if I should be benchmarking against a specific newer version or a different vllm-omni tag to better match your environment.

…ient to docker for benchmark

Signed-off-by: Viacheslav Klimkov <vklimkov@nvidia.com>
ischencheng added a commit to ischencheng/vllm-omni that referenced this pull request May 5, 2026
Implements RFC vllm-project#3163 P0:

* Fix the per-request for-loop in Qwen3TTSCode2Wav.forward(): pad
  scheduler-delivered sequences to a single [B, Q, F_max] and run one
  chunked_decode call. The previous loop forced bs=1 even when the
  scheduler had already grouped concurrent requests, regressing
  per-step throughput introduced by PR vllm-project#1426.
* Extend CUDAGraphDecoderWrapper to capture (batch_size, seq_len)
  pairs. Default set keeps the existing bs=1 seq buckets and adds
  (bs in {2,4,8}, seq=streaming_hot) so the new bs>1 hot path stays
  on the graph. Bs/seq misses fall back to eager.
* Plumb capture_pairs / max_batch_size through enable_cudagraph().
  YAML adds an optional code2wav_capture_pairs override.
* Raise Stage 1 max_num_seqs from 1 to 10 (matches Stage 0). Update
  the inline comment about engine-level CUDA Graph compatibility.

Tests:
* New tests/model_executor/models/qwen3_tts/test_code2wav_batching.py
  covers bs=1 parity, bs>1 per-request parity, padding-no-bleed,
  per-request left_context, and malformed-mixed batches.
* tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py
  switches the fixture and tests to the new (bs, seq) API and adds
  multi-bs capture/replay, uncaptured-bs fallback, and
  compute_capture_pairs cases.
* New tests/e2e/online_serving/test_qwen3_tts_concurrent_ttfb.py
  follows the bench shape of PR vllm-project#3221's benchmark_service.py
  (Throughput, RTF, TTFA mean / p95) and asserts sub-linear TTFA
  scaling per the RFC target (TTFA c=4 / TTFA c=1 <= 4.0x).

Resolves vllm-project#3163.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: ischencheng <cheng21@seas.upenn.edu>
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

Hi @vklimkov-nvidia, friendly reminder — this PR hasn't had any activity (commits or reviews) in the past 11 days. 🕐

Could you please provide an update?

  • If you're still working on it, that's great — just let us know.
  • If you're blocked on something, feel free to ask for help.
  • If this PR is no longer being pursued, please consider closing it so we can keep the review queue manageable.

Thanks for your contribution! 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants