Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
62ed367
replace polling loops with instant-wakeup
Mar 12, 2026
a72111f
eliminate per-step CPU -> GPU
Mar 12, 2026
bec45ba
transpose in python
Mar 12, 2026
75930ff
return numpy object instead of doing .list()
Mar 12, 2026
9e0e6a8
revert back wav file transport
JuanPZuluaga Mar 12, 2026
89f7a63
add self.connector.close()
JuanPZuluaga Mar 12, 2026
a00d8e4
couple of fixes
JuanPZuluaga Mar 12, 2026
f4b49cd
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Mar 12, 2026
a99a19e
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Mar 13, 2026
b535b22
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Mar 13, 2026
db1e0bc
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Mar 13, 2026
5b7a5fb
fix tests
JuanPZuluaga Mar 13, 2026
ce4053b
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Mar 15, 2026
badbfde
Merge branch 'main' into feat/qwen3tts-optimize-decode
JuanPZuluaga Mar 16, 2026
340f38e
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Mar 16, 2026
78200a9
Merge branch 'feat/qwen3tts-optimize-decode' of https://github.com/Ju…
JuanPZuluaga Mar 16, 2026
dec614c
Merge branch 'main' into feat/qwen3tts-optimize-decode
JuanPZuluaga Mar 16, 2026
d4b9cd0
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Mar 17, 2026
f3c72aa
Merge branch 'feat/qwen3tts-optimize-decode' of https://github.com/Ju…
JuanPZuluaga Mar 17, 2026
8e14b52
Merge branch 'main' into feat/qwen3tts-optimize-decode
hsliuustc0106 Mar 18, 2026
4e2bf4d
Merge branch 'main' into feat/qwen3tts-optimize-decode
JuanPZuluaga Mar 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def _fake_base_init(self, config):
self._pending_save_reqs = deque()
self._finished_save_reqs = set()
self.stop_event = threading.Event()
self._recv_cond = threading.Condition()
self._save_cond = threading.Condition()

monkeypatch.setattr(OmniTransferAdapterBase, "__init__", _fake_base_init)
monkeypatch.setattr(
Expand Down
44 changes: 36 additions & 8 deletions vllm_omni/distributed/omni_connectors/transfer_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import threading
import time
from collections import deque
from typing import Any

Expand Down Expand Up @@ -33,6 +32,8 @@ def __init__(self, config: Any):
self._finished_save_reqs = set()

self.stop_event = threading.Event()
self._recv_cond = threading.Condition()
self._save_cond = threading.Condition()

self.recv_thread = threading.Thread(target=self.recv_loop, daemon=True)
self.recv_thread.start()
Expand All @@ -45,22 +46,37 @@ def create_connector(cls, model_config: Any):
raise NotImplementedError

def recv_loop(self):
"""Loop to poll for incoming data."""
"""Loop to poll for incoming data.

Process each pending request exactly once per pass. When no request
made progress, back off 1 ms instead of tight-spinning on failed
shm_open syscalls (which can burn a full CPU core).
"""
while not self.stop_event.is_set():
# Iterate over a snapshot of pending requests
while self._pending_load_reqs:
n = len(self._pending_load_reqs)
any_success = False
for _ in range(n):
if not self._pending_load_reqs:
break
request = self._pending_load_reqs.popleft()
request_id = request.request_id
self.request_ids_mapping[request_id] = request.external_req_id
try:
is_success = self._poll_single_request(request)
if not is_success:
if is_success:
any_success = True
else:
self._pending_load_reqs.append(request)
except Exception as e:
self._pending_load_reqs.append(request)
logger.warning(f"Error receiving data for {request_id}: {e}")

time.sleep(0.001)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Timeout is the fallback for lock-free append/notify races.
with self._recv_cond:
if not self._pending_load_reqs and not self.stop_event.is_set():
self._recv_cond.wait(timeout=0.1)
elif not any_success and not self.stop_event.is_set():
self._recv_cond.wait(timeout=0.001)
Comment on lines +56 to +79
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Condition is not used to protect access to _pending_load_reqs. Because producers can append without holding _recv_cond’s lock, recv_loop can observe the deque as empty and go into wait() while a request is appended/notify() happens just before it starts waiting (missed wakeup), delaying processing until timeout. Fix by using the condition lock as the mutex for all _pending_load_reqs mutations + checks (append/popleft/len/emptiness) or by switching to a thread-safe queue (queue.Queue) and blocking get() with timeout.

Copilot uses AI. Check for mistakes.

def save_loop(self):
"""Loop to send outgoing data."""
Expand All @@ -71,7 +87,10 @@ def save_loop(self):
self._send_single_request(task)
except Exception as e:
logger.warning(f"Error saving data for {task.get('request_id')}: {e}")
time.sleep(0.001)

with self._save_cond:
if not self._pending_save_reqs and not self.stop_event.is_set():
self._save_cond.wait(timeout=0.1)

def _poll_single_request(self, *args, **kwargs):
"""Poll connector for a single request task.
Expand Down Expand Up @@ -105,4 +124,13 @@ def get_finished_requests(self):

def shutdown(self):
"""Stop background loops and close the connector."""
raise NotImplementedError
self.stop_event.set()
with self._recv_cond:
self._recv_cond.notify_all()
with self._save_cond:
self._save_cond.notify_all()
Comment on lines +127 to +131
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Close connector resources in adapter shutdown

The new shutdown implementation wakes the worker threads but never closes the underlying connector, even though connector cleanup is defined via OmniConnectorBase.close(). When shutdown is used during stage teardown/restart, transports like shared memory or Mooncake can keep resources open (handles, pools, executors), causing leaks and cross-run interference; call self.connector.close() as part of shutdown after signaling threads.

Useful? React with 👍 / 👎.

if self.connector is not None:
try:
self.connector.close()
except Exception:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def load_async(self, request: Request):
if not hasattr(request, "additional_information"):
request.additional_information = None
self._pending_load_reqs.append(request)
with self._recv_cond:
self._recv_cond.notify()
Comment on lines 95 to +99
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This notify() doesn’t prevent missed wakeups because the enqueue (_pending_load_reqs.append) is not done while holding the same _recv_cond lock that the consumer uses to decide whether to wait. Move the append inside the with self._recv_cond: block (and ensure the consumer also checks/consumes under that same lock), or use a blocking queue abstraction to avoid subtle races.

Copilot uses AI. Check for mistakes.

def save_async(
self,
Expand All @@ -116,6 +118,8 @@ def save_async(
"is_finished": request.is_finished(),
}
self._pending_save_reqs.append(task)
with self._save_cond:
self._save_cond.notify()

def _poll_single_request(self, request: Request):
stage_id = self.connector.stage_id
Expand Down
46 changes: 28 additions & 18 deletions vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
codec_mask[self._codec_eos_token_id] = True
self.register_buffer("_codec_allowed_mask", codec_mask, persistent=False)

# Keys that should stay on GPU in model_intermediate_buffer to avoid
# CPU-to-GPU round-trips on every decode step.
self.gpu_resident_buffer_keys: set[str] = {
"last_talker_hidden",
"tts_pad_embed",
"tailing_text_hidden",
}

# Tokenizer for prompt building.
self._tokenizer = None
self._speech_tokenizer: Qwen3TTSTokenizer | None = None
Expand Down Expand Up @@ -547,12 +555,13 @@ def preprocess(
full_prompt_embeds, tailing_text_hidden, tts_pad_embed, ref_code_len, ref_code = (
self._build_prompt_embeds(task_type=task_type, info_dict=info_dict)
)
# Store full prompt embeddings + trailing queue on CPU for later chunks/steps.
# Store full prompt embeddings on CPU (large, prefill-only).
# tailing_text_hidden and tts_pad_embed stay on GPU (gpu_resident_buffer_keys).
prompt_embeds_cpu = full_prompt_embeds.detach().to("cpu").contiguous()
info_update: dict[str, Any] = {
"talker_prompt_embeds": prompt_embeds_cpu,
"tailing_text_hidden": tailing_text_hidden.detach().to("cpu").contiguous(),
"tts_pad_embed": tts_pad_embed.detach().to("cpu").contiguous(),
"tailing_text_hidden": tailing_text_hidden.detach(),
"tts_pad_embed": tts_pad_embed.detach(),
"talker_prefill_offset": 0,
"codec_streaming": codec_streaming,
}
Expand All @@ -568,7 +577,7 @@ def preprocess(
take = prompt_embeds_cpu[s:e]
if int(take.shape[0]) < span_len:
pad_n = int(span_len - int(take.shape[0]))
pad_rows = tts_pad_embed.detach().to("cpu").contiguous().reshape(1, -1).expand(pad_n, -1)
pad_rows = tts_pad_embed.reshape(1, -1).to("cpu").expand(pad_n, -1)
take = torch.cat([take, pad_rows], dim=0)
prompt_embeds = take.to(device=input_ids.device, dtype=torch.bfloat16)
info_update["talker_prefill_offset"] = int(offset + span_len)
Expand All @@ -584,7 +593,7 @@ def preprocess(
take = prompt_embeds_cpu[s:e]
if int(take.shape[0]) < span_len:
pad_n = int(span_len - int(take.shape[0]))
pad_rows = tts_pad_embed.detach().to("cpu").contiguous().reshape(1, -1).expand(pad_n, -1)
pad_rows = tts_pad_embed.reshape(1, -1).to("cpu").expand(pad_n, -1)
take = torch.cat([take, pad_rows], dim=0)
prompt_embeds = take.to(device=input_ids.device, dtype=torch.bfloat16)
info_update = {"talker_prefill_offset": int(offset + span_len)}
Expand All @@ -604,23 +613,24 @@ def preprocess(

# Decode: span_len == 1
# Pop one text-step vector from tailing_text_hidden queue.
tts_pad_embed_cpu = info_dict.get("tts_pad_embed")
if not isinstance(tts_pad_embed_cpu, torch.Tensor):
# These tensors stay on GPU via gpu_resident_buffer_keys - .to() is a no-op.
tts_pad_embed_buf = info_dict.get("tts_pad_embed")
if not isinstance(tts_pad_embed_buf, torch.Tensor):
raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must run first.")
tts_pad_embed = tts_pad_embed_cpu.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)
tts_pad_embed = tts_pad_embed_buf.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)

tail_cpu = info_dict.get("tailing_text_hidden")
if isinstance(tail_cpu, torch.Tensor) and tail_cpu.ndim == 2 and tail_cpu.shape[0] > 0:
text_step = tail_cpu[:1].to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)
new_tail = tail_cpu[1:].detach().to("cpu").contiguous() if tail_cpu.shape[0] > 1 else tail_cpu[:0]
tail = info_dict.get("tailing_text_hidden")
if isinstance(tail, torch.Tensor) and tail.ndim == 2 and tail.shape[0] > 0:
text_step = tail[:1].to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)
new_tail = tail[1:] if tail.shape[0] > 1 else tail[:0]
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On GPU, new_tail = tail[1:] is a view into the original tensor storage. Even as the logical length shrinks, the full original allocation can remain live, and repeated slicing can keep large GPU buffers resident longer than intended. Consider tracking an integer offset into tailing_text_hidden (no slicing / no copies) or, if you do want to physically drop consumed elements, make new_tail a new tensor (e.g., contiguous()/clone()) before storing it back so the old storage can be released.

Suggested change
new_tail = tail[1:] if tail.shape[0] > 1 else tail[:0]
if tail.shape[0] > 1:
# Materialize a new tensor for the remaining queue to avoid keeping
# the original (potentially large) GPU storage alive via a view.
new_tail = tail[1:].contiguous()
else:
# Create a truly empty tensor with matching feature dimension.
new_tail = torch.empty(
(0, tail.shape[1]),
device=tail.device,
dtype=tail.dtype,
)

Copilot uses AI. Check for mistakes.
else:
text_step = tts_pad_embed
new_tail = tail_cpu if isinstance(tail_cpu, torch.Tensor) else torch.empty((0, tts_pad_embed.shape[-1]))
new_tail = tail if isinstance(tail, torch.Tensor) else torch.empty((0, tts_pad_embed.shape[-1]))

last_hidden_cpu = info_dict.get("last_talker_hidden")
if not isinstance(last_hidden_cpu, torch.Tensor):
last_hidden = info_dict.get("last_talker_hidden")
if not isinstance(last_hidden, torch.Tensor):
raise RuntimeError("Missing `last_talker_hidden` in additional_information; postprocess must run.")
past_hidden = last_hidden_cpu.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)
past_hidden = last_hidden.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)

# Use OmniGPUModelRunner talker_mtp fast-path for residual codebooks and per-step inputs_embeds update.
last_id_hidden = self.embed_input_ids(input_ids.reshape(1, 1).to(torch.long)).to(
Expand All @@ -637,10 +647,10 @@ def preprocess(

def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]:
# Keep the last token hidden for the next decode step's code predictor.
# Stays on GPU - gpu_resident_buffer_keys avoids the CPU round-trip.
if hidden_states.numel() == 0:
return {}
last = hidden_states[-1, :].detach().to("cpu").contiguous()
return {"last_talker_hidden": last}
return {"last_talker_hidden": hidden_states[-1, :].detach()}

# -------------------- prompt construction helpers --------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def talker2code2wav_async_chunk(
if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0 and request_payload.get(request_id) is None:
request_payload[request_id] = ref_code.to(torch.long).cpu().contiguous()
elif not finished:
# Some steps may not produce pooling_output. Only flush on finish.
return None

connector = getattr(transfer_manager, "connector", None)
Expand Down Expand Up @@ -150,7 +149,7 @@ def talker2code2wav_async_chunk(
if finished:
return {
"code_predictor_codes": [],
"finished": torch.tensor(True, dtype=torch.bool),
"finished": True,
}
return None

Expand Down Expand Up @@ -191,10 +190,12 @@ def talker2code2wav_async_chunk(
window_frames = ref_frames + window_frames
left_context_size += len(ref_frames)

Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

window_frames[0] will raise IndexError if window_frames is ever empty. The previous torch.tensor(window_frames) path would also fail on empty input, but if empty windows are possible in some streaming boundary conditions, this needs an explicit guard (e.g., handle empty by returning code_predictor_codes = [] and still emitting finished/left_context_size).

Suggested change
# Handle potential empty window to avoid IndexError and follow empty-window behavior.
if not window_frames:
return {
"code_predictor_codes": [],
"left_context_size": left_context_size,
"finished": finished,
}

Copilot uses AI. Check for mistakes.
code_predictor_codes = torch.tensor(window_frames).transpose(0, 1).reshape(-1).tolist()
num_quantizers = len(window_frames[0])
num_frames = len(window_frames)
code_predictor_codes = [window_frames[f][q] for q in range(num_quantizers) for f in range(num_frames)]

return {
"code_predictor_codes": code_predictor_codes,
"left_context_size": left_context_size,
"finished": torch.tensor(finished, dtype=torch.bool),
"finished": finished,
}
Loading