-
Notifications
You must be signed in to change notification settings - Fork 1k
[Optim][Qwen3TTS] big boost model throughput+latency high concurrency #1852
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
62ed367
a72111f
bec45ba
75930ff
9e0e6a8
89f7a63
a00d8e4
f4b49cd
a99a19e
b535b22
db1e0bc
5b7a5fb
ce4053b
badbfde
340f38e
78200a9
dec614c
d4b9cd0
f3c72aa
8e14b52
4e2bf4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,6 @@ | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import threading | ||
| import time | ||
| from collections import deque | ||
| from typing import Any | ||
|
|
||
|
|
@@ -33,6 +32,8 @@ def __init__(self, config: Any): | |
| self._finished_save_reqs = set() | ||
|
|
||
| self.stop_event = threading.Event() | ||
| self._recv_cond = threading.Condition() | ||
| self._save_cond = threading.Condition() | ||
|
|
||
| self.recv_thread = threading.Thread(target=self.recv_loop, daemon=True) | ||
| self.recv_thread.start() | ||
|
|
@@ -45,22 +46,37 @@ def create_connector(cls, model_config: Any): | |
| raise NotImplementedError | ||
|
|
||
| def recv_loop(self): | ||
| """Loop to poll for incoming data.""" | ||
| """Loop to poll for incoming data. | ||
|
|
||
| Process each pending request exactly once per pass. When no request | ||
| made progress, back off 1 ms instead of tight-spinning on failed | ||
| shm_open syscalls (which can burn a full CPU core). | ||
| """ | ||
| while not self.stop_event.is_set(): | ||
| # Iterate over a snapshot of pending requests | ||
| while self._pending_load_reqs: | ||
| n = len(self._pending_load_reqs) | ||
| any_success = False | ||
| for _ in range(n): | ||
| if not self._pending_load_reqs: | ||
| break | ||
| request = self._pending_load_reqs.popleft() | ||
| request_id = request.request_id | ||
| self.request_ids_mapping[request_id] = request.external_req_id | ||
| try: | ||
| is_success = self._poll_single_request(request) | ||
| if not is_success: | ||
| if is_success: | ||
| any_success = True | ||
| else: | ||
| self._pending_load_reqs.append(request) | ||
| except Exception as e: | ||
| self._pending_load_reqs.append(request) | ||
| logger.warning(f"Error receiving data for {request_id}: {e}") | ||
|
|
||
| time.sleep(0.001) | ||
| # Timeout is the fallback for lock-free append/notify races. | ||
| with self._recv_cond: | ||
| if not self._pending_load_reqs and not self.stop_event.is_set(): | ||
| self._recv_cond.wait(timeout=0.1) | ||
| elif not any_success and not self.stop_event.is_set(): | ||
| self._recv_cond.wait(timeout=0.001) | ||
|
Comment on lines
+56
to
+79
|
||
|
|
||
| def save_loop(self): | ||
| """Loop to send outgoing data.""" | ||
|
|
@@ -71,7 +87,10 @@ def save_loop(self): | |
| self._send_single_request(task) | ||
| except Exception as e: | ||
| logger.warning(f"Error saving data for {task.get('request_id')}: {e}") | ||
| time.sleep(0.001) | ||
|
|
||
| with self._save_cond: | ||
| if not self._pending_save_reqs and not self.stop_event.is_set(): | ||
| self._save_cond.wait(timeout=0.1) | ||
|
|
||
| def _poll_single_request(self, *args, **kwargs): | ||
| """Poll connector for a single request task. | ||
|
|
@@ -105,4 +124,13 @@ def get_finished_requests(self): | |
|
|
||
| def shutdown(self): | ||
| """Stop background loops and close the connector.""" | ||
| raise NotImplementedError | ||
| self.stop_event.set() | ||
| with self._recv_cond: | ||
| self._recv_cond.notify_all() | ||
| with self._save_cond: | ||
| self._save_cond.notify_all() | ||
|
Comment on lines
+127
to
+131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The new Useful? React with 👍 / 👎. |
||
| if self.connector is not None: | ||
| try: | ||
| self.connector.close() | ||
| except Exception: | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,6 +95,8 @@ def load_async(self, request: Request): | |
| if not hasattr(request, "additional_information"): | ||
| request.additional_information = None | ||
| self._pending_load_reqs.append(request) | ||
| with self._recv_cond: | ||
| self._recv_cond.notify() | ||
|
Comment on lines
95
to
+99
|
||
|
|
||
| def save_async( | ||
| self, | ||
|
|
@@ -116,6 +118,8 @@ def save_async( | |
| "is_finished": request.is_finished(), | ||
| } | ||
| self._pending_save_reqs.append(task) | ||
| with self._save_cond: | ||
| self._save_cond.notify() | ||
|
|
||
| def _poll_single_request(self, request: Request): | ||
| stage_id = self.connector.stage_id | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -393,6 +393,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |||||||||||||||||||||||||
| codec_mask[self._codec_eos_token_id] = True | ||||||||||||||||||||||||||
| self.register_buffer("_codec_allowed_mask", codec_mask, persistent=False) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Keys that should stay on GPU in model_intermediate_buffer to avoid | ||||||||||||||||||||||||||
| # CPU-to-GPU round-trips on every decode step. | ||||||||||||||||||||||||||
| self.gpu_resident_buffer_keys: set[str] = { | ||||||||||||||||||||||||||
| "last_talker_hidden", | ||||||||||||||||||||||||||
| "tts_pad_embed", | ||||||||||||||||||||||||||
| "tailing_text_hidden", | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Tokenizer for prompt building. | ||||||||||||||||||||||||||
| self._tokenizer = None | ||||||||||||||||||||||||||
| self._speech_tokenizer: Qwen3TTSTokenizer | None = None | ||||||||||||||||||||||||||
|
|
@@ -547,12 +555,13 @@ def preprocess( | |||||||||||||||||||||||||
| full_prompt_embeds, tailing_text_hidden, tts_pad_embed, ref_code_len, ref_code = ( | ||||||||||||||||||||||||||
| self._build_prompt_embeds(task_type=task_type, info_dict=info_dict) | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| # Store full prompt embeddings + trailing queue on CPU for later chunks/steps. | ||||||||||||||||||||||||||
| # Store full prompt embeddings on CPU (large, prefill-only). | ||||||||||||||||||||||||||
| # tailing_text_hidden and tts_pad_embed stay on GPU (gpu_resident_buffer_keys). | ||||||||||||||||||||||||||
| prompt_embeds_cpu = full_prompt_embeds.detach().to("cpu").contiguous() | ||||||||||||||||||||||||||
| info_update: dict[str, Any] = { | ||||||||||||||||||||||||||
| "talker_prompt_embeds": prompt_embeds_cpu, | ||||||||||||||||||||||||||
| "tailing_text_hidden": tailing_text_hidden.detach().to("cpu").contiguous(), | ||||||||||||||||||||||||||
| "tts_pad_embed": tts_pad_embed.detach().to("cpu").contiguous(), | ||||||||||||||||||||||||||
| "tailing_text_hidden": tailing_text_hidden.detach(), | ||||||||||||||||||||||||||
| "tts_pad_embed": tts_pad_embed.detach(), | ||||||||||||||||||||||||||
| "talker_prefill_offset": 0, | ||||||||||||||||||||||||||
| "codec_streaming": codec_streaming, | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
@@ -568,7 +577,7 @@ def preprocess( | |||||||||||||||||||||||||
| take = prompt_embeds_cpu[s:e] | ||||||||||||||||||||||||||
| if int(take.shape[0]) < span_len: | ||||||||||||||||||||||||||
| pad_n = int(span_len - int(take.shape[0])) | ||||||||||||||||||||||||||
| pad_rows = tts_pad_embed.detach().to("cpu").contiguous().reshape(1, -1).expand(pad_n, -1) | ||||||||||||||||||||||||||
| pad_rows = tts_pad_embed.reshape(1, -1).to("cpu").expand(pad_n, -1) | ||||||||||||||||||||||||||
| take = torch.cat([take, pad_rows], dim=0) | ||||||||||||||||||||||||||
| prompt_embeds = take.to(device=input_ids.device, dtype=torch.bfloat16) | ||||||||||||||||||||||||||
| info_update["talker_prefill_offset"] = int(offset + span_len) | ||||||||||||||||||||||||||
|
|
@@ -584,7 +593,7 @@ def preprocess( | |||||||||||||||||||||||||
| take = prompt_embeds_cpu[s:e] | ||||||||||||||||||||||||||
| if int(take.shape[0]) < span_len: | ||||||||||||||||||||||||||
| pad_n = int(span_len - int(take.shape[0])) | ||||||||||||||||||||||||||
| pad_rows = tts_pad_embed.detach().to("cpu").contiguous().reshape(1, -1).expand(pad_n, -1) | ||||||||||||||||||||||||||
| pad_rows = tts_pad_embed.reshape(1, -1).to("cpu").expand(pad_n, -1) | ||||||||||||||||||||||||||
| take = torch.cat([take, pad_rows], dim=0) | ||||||||||||||||||||||||||
| prompt_embeds = take.to(device=input_ids.device, dtype=torch.bfloat16) | ||||||||||||||||||||||||||
| info_update = {"talker_prefill_offset": int(offset + span_len)} | ||||||||||||||||||||||||||
|
|
@@ -604,23 +613,24 @@ def preprocess( | |||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Decode: span_len == 1 | ||||||||||||||||||||||||||
| # Pop one text-step vector from tailing_text_hidden queue. | ||||||||||||||||||||||||||
| tts_pad_embed_cpu = info_dict.get("tts_pad_embed") | ||||||||||||||||||||||||||
| if not isinstance(tts_pad_embed_cpu, torch.Tensor): | ||||||||||||||||||||||||||
| # These tensors stay on GPU via gpu_resident_buffer_keys - .to() is a no-op. | ||||||||||||||||||||||||||
| tts_pad_embed_buf = info_dict.get("tts_pad_embed") | ||||||||||||||||||||||||||
| if not isinstance(tts_pad_embed_buf, torch.Tensor): | ||||||||||||||||||||||||||
| raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must run first.") | ||||||||||||||||||||||||||
| tts_pad_embed = tts_pad_embed_cpu.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) | ||||||||||||||||||||||||||
| tts_pad_embed = tts_pad_embed_buf.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| tail_cpu = info_dict.get("tailing_text_hidden") | ||||||||||||||||||||||||||
| if isinstance(tail_cpu, torch.Tensor) and tail_cpu.ndim == 2 and tail_cpu.shape[0] > 0: | ||||||||||||||||||||||||||
| text_step = tail_cpu[:1].to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) | ||||||||||||||||||||||||||
| new_tail = tail_cpu[1:].detach().to("cpu").contiguous() if tail_cpu.shape[0] > 1 else tail_cpu[:0] | ||||||||||||||||||||||||||
| tail = info_dict.get("tailing_text_hidden") | ||||||||||||||||||||||||||
| if isinstance(tail, torch.Tensor) and tail.ndim == 2 and tail.shape[0] > 0: | ||||||||||||||||||||||||||
| text_step = tail[:1].to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) | ||||||||||||||||||||||||||
| new_tail = tail[1:] if tail.shape[0] > 1 else tail[:0] | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
| new_tail = tail[1:] if tail.shape[0] > 1 else tail[:0] | |
| if tail.shape[0] > 1: | |
| # Materialize a new tensor for the remaining queue to avoid keeping | |
| # the original (potentially large) GPU storage alive via a view. | |
| new_tail = tail[1:].contiguous() | |
| else: | |
| # Create a truly empty tensor with matching feature dimension. | |
| new_tail = torch.empty( | |
| (0, tail.shape[1]), | |
| device=tail.device, | |
| dtype=tail.dtype, | |
| ) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -93,7 +93,6 @@ def talker2code2wav_async_chunk( | |||||||||||||||||||
| if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0 and request_payload.get(request_id) is None: | ||||||||||||||||||||
| request_payload[request_id] = ref_code.to(torch.long).cpu().contiguous() | ||||||||||||||||||||
| elif not finished: | ||||||||||||||||||||
| # Some steps may not produce pooling_output. Only flush on finish. | ||||||||||||||||||||
| return None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| connector = getattr(transfer_manager, "connector", None) | ||||||||||||||||||||
|
|
@@ -150,7 +149,7 @@ def talker2code2wav_async_chunk( | |||||||||||||||||||
| if finished: | ||||||||||||||||||||
| return { | ||||||||||||||||||||
| "code_predictor_codes": [], | ||||||||||||||||||||
| "finished": torch.tensor(True, dtype=torch.bool), | ||||||||||||||||||||
| "finished": True, | ||||||||||||||||||||
| } | ||||||||||||||||||||
| return None | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -191,10 +190,12 @@ def talker2code2wav_async_chunk( | |||||||||||||||||||
| window_frames = ref_frames + window_frames | ||||||||||||||||||||
| left_context_size += len(ref_frames) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
||||||||||||||||||||
| # Handle potential empty window to avoid IndexError and follow empty-window behavior. | |
| if not window_frames: | |
| return { | |
| "code_predictor_codes": [], | |
| "left_context_size": left_context_size, | |
| "finished": finished, | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @amy-why-3459