Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _build(
stage_id: int = 1,
model_mode: str = "ar",
max_num_seqs: int = 2,
save_queue_max_size: int = 0,
active_stream_window: int = 0,
connector_extra: dict | None = None,
):
Expand All @@ -59,13 +60,16 @@ def _build(
connector.get.return_value = None
connector.put.return_value = (True, 1, {})

cap = save_queue_max_size

def _fake_base_init(self, config):
self.config = config
self._pending_load_reqs = deque()
self._finished_load_reqs = set()
self._cancelled_load_reqs = set()
self._pending_save_reqs = deque()
self._finished_save_reqs = set()
self._save_semaphore = threading.Semaphore(cap) if cap > 0 else None
self.stop_event = threading.Event()
self._recv_cond = threading.Condition()
self._save_cond = threading.Condition()
Expand Down Expand Up @@ -1170,3 +1174,87 @@ def test_deferred_finish_not_finished_still_emits_output(mocker: MockerFixture):
assert eco.outputs[0].finish_reason == "stop"
assert eco.outputs[0].kv_transfer_params is None
assert scheduler._pending_finish_reqs == []


def test_save_queue_disabled_by_default(build_adapter):
adapter, _ = build_adapter(stage_id=1, save_queue_max_size=0)
adapter.custom_process_next_stage_input_func = lambda **kwargs: {"x": [1], "finished": False}
adapter.save_async(pooling_output=None, request=_req("r0", RequestStatus.WAITING, external_req_id="e0"))

assert adapter._save_semaphore is None
assert len(adapter._pending_save_reqs) == 1


def test_save_queue_blocks_when_full(build_adapter):
adapter, _ = build_adapter(stage_id=1, save_queue_max_size=2)
adapter.custom_process_next_stage_input_func = lambda **kwargs: {"x": [1], "finished": False}
for i in range(2):
adapter.save_async(
pooling_output=None,
request=_req(f"r{i}", RequestStatus.WAITING, external_req_id=f"e{i}"),
)

blocked = threading.Thread(
target=adapter.save_async,
kwargs={
"pooling_output": None,
"request": _req("r2", RequestStatus.WAITING, external_req_id="e2"),
},
daemon=True,
)
blocked.start()
blocked.join(timeout=0.2)
assert blocked.is_alive()
assert len(adapter._pending_save_reqs) == 2

adapter._save_semaphore.release()
blocked.join(timeout=1.0)
assert not blocked.is_alive()
assert len(adapter._pending_save_reqs) == 3


def test_save_queue_stop_event_unblocks_producer(build_adapter):
adapter, _ = build_adapter(stage_id=1, save_queue_max_size=1)
adapter.custom_process_next_stage_input_func = lambda **kwargs: {"x": [1], "finished": False}
adapter.save_async(pooling_output=None, request=_req("r0", RequestStatus.WAITING, external_req_id="e0"))

done = threading.Event()

def producer():
adapter.save_async(pooling_output=None, request=_req("r1", RequestStatus.WAITING, external_req_id="e1"))
done.set()

t = threading.Thread(target=producer, daemon=True)
t.start()
assert not done.wait(timeout=0.2)

adapter.stop_event.set()
assert done.wait(timeout=1.0)
t.join(timeout=1.0)
assert len(adapter._pending_save_reqs) == 1


def test_save_queue_releases_slot_on_send_exception(build_adapter):
adapter, _ = build_adapter(stage_id=1, save_queue_max_size=1)
adapter.custom_process_next_stage_input_func = lambda **kwargs: {"x": [1], "finished": False}

def _boom(task):
raise RuntimeError("send failed")

adapter._send_single_request = _boom

# Acquires the only slot and enqueues the task.
adapter.save_async(pooling_output=None, request=_req("r0", RequestStatus.WAITING, external_req_id="e0"))
assert not adapter._save_semaphore.acquire(blocking=False)

loop = threading.Thread(target=adapter.save_loop, daemon=True)
loop.start()
try:
# save_loop pops the task, _send_single_request raises, the finally clause still releases.
assert adapter._save_semaphore.acquire(timeout=1.0)
finally:
adapter.stop_event.set()
with adapter._save_cond:
adapter._save_cond.notify_all()
loop.join(timeout=1.0)
assert not loop.is_alive()
2 changes: 2 additions & 0 deletions vllm_omni/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ class OmniModelConfig(ModelConfig):
task_type: str | None = None
enable_sleep_mode: bool = False
has_sampling_extra_args: bool = False
# Max in-flight chunk-transfer save tasks before producers block. 0 disables.
save_queue_max_size: int = 0

@property
def registry(self):
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/config/stage_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,8 @@ class DeployConfig:
data_parallel_size: int | None = None
pipeline_parallel_size: int | None = None
custom_voice_dir: str | None = None
# Cap in-flight chunk-transfer save tasks; 0 disables the bound.
save_queue_max_size: int = 0


_STAGE_RESERVED_KEYS = frozenset(
Expand Down Expand Up @@ -695,6 +697,7 @@ def load_deploy_config(path: str | Path) -> DeployConfig:
"data_parallel_size",
"pipeline_parallel_size",
"custom_voice_dir",
"save_queue_max_size",
):
if name in raw_dict:
kwargs[name] = raw_dict[name]
Expand Down Expand Up @@ -818,6 +821,7 @@ def _select_processor_funcs(
"pipeline_parallel_size",
"active_stream_window",
"custom_voice_dir",
"save_queue_max_size",
)


Expand Down
3 changes: 3 additions & 0 deletions vllm_omni/deploy/qwen3_tts_high_concurrency.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#
async_chunk: true

# Bound the in-flight save queue at high concurrency; default is disabled.
save_queue_max_size: 64

connectors:
connector_of_shared_memory:
name: SharedMemoryConnector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __init__(self, config: Any):
self._pending_save_reqs = deque()
# Requests that have successfully saved data
self._finished_save_reqs = set()
# Bound the in-flight save queue; 0 disables backpressure.
cap = int(getattr(config, "save_queue_max_size", 0) or 0)
self._save_semaphore = threading.Semaphore(cap) if cap > 0 else None

self.stop_event = threading.Event()
self._recv_cond = threading.Condition()
Expand Down Expand Up @@ -91,6 +94,9 @@ def save_loop(self):
self._send_single_request(task)
except Exception as e:
logger.warning(f"Error saving data for {task.get('request_id')}: {e}")
finally:
if self._save_semaphore is not None:
self._save_semaphore.release()

with self._save_cond:
if not self._pending_save_reqs and not self.stop_event.is_set():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ def save_async(
"is_finished": is_finished,
"is_segment_finished": is_segment_finished,
}
# Block when the in-flight queue is full; bail on shutdown.
if self._save_semaphore is not None:
while not self.stop_event.is_set():
if self._save_semaphore.acquire(timeout=0.05):
break
else:
return
self._pending_save_reqs.append(task)
with self._save_cond:
self._save_cond.notify()
Expand Down
3 changes: 3 additions & 0 deletions vllm_omni/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class OmniEngineArgs(EngineArgs):
# in __post_init__ based on worker_type (ar/generation), so None is safe here.
enable_sleep_mode: bool = False
omni: bool = False
# Max in-flight chunk-transfer save tasks before producers block. 0 disables.
save_queue_max_size: int = 0

@classmethod
def _add_omni_specific_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Expand Down Expand Up @@ -330,6 +332,7 @@ def create_model_config(self) -> OmniModelConfig:
omni_kv_config=self.omni_kv_config,
task_type=self.task_type,
has_sampling_extra_args=self.has_sampling_extra_args,
save_queue_max_size=self.save_queue_max_size,
)
return omni_config

Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/engine/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ async def _orchestration_loop(self) -> None:
idle = False

if idle:
await asyncio.sleep(0.001)
await asyncio.sleep(0.0001)
else:
await asyncio.sleep(0)

Expand Down
Loading