Skip to content

[Optim][Qwen3TTS] big boost model throughput+latency high concurrency#1852

Merged
linyueqian merged 21 commits into
vllm-project:mainfrom
JuanPZuluaga:feat/qwen3tts-optimize-decode
Mar 18, 2026
Merged

[Optim][Qwen3TTS] big boost model throughput+latency high concurrency#1852
linyueqian merged 21 commits into
vllm-project:mainfrom
JuanPZuluaga:feat/qwen3tts-optimize-decode

Conversation

@JuanPZuluaga
Copy link
Copy Markdown
Contributor

@JuanPZuluaga JuanPZuluaga commented Mar 12, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

in the Qwen3-TTS Talker we move 3 big tensors (last_talker_hidden, tts_pad_embed, tailing_text_hidden) from GPU to CPU after every decode step, then CPU to GPU at the start of the next step. At batch_size=32 with ~200 decode steps per utterance (ie high concurrency settings), this creates 25k cudaMemcpy + GPU pipeline stalls per generation. That's why even in high concurrency is hard to see GPU util near 100%

This PR eliminates that round-trip and fixes two additional inefficiencies:

  • GPU-resident buffer keys for Talker: declares the three tensors as gpu_resident_buffer_keys (this thing already exists in qwen3_omni, just wasn't used here). Tensors stay on GPU; the .to(device, dtype) calls become no-ops. Zero-copy per decode step.
  • also, the Recv loop busy-spin: the shared-memory polling loop tight-spins when pending requests have no data ready. Now processes each request exactly once per pass, then backs off 1ms when no request made progress (100ms when the queue is empty and waiting for new work). Combined with threading.Condition() replacing time.sleep(0.001) in both recv/save loops for instant wakeup when data arrives.

Test Plan

Run benchmark with this PR vs Main.

the yaml i changed with:

Config YAML

stage_args:
  - stage_id: 0
    stage_type: llm
    runtime:
      devices: "0"
      max_batch_size: 16

  - stage_id: 1
    stage_type: llm
    runtime:
      devices: "0"
      max_batch_size: 16

runtime:
  max_inflight: 16
  connectors:
    connector_of_shared_memory:
      codec_streaming: true
      codec_chunk_frames: 32
      codec_left_context_frames: 32

Test Result

see the plot, results are clear, we improve in every setting.

comparison
  Warming up with 3 requests...
  Warmup done.
  Running 50 requests with concurrency=2...
  concurrency=2: 100%|██████████| 50/50 [00:38<00:00,  1.31it/s]

==================================================
             Serving Benchmark Result
==================================================
Successful requests:                    50
Failed requests:                        0
Maximum request concurrency:            2
Benchmark duration (s):                 38.05
Request throughput (req/s):             1.31
--------------------------------------------------
                End-to-end Latency
--------------------------------------------------
Mean E2EL (ms):                         1510.48
Median E2EL (ms):                       1499.88
P99 E2EL (ms):                          2071.66
==================================================
                   Audio Result
==================================================
Total audio duration generated (s):     283.52
Audio throughput (audio duration/s):    7.45
--------------------------------------------------
               Time to First Packet
--------------------------------------------------
Mean AUDIO_TTFP (ms):                   92.07
Median AUDIO_TTFP (ms):                 91.81
P99 AUDIO_TTFP (ms):                    111.66
--------------------------------------------------
                 Real Time Factor
--------------------------------------------------
Mean AUDIO_RTF:                         0.267
Median AUDIO_RTF:                       0.266
P99 AUDIO_RTF:                          0.307
==================================================



  Warming up with 3 requests...
  Warmup done.
  Running 50 requests with concurrency=16...
  concurrency=16: 100%|██████████| 50/50 [00:07<00:00,  6.56it/s]

==================================================
             Serving Benchmark Result
==================================================
Successful requests:                    50
Failed requests:                        0
Maximum request concurrency:            16
Benchmark duration (s):                 7.62
Request throughput (req/s):             6.56
--------------------------------------------------
                End-to-end Latency
--------------------------------------------------
Mean E2EL (ms):                         2214.74
Median E2EL (ms):                       2245.88
P99 E2EL (ms):                          3096.15
==================================================
                   Audio Result
==================================================
Total audio duration generated (s):     320.00
Audio throughput (audio duration/s):    41.99
--------------------------------------------------
               Time to First Packet
--------------------------------------------------
Mean AUDIO_TTFP (ms):                   556.57
Median AUDIO_TTFP (ms):                 593.41
P99 AUDIO_TTFP (ms):                    672.01
--------------------------------------------------
                 Real Time Factor
--------------------------------------------------
Mean AUDIO_RTF:                         0.360
Median AUDIO_RTF:                       0.378
P99 AUDIO_RTF:                          0.530
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

pablo added 4 commits March 12, 2026 14:25
Signed-off-by: pablo <pablo@agigo.ai>
Signed-off-by: pablo <pablo@agigo.ai>
Signed-off-by: pablo <pablo@agigo.ai>
Signed-off-by: pablo <pablo@agigo.ai>
@JuanPZuluaga JuanPZuluaga changed the title [Optim][Qwen3TTS] boost model throughput [Optim][Qwen3TTS] big boost model throughput+latency high concurrency Mar 12, 2026
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 75930ff3e7

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

wav_list, sr = await self._resolve_ref_audio(request.ref_audio)
tts_params["ref_audio"] = [[wav_list, sr]]
wav_np, sr = await self._resolve_ref_audio(request.ref_audio)
tts_params["ref_audio"] = [(wav_np, sr)]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep ref_audio list_data msgspec-serializable

This now places a NumPy array inside additional_information["ref_audio"] ([(wav_np, sr)]), but OmniInputProcessor._build_engine_request forwards list values into AdditionalInformationEntry(list_data=...) without converting nested objects, and AdditionalInformationEntry is documented as requiring msgspec-serializable list content. In multiprocess/distributed engine paths that msgpack-encode OmniEngineCoreRequest, Base voice-cloning requests with ref_audio can fail serialization because nested np.ndarray is not encoded like tensor payloads.

Useful? React with 👍 / 👎.

Comment on lines +126 to +130
self.stop_event.set()
with self._recv_cond:
self._recv_cond.notify_all()
with self._save_cond:
self._save_cond.notify_all()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Close connector resources in adapter shutdown

The new shutdown implementation wakes the worker threads but never closes the underlying connector, even though connector cleanup is defined via OmniConnectorBase.close(). When shutdown is used during stage teardown/restart, transports like shared memory or Mooncake can keep resources open (handles, pools, executors), causing leaks and cross-run interference; call self.connector.close() as part of shutdown after signaling threads.

Useful? React with 👍 / 👎.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

@linyueqian PTAL

Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

addressed Codex Reviewer.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary

This is a well-structured optimization PR with clear benchmark evidence. The three main changes are:

  1. GPU-resident buffer keys - Correctly follows the existing pattern from qwen3_omni.py. The gpu_resident_buffer_keys set is properly consumed by gpu_model_runner.py to keep last_talker_hidden, tts_pad_embed, and tailing_text_hidden on GPU, eliminating CPU round-trips.

  2. Threading.Condition for recv/save loops - Good pattern change from busy-spin polling to proper synchronization. The logic correctly processes exactly n requests per pass (prevents starvation), waits 100ms when idle, and backs off 1ms on no progress.

  3. Pure Python transpose - Verified the list comprehension produces identical output to the original torch.tensor().transpose().reshape(-1).tolist() chain. Avoids tensor allocation overhead.

Validated

  • ✅ DCO signed
  • ✅ Benchmark evidence with clear before/after comparison
  • ✅ GPU-resident pattern matches existing qwen3_omni implementation
  • finished: bool type change is correct (matches OmniRequestOutput.finished: bool)
  • ✅ Transpose optimization is semantically equivalent
  • shutdown() implementation properly signals stop and wakes waiting threads

Minor Notes

  • The shutdown() implementation silently swallows errors from connector.close() via try/except pass - this is acceptable for cleanup code but worth awareness

Test Coverage

The existing tests in tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py and tests/distributed/omni_connectors/test_chunk_transfer_adapter.py should cover the modified paths. The benchmark results demonstrate functional correctness.


Nice performance wins - especially the GPU-resident buffers which eliminate 25k+ cudaMemcpy calls per generation at high concurrency.

@hsliuustc0106 hsliuustc0106 requested a review from Copilot March 12, 2026 15:35
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Improve Qwen3-TTS high-concurrency throughput/latency by reducing per-step device transfers and lowering connector polling CPU overhead.

Changes:

  • Keep selected Talker intermediate tensors GPU-resident across decode steps to avoid repeated CPU↔GPU round-trips.
  • Replace tight polling sleeps in transfer adapter loops with condition-variable wakeups and backoff.
  • Optimize chunk-streaming payload construction by avoiding heavy tensor/list conversions.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.

File Description
vllm_omni/model_executor/stage_input_processors/qwen3_tts.py Stream payload now uses Python-native transforms and bool finished flags
vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py Introduces GPU-resident intermediate keys and removes CPU staging for select tensors
vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py Notifies recv/save conditions when new pending work is queued
vllm_omni/distributed/omni_connectors/transfer_adapter/base.py Refactors recv/save loops to use Conditions + backoff and implements shutdown

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +56 to +78
n = len(self._pending_load_reqs)
any_success = False
for _ in range(n):
if not self._pending_load_reqs:
break
request = self._pending_load_reqs.popleft()
request_id = request.request_id
self.request_ids_mapping[request_id] = request.external_req_id
try:
is_success = self._poll_single_request(request)
if not is_success:
if is_success:
any_success = True
else:
self._pending_load_reqs.append(request)
except Exception as e:
self._pending_load_reqs.append(request)
logger.warning(f"Error receiving data for {request_id}: {e}")

time.sleep(0.001)
with self._recv_cond:
if not self._pending_load_reqs and not self.stop_event.is_set():
self._recv_cond.wait(timeout=0.1)
elif not any_success and not self.stop_event.is_set():
self._recv_cond.wait(timeout=0.001)
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Condition is not used to protect access to _pending_load_reqs. Because producers can append without holding _recv_cond’s lock, recv_loop can observe the deque as empty and go into wait() while a request is appended/notify() happens just before it starts waiting (missed wakeup), delaying processing until timeout. Fix by using the condition lock as the mutex for all _pending_load_reqs mutations + checks (append/popleft/len/emptiness) or by switching to a thread-safe queue (queue.Queue) and blocking get() with timeout.

Copilot uses AI. Check for mistakes.
Comment on lines 95 to +99
if not hasattr(request, "additional_information"):
request.additional_information = None
self._pending_load_reqs.append(request)
with self._recv_cond:
self._recv_cond.notify()
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This notify() doesn’t prevent missed wakeups because the enqueue (_pending_load_reqs.append) is not done while holding the same _recv_cond lock that the consumer uses to decide whether to wait. Move the append inside the with self._recv_cond: block (and ensure the consumer also checks/consumes under that same lock), or use a blocking queue abstraction to avoid subtle races.

Copilot uses AI. Check for mistakes.
tail = info_dict.get("tailing_text_hidden")
if isinstance(tail, torch.Tensor) and tail.ndim == 2 and tail.shape[0] > 0:
text_step = tail[:1].to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)
new_tail = tail[1:] if tail.shape[0] > 1 else tail[:0]
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On GPU, new_tail = tail[1:] is a view into the original tensor storage. Even as the logical length shrinks, the full original allocation can remain live, and repeated slicing can keep large GPU buffers resident longer than intended. Consider tracking an integer offset into tailing_text_hidden (no slicing / no copies) or, if you do want to physically drop consumed elements, make new_tail a new tensor (e.g., contiguous()/clone()) before storing it back so the old storage can be released.

Suggested change
new_tail = tail[1:] if tail.shape[0] > 1 else tail[:0]
if tail.shape[0] > 1:
# Materialize a new tensor for the remaining queue to avoid keeping
# the original (potentially large) GPU storage alive via a view.
new_tail = tail[1:].contiguous()
else:
# Create a truly empty tensor with matching feature dimension.
new_tail = torch.empty(
(0, tail.shape[1]),
device=tail.device,
dtype=tail.dtype,
)

Copilot uses AI. Check for mistakes.
if int(take.shape[0]) < span_len:
pad_n = int(span_len - int(take.shape[0]))
pad_rows = tts_pad_embed.detach().to("cpu").contiguous().reshape(1, -1).expand(pad_n, -1)
pad_rows = tts_pad_embed.reshape(1, -1).expand(pad_n, -1).to("cpu")
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This expands on GPU and then copies the expanded view to CPU, which can transfer pad_n * hidden_dim elements over the interconnect. If the intent is to keep prefill prompt building CPU-side, copy the 1×D row to CPU first and then expand/repeat on CPU (or construct the padding directly on CPU) to reduce GPU→CPU transfer volume. Same pattern appears again at line 596.

Copilot uses AI. Check for mistakes.
return {}
last = hidden_states[-1, :].detach().to("cpu").contiguous()
return {"last_talker_hidden": last}
return {"last_talker_hidden": hidden_states[-1, :]}
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously this value was .detach()’d before storing. Returning a view of hidden_states without detaching can keep autograd history alive if this code is ever run with gradients enabled (or under certain tracing setups), increasing memory usage. Keep the GPU-resident behavior but still detach (and optionally make it contiguous if downstream relies on contiguous reshapes).

Suggested change
return {"last_talker_hidden": hidden_states[-1, :]}
last_hidden = hidden_states[-1, :].detach()
last_hidden = last_hidden.contiguous()
return {"last_talker_hidden": last_hidden}

Copilot uses AI. Check for mistakes.
@@ -189,10 +188,12 @@ def talker2code2wav_async_chunk(
window_frames = ref_frames + window_frames
left_context_size += len(ref_frames)

Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

window_frames[0] will raise IndexError if window_frames is ever empty. The previous torch.tensor(window_frames) path would also fail on empty input, but if empty windows are possible in some streaming boundary conditions, this needs an explicit guard (e.g., handle empty by returning code_predictor_codes = [] and still emitting finished/left_context_size).

Suggested change
# Handle potential empty window to avoid IndexError and follow empty-window behavior.
if not window_frames:
return {
"code_predictor_codes": [],
"left_context_size": left_context_size,
"finished": finished,
}

Copilot uses AI. Check for mistakes.
self._pending_load_reqs.append(request)
logger.warning(f"Error receiving data for {request_id}: {e}")

time.sleep(0.001)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
@linyueqian
Copy link
Copy Markdown
Collaborator

Nice work on the optimization. I ran an independent benchmark on an H200 (same config for both, max_batch_size=4, max_inflight=4, 20 requests per concurrency level, async_chunk enabled):

Metric main PR #1852 Delta
Concurrency 1
Mean E2EL (ms) 3530 2960 -16%
Mean TTFP (ms) 263 228 -13%
Mean RTF 0.560 0.471 -16%
Audio throughput (x) 1.78 2.12 +19%
Concurrency 4
Mean E2EL (ms) 6245 3472 -44%
Mean TTFP (ms) 1426 869 -39%
Mean RTF 1.124 0.634 -44%
Audio throughput (x) 3.80 6.28 +65%

Results confirm the PR's claims. The gains scale with concurrency as expected from eliminating per-step GPU<->CPU round-trips.

A couple of nits on the PR description:

  1. "numpy array for ref_audio instead of .tolist()" is mentioned in the description but doesn't appear in the diff. Might want to remove it or save for a follow-up.

  2. Backoff timing: Description says "backs off 1ms if none succeeded" - the 1ms applies when requests exist but none had data ready, but when the queue is completely empty it's actually 100ms. Minor, but worth clarifying.

@linyueqian linyueqian added the ready label to trigger buildkite CI label Mar 12, 2026
@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

JuanPZuluaga commented Mar 13, 2026

  1. "numpy array for ref_audio instead of .tolist()" is mentioned in the description but doesn't appear in the diff. Might want to remove it or save for a follow-up.
  2. Backoff timing: Description says "backs off 1ms if none succeeded" - the 1ms applies when requests exist but none had data ready, but when the queue is completely empty it's actually 100ms. Minor, but worth clarifying.

I have updated the PR description. Let me know if there's something else to be done.

@linyueqian linyueqian added this to the v0.18.0 milestone Mar 16, 2026
@tzhouam
Copy link
Copy Markdown
Collaborator

tzhouam commented Mar 17, 2026

@natureofnature Does this align with your design?

@natureofnature
Copy link
Copy Markdown
Contributor

@natureofnature Does this align with your design?

Yes, I can rebase with this PR for the refactor.

@linyueqian linyueqian merged commit fc5e8f8 into vllm-project:main Mar 18, 2026
6 of 7 checks passed
yiliu30 pushed a commit to yiliu30/vllm-omni-fork that referenced this pull request Mar 20, 2026
…vllm-project#1852)

Signed-off-by: pablo <pablo@agigo.ai>
Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
Co-authored-by: pablo <pablo@agigo.ai>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>

Signed-off-by: yiliu30 <yi4.liu@intel.com>
linyueqian added a commit to linyueqian/vllm-omni that referenced this pull request Mar 31, 2026
Fix test assertions and mocks that fell out of sync with source code
changes in qwen3_tts.py across PRs vllm-project#1930, vllm-project#1852, and vllm-project#2104.

- test_flush_on_finish: `finished` is now a plain bool, not a tensor;
  remove `.item()` call
- test_ic_load_change_mid_request: IC is cached per request since vllm-project#1930;
  update expected emission frames to match current logic
- test_non_async_processor_prepends_ref_code_and_sets_trim_context:
  add missing `finished=True` and `token_ids` to mock (required since vllm-project#2104)
- test_non_async_processor_filters_out_of_range_codec_values: same fix

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: linyueqian <linyueqian@outlook.com>
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
…vllm-project#1852)

Signed-off-by: pablo <pablo@agigo.ai>
Signed-off-by: JuanPZuluaga <juanz9312@gmail.com>
Co-authored-by: pablo <pablo@agigo.ai>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants