From 1bcf6b82221ad56ab6f73549170d95944b007652 Mon Sep 17 00:00:00 2001 From: thrashingstate Date: Fri, 13 Mar 2026 11:09:46 -0700 Subject: [PATCH 1/2] feat: streaming text input for Qwen3-TTS --- vllm_omni/core/sched/omni_ar_scheduler.py | 101 +++++++++++++- vllm_omni/core/sched/output.py | 5 + vllm_omni/entrypoints/async_omni.py | 15 +++ vllm_omni/entrypoints/omni_stage.py | 20 ++- vllm_omni/entrypoints/openai/api_server.py | 119 +++++++++++++++++ .../entrypoints/openai/serving_speech.py | 123 ++++++++++++++++++ vllm_omni/entrypoints/stage_utils.py | 1 + .../models/qwen3_tts/qwen3_tts_talker.py | 89 +++++++++++-- vllm_omni/outputs.py | 3 + vllm_omni/patch.py | 39 ++++++ vllm_omni/worker/gpu_ar_model_runner.py | 6 + vllm_omni/worker/gpu_model_runner.py | 28 +++- 12 files changed, 531 insertions(+), 18 deletions(-) diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index 540fc391c51..8b6b2860fbe 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections import defaultdict +from collections import defaultdict, deque from dataclasses import asdict, dataclass from time import time from typing import Any @@ -48,6 +48,13 @@ class OmniARScheduler(VLLMScheduler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # Pending additional_information updates for running requests. + # Populated by update_request_additional_info(), consumed by schedule(). + self._pending_additional_info_updates: dict[str, list[dict[str, Any]]] = defaultdict(list) + # Early-arriving updates for requests not yet registered. + # Keyed by external request ID, flushed when the request arrives. + self._early_additional_info_updates: dict[str, list[dict[str, Any]]] = defaultdict(list) + # Track requests that need KV cache transfer when finished # Value is {"seq_len": int, "block_ids": list[int]} self.requests_needing_kv_transfer: dict[str, dict[str, Any]] = {} @@ -68,6 +75,69 @@ def __init__(self, *args, **kwargs): if getattr(model_config, "async_chunk", False): self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config) + # ------------------------------------------------------------------ # + # Incremental additional_information updates for running requests # + # ------------------------------------------------------------------ # + + def update_request_additional_info( + self, request_id: str, update: dict[str, Any] + ) -> None: + """Queue an additional_information update for a running request. + + The update is delivered to the model runner on the next scheduler + step via OmniSchedulerOutput. Resolves the external request ID + immediately if possible (needed to resume paused requests without + waiting for the next schedule() call). + """ + resolved_id = self._resolve_streaming_request_id(request_id) + if resolved_id is not None: + self._pending_additional_info_updates[resolved_id].append(update) + self._maybe_resume_request(resolved_id) + else: + self._early_additional_info_updates[request_id].append(update) + + def _resolve_streaming_request_id(self, external_id: str) -> str | None: + """Resolve an external request ID to the internal (suffixed) ID. + + vLLM appends a random suffix to request IDs (input_processor.py:519). + """ + if external_id in self.requests: + return external_id + for rid, req in self.requests.items(): + if getattr(req, "external_req_id", None) == external_id: + return rid + return None + + def _maybe_resume_request(self, req_id: str) -> None: + """Resume a paused request if it was waiting for an update.""" + req = self.requests.get(req_id) + if req is not None and req.status == RequestStatus.WAITING_FOR_CHUNK: + req.status = RequestStatus.RUNNING + if req not in self.running: + self.running.append(req) + + def _flush_early_updates(self) -> None: + """Drain buffered updates for requests that have now been registered.""" + if not self._early_additional_info_updates: + return + flushed_keys = [] + for ext_id, updates in list(self._early_additional_info_updates.items()): + resolved = self._resolve_streaming_request_id(ext_id) + if resolved is not None: + self._pending_additional_info_updates[resolved].extend(updates) + flushed_keys.append(ext_id) + self._maybe_resume_request(resolved) + for k in flushed_keys: + del self._early_additional_info_updates[k] + + def _drain_pending_additional_info_updates(self) -> dict[str, list[dict[str, Any]]]: + """Pop all pending updates. Called during schedule().""" + if not self._pending_additional_info_updates: + return {} + updates = dict(self._pending_additional_info_updates) + self._pending_additional_info_updates = defaultdict(list) + return updates + def _get_kv_transfer_criteria(self) -> dict | None: # Note: vllm_config is available in Scheduler after super().__init__ if not hasattr(self, "vllm_config"): @@ -135,7 +205,17 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int return False + def add_request(self, *args, **kwargs): + """Override to flush early updates when a new request is registered.""" + super().add_request(*args, **kwargs) + # Flush early-arriving streaming updates that were buffered + # before this request was registered. + self._flush_early_updates() + def schedule(self) -> SchedulerOutput: # type: ignore[override] + # Flush early-arriving updates for newly registered requests. + self._flush_early_updates() + if self.chunk_transfer_adapter: self.chunk_transfer_adapter.process_pending_chunks(self.waiting, self.running) @@ -185,9 +265,13 @@ def schedule(self) -> SchedulerOutput: # type: ignore[override] # Wrap in omni scheduler output to carry transfer metadata. base_fields = SchedulerOutput.__dataclass_fields__.keys() base_data = {name: getattr(scheduler_output, name) for name in base_fields} + # Drain pending additional_information updates for this step. + ai_updates = self._drain_pending_additional_info_updates() + return OmniSchedulerOutput( **base_data, finished_requests_needing_kv_transfer=finished_reqs, + additional_information_updates=ai_updates, ) def update_from_output( @@ -195,6 +279,21 @@ def update_from_output( scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: + # Pause requests that signalled _streaming_needs_text via the model + # runner output. Skip pausing if updates are already queued. + pause_ids = getattr(model_runner_output, "streaming_pause_req_ids", None) + if pause_ids: + paused = set() + for req_id in pause_ids: + if req_id in self._pending_additional_info_updates: + continue + req = self.requests.get(req_id) + if req is not None and not req.is_finished(): + req.status = RequestStatus.WAITING_FOR_CHUNK + paused.add(req) + if paused: + self.running = deque(r for r in self.running if r not in paused) + sampled_token_ids = model_runner_output.sampled_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict diff --git a/vllm_omni/core/sched/output.py b/vllm_omni/core/sched/output.py index c7a8c07c5ce..8503f783128 100644 --- a/vllm_omni/core/sched/output.py +++ b/vllm_omni/core/sched/output.py @@ -75,3 +75,8 @@ class OmniSchedulerOutput(SchedulerOutput): """Scheduler output with omni-specific transfer metadata.""" finished_requests_needing_kv_transfer: dict[str, dict] = field(default_factory=dict) + + # Per-request additional_information updates for running requests. + # Maps request_id → list of update dicts to merge into the request's + # additional_information_cpu on the model runner. + additional_information_updates: dict[str, list[dict]] = field(default_factory=dict) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 1ba95124a93..8aa35360217 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -840,6 +840,21 @@ async def abort(self, request_id: str | Iterable[str]) -> None: stage.submit(abort_task) return None + def update_request( + self, request_id: str, update: dict, *, stage_id: int = 0 + ) -> None: + """Send an additional_information update to a running request. + + The update dict is merged into the request's additional_information_cpu + on the model runner side. Flow: stage worker → EngineCore → Scheduler + → SchedulerOutput → ModelRunner. + """ + self.stage_list[stage_id].submit({ + "type": OmniStageTaskType.UPDATE, + "request_id": request_id, + "update": update, + }) + async def get_vllm_config(self) -> VllmConfig: for stage in self.stage_list: if stage.is_comprehension: diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 1298322676c..fba3a4d318d 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -1504,6 +1504,7 @@ async def generation_single_request(task: dict[str, Any]): ein = cast(PromptType, ein) llm_sampling_params: SamplingParams = task["sampling_params"] gen_output = None + async for res in cast(AsyncLLM, stage_engine).generate(ein, llm_sampling_params, rid): gen_output = res _gen_t1 = _time.time() @@ -1520,10 +1521,15 @@ async def generation_single_request(task: dict[str, Any]): } ) + # Cache the UPDATE request type for the main loop (avoid per-iteration import). + from vllm.v1.engine import EngineCoreRequestType as _ECRType + _update_request_type = _ECRType.UPDATE + _batch_gen_t0 = _time.time() while True: try: task = in_q.get_nowait() + task_type = task.get("type", OmniStageTaskType.GENERATE) if task_type == OmniStageTaskType.SHUTDOWN: logger.debug("Received shutdown signal") @@ -1532,9 +1538,21 @@ async def generation_single_request(task: dict[str, Any]): elif task_type == OmniStageTaskType.ABORT: rid = task["request_id"] asyncio.create_task(stage_engine.abort(rid)) + elif task_type == OmniStageTaskType.UPDATE: + # Route additional_information update to the EngineCore's scheduler + # via the AsyncLLM engine_core client (fire-and-forget IPC). + rid = task["request_id"] + upd = task["update"] + try: + _engine_core = getattr(stage_engine, "engine_core", None) + if _engine_core is not None: + _engine_core._send_input( + _update_request_type, (rid, upd) + ) + except Exception as e: + logger.warning(f"Failed to send UPDATE for {rid}: {e}") elif is_profiler_task(task_type): profiler_data = await handle_profiler_task_async(task_type) - # Send result back to orchestrator for STOP command if task_type == OmniStageTaskType.PROFILER_STOP: out_q.put({"type": "profiler_result", "data": profiler_data}) elif task_type == OmniStageTaskType.COLLECTIVE_RPC: diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 1507ce8c3f7..6b0c40c4caf 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -235,6 +235,125 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, await omni_init_app_state(engine_client, app.state, args) + # --- Streaming TTS WebSocket endpoint --- + from fastapi import WebSocket, WebSocketDisconnect + + @app.websocket("/v1/audio/speech/stream") + async def stream_speech_ws(ws: WebSocket): + """Streaming text-to-speech via WebSocket. + + Client sends text chunks incrementally, server generates + one continuous audio stream with no voice discontinuity. + + Protocol: + Client → Server: + {"type": "start", "voice": "dylan", "language": "English", + "initial_text": "First words from LLM"} + {"type": "text", "content": "more words"} + {"type": "end"} + + Server → Client: + {"type": "started", "request_id": "..."} + {"type": "audio", "data": ""} + {"type": "done"} + """ + import asyncio as _aio + + await ws.accept() + _req_id = f"stream-{uuid.uuid4().hex[:16]}" + + try: + start_msg = await _aio.wait_for(ws.receive_json(), timeout=30) + if start_msg.get("type") != "start": + await ws.send_json({"type": "error", "message": "Expected 'start'"}) + return + + _voice = start_msg.get("voice", "dylan") + _lang = start_msg.get("language", "English") + _instr = start_msg.get("instructions", "") + _init_text = start_msg.get("initial_text") or start_msg.get("text", "") + + if not _init_text: + msg = await _aio.wait_for(ws.receive_json(), timeout=30) + if msg.get("type") == "text": + _init_text = msg["content"] + elif msg.get("type") == "end": + await ws.send_json({"type": "done"}) + return + + if not _init_text.strip(): + await ws.send_json({"type": "error", "message": "No text"}) + return + + logger.info(f"[{_req_id}] Streaming TTS: voice={_voice}, initial='{_init_text[:50]}'") + await ws.send_json({"type": "started", "request_id": _req_id}) + + # Async generator that yields text chunks from the WebSocket + _text_queue: _aio.Queue[str | None] = _aio.Queue() + + async def _recv_text(): + try: + while True: + msg = await _aio.wait_for(ws.receive_json(), timeout=120) + if msg.get("type") == "text": + await _text_queue.put(msg["content"]) + elif msg.get("type") == "end": + await _text_queue.put(None) + break + except (_aio.TimeoutError, WebSocketDisconnect): + await _text_queue.put(None) + + async def _text_stream(): + while True: + chunk = await _text_queue.get() + if chunk is None: + break + yield chunk + + _recv_task = _aio.create_task(_recv_text()) + + # Use the streaming speech method + _ss = app.state.openai_serving_speech + _t0 = time.time() + _cc = 0 + + async for audio_chunk in _ss.create_speech_streaming( + initial_text=_init_text, + text_stream=_text_stream(), + voice=_voice, + language=_lang, + instructions=_instr, + ): + if audio_chunk: + await ws.send_json({ + "type": "audio", + "data": base64.b64encode(audio_chunk).decode("utf-8"), + }) + _cc += 1 + if _cc == 1: + logger.info(f"[{_req_id}] First audio at {time.time()-_t0:.3f}s") + + logger.info(f"[{_req_id}] Done: {_cc} chunks, {time.time()-_t0:.2f}s") + await ws.send_json({"type": "done"}) + _recv_task.cancel() + + except WebSocketDisconnect: + logger.info(f"[{_req_id}] Client disconnected") + except Exception as e: + logger.error(f"[{_req_id}] Streaming TTS error: {e}", exc_info=True) + try: + await ws.send_json({"type": "error", "message": str(e)}) + except Exception: + pass + + logger.info("Registered /v1/audio/speech/stream WebSocket endpoint") + # --- End streaming TTS WebSocket endpoint --- + + # Conditionally register profiler endpoints based on config or env var + if _should_enable_profiler_endpoints(args): + logger.warning("Profiler endpoints are enabled. This should ONLY be used for local development!") + app.include_router(profiler_router) + vllm_config = await engine_client.get_vllm_config() # Check if pure diffusion mode (vllm_config will be None) diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 3b5d67542be..acf3db1651c 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -1038,3 +1038,126 @@ async def create_speech( except Exception as e: logger.exception("Speech generation failed: %s", e) return self.create_error_response(f"Speech generation failed: {e}") + + async def create_speech_streaming( + self, + initial_text: str, + text_stream, # AsyncGenerator[str, None] — yields text chunks + voice: str = "dylan", + language: str = "English", + instructions: str = "", + task_type: str = "CustomVoice", + ): + """Create streaming speech with true real-time text input. + + Text token IDs are sent via UPDATE_REQUEST and embedded on GPU. + The model's tailing_text_hidden queue is extended on the fly — + tts_eos_embed (from the initial prompt) stays at the end of the + queue, protected by a guard that only releases when + streaming_tts_eos is signalled. + + Args: + initial_text: First text chunk to start generation with + text_stream: Async generator yielding additional text chunks + voice: Speaker voice name + language: Language code + instructions: Style/emotion instructions + task_type: "CustomVoice", "VoiceDesign", or "Base" + + Yields: + bytes: PCM audio chunks as they're generated + """ + import time as _time + + from transformers import AutoTokenizer + + request_id = f"stream-tts-{random_uuid()}" + + model_path = self.engine_client.model_config.model + if not hasattr(self, "_streaming_tokenizer"): + self._streaming_tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True + ) + tok = self._streaming_tokenizer + + tts_params = { + "text": [initial_text], + "task_type": [task_type], + "language": [language], + "instruct": [instructions], + # non_streaming_mode=False places tts_eos_embed at end of + # tailing_text_hidden. _streaming_request enables guard_eos + # and scheduler pausing in the decode path. + "non_streaming_mode": [False], + "_streaming_request": [True], + } + if voice: + tts_params["speaker"] = [voice] + tts_params["max_new_tokens"] = [4096] + + ph_len = self._estimate_prompt_len(tts_params) + prompt = { + "prompt_token_ids": [1] * ph_len, + "additional_information": tts_params, + } + + sampling_params_list = self.engine_client.default_sampling_params_list + + generator = self.engine_client.generate( + prompt=prompt, + request_id=request_id, + sampling_params_list=sampling_params_list, + output_modalities=["audio"], + ) + + _total_stream_tokens = 0 + _eos_sent_time = None + + async def _feed_text(): + nonlocal _total_stream_tokens, _eos_sent_time + try: + async for text_chunk in text_stream: + if not text_chunk or not text_chunk.strip(): + continue + token_ids = tok.encode(text_chunk, add_special_tokens=False) + if not token_ids: + continue + _total_stream_tokens += len(token_ids) + self.engine_client.update_request(request_id, { + "streaming_text_token_ids": token_ids, + }) + except Exception as e: + logger.error(f"[{request_id}] Text stream error: {e}") + finally: + # Signal end of text: allows the model to pop tts_eos_embed + # from tailing_text_hidden (the "don't pop if not done" guard). + self.engine_client.update_request(request_id, { + "streaming_tts_eos": True, + }) + _eos_sent_time = _time.monotonic() + + feed_task = asyncio.create_task(_feed_text()) + + # Safety abort: if the model doesn't generate natural codec_eos, + # stop after an audio budget proportional to the text length. + # ~4 codec frames per text token × 80ms/frame × 24kHz × 2 bytes/sample × 2x margin. + BYTES_PER_TEXT_TOKEN = int(4 * 0.08 * 24000 * 2 * 2.0) # ~30720 + _total_audio_bytes = 0 + initial_tokens = len(tok.encode(initial_text, add_special_tokens=False)) + + async for pcm_bytes in self._generate_pcm_chunks(generator, request_id): + _total_audio_bytes += len(pcm_bytes) + yield pcm_bytes + + if _eos_sent_time is not None: + total_tokens = initial_tokens + _total_stream_tokens + max_audio_bytes = total_tokens * BYTES_PER_TEXT_TOKEN + if _total_audio_bytes > max_audio_bytes: + logger.info( + f"[{request_id}] Audio budget exceeded " + f"({_total_audio_bytes} > {max_audio_bytes}), aborting" + ) + await self.engine_client.abort(request_id) + break + + await feed_task diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py index 99f2042e24f..7870126ab56 100644 --- a/vllm_omni/entrypoints/stage_utils.py +++ b/vllm_omni/entrypoints/stage_utils.py @@ -34,6 +34,7 @@ class OmniStageTaskType(enum.Enum): GENERATE = "generate" ABORT = "abort" SHUTDOWN = "shutdown" + UPDATE = "update" PROFILER_START = "profiler_start" PROFILER_STOP = "profiler_stop" COLLECTIVE_RPC = "collective_rpc" diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index 72597249cdf..27bbb8c90a3 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -544,7 +544,7 @@ def preprocess( # Subsequent prefill rounds (multi-chunk): prompt_embeds_cpu is a Tensor stored by the first round. is_first_prefill = not isinstance(prompt_embeds_cpu, torch.Tensor) or prompt_embeds_cpu.ndim != 2 if is_first_prefill: - full_prompt_embeds, tailing_text_hidden, tts_pad_embed, ref_code_len, ref_code = ( + full_prompt_embeds, tailing_text_hidden, tts_pad_embed, tts_eos_embed, ref_code_len, ref_code = ( self._build_prompt_embeds(task_type=task_type, info_dict=info_dict) ) # Store full prompt embeddings + trailing queue on CPU for later chunks/steps. @@ -553,6 +553,7 @@ def preprocess( "talker_prompt_embeds": prompt_embeds_cpu, "tailing_text_hidden": tailing_text_hidden.detach().to("cpu").contiguous(), "tts_pad_embed": tts_pad_embed.detach().to("cpu").contiguous(), + "tts_eos_embed": tts_eos_embed.detach().to("cpu").contiguous(), "talker_prefill_offset": 0, "codec_streaming": codec_streaming, } @@ -573,14 +574,18 @@ def preprocess( prompt_embeds = take.to(device=input_ids.device, dtype=torch.bfloat16) info_update["talker_prefill_offset"] = int(offset + span_len) else: - # Subsequent prefill chunk: slice from stored embeddings at running offset. + # Subsequent prefill chunk: continuation of initial prompt. if tts_pad_embed is None: raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must initialize it.") offset = int(info_dict.get("talker_prefill_offset", 0) or 0) if offset < 0: offset = 0 - s = max(0, min(offset, int(prompt_embeds_cpu.shape[0]))) - e = max(0, min(offset + span_len, int(prompt_embeds_cpu.shape[0]))) + + prompt_embeds_len = int(prompt_embeds_cpu.shape[0]) + + # ---- CHUNKED PREFILL (original prompt continuation) ---- + s = max(0, min(offset, prompt_embeds_len)) + e = max(0, min(offset + span_len, prompt_embeds_len)) take = prompt_embeds_cpu[s:e] if int(take.shape[0]) < span_len: pad_n = int(span_len - int(take.shape[0])) @@ -609,10 +614,57 @@ def preprocess( raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must run first.") tts_pad_embed = tts_pad_embed_cpu.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + # ---- STREAMING TEXT INJECTION ---- + # streaming_text_token_ids: list[int] — new text to embed and enqueue + # streaming_tts_eos: True — signal that text stream is done + # + # tailing_text_hidden ends with tts_eos_embed from _build_prompt_embeds. + # The pop logic below will NOT pop the last element (eos) until + # streaming_tts_eos is received — this prevents premature EOS. + + token_ids_data = info_dict.get("streaming_text_token_ids") + eos_signal = info_dict.get("streaming_tts_eos") + + # Embed new text tokens on GPU and append to tailing_text_hidden + if token_ids_data is not None: + if isinstance(token_ids_data, list): + ids_tensor = torch.tensor(token_ids_data, dtype=torch.long, device=input_ids.device) + elif isinstance(token_ids_data, torch.Tensor): + ids_tensor = token_ids_data.to(device=input_ids.device, dtype=torch.long) + else: + ids_tensor = None + if ids_tensor is not None and ids_tensor.numel() > 0: + with torch.no_grad(): + new_embeds = self.text_projection( + self.text_embedding(ids_tensor.unsqueeze(0)) + ).squeeze(0).detach().to("cpu").contiguous() + tail = info_dict.get("tailing_text_hidden") + if isinstance(tail, torch.Tensor) and tail.ndim == 2 and tail.shape[0] > 0: + # Insert before last element (tts_eos_embed). + # Works for any tail length: [:-1] is empty when len==1. + info_dict["tailing_text_hidden"] = torch.cat( + [tail[:-1], new_embeds, tail[-1:]], dim=0 + ) + else: + info_dict["tailing_text_hidden"] = new_embeds + + # Pop one text embedding from the front of tailing_text_hidden. + # Special case: if only 1 element remains, only pop it if the text + # stream is done (eos has been appended). Otherwise use pad embed + # and wait for more text to arrive. tail_cpu = info_dict.get("tailing_text_hidden") + is_streaming = bool(info_dict.get("_streaming_request")) + text_done = bool(info_dict.get("_streaming_text_done")) or bool(eos_signal) + if isinstance(tail_cpu, torch.Tensor) and tail_cpu.ndim == 2 and tail_cpu.shape[0] > 0: - text_step = tail_cpu[:1].to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) - new_tail = tail_cpu[1:].detach().to("cpu").contiguous() if tail_cpu.shape[0] > 1 else tail_cpu[:0] + if tail_cpu.shape[0] == 1 and is_streaming and not text_done: + # Last element in queue (eos) but text stream not done — wait. + text_step = tts_pad_embed + new_tail = tail_cpu + else: + # Normal pop + text_step = tail_cpu[:1].to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + new_tail = tail_cpu[1:].detach().to("cpu").contiguous() if tail_cpu.shape[0] > 1 else tail_cpu[:0] else: text_step = tts_pad_embed new_tail = tail_cpu if isinstance(tail_cpu, torch.Tensor) else torch.empty((0, tts_pad_embed.shape[-1])) @@ -628,10 +680,25 @@ def preprocess( ) inputs_embeds_out = last_id_hidden.reshape(1, -1) + # Determine whether the model should pause for more text. + # If only eos remains and text stream is not done, signal the + # scheduler to pause this request until more text arrives. + needs_text = ( + is_streaming and not text_done + and isinstance(new_tail, torch.Tensor) and new_tail.ndim == 2 + and new_tail.shape[0] <= 1 + ) + info_update = { "tailing_text_hidden": new_tail, "mtp_inputs": (past_hidden, text_step), "codec_streaming": codec_streaming, + # Clear consumed updates + "streaming_text_token_ids": None, + "streaming_tts_eos": None, + # Persist streaming state + "_streaming_text_done": text_done if text_done else None, + "_streaming_needs_text": True if needs_text else None, } return input_ids, inputs_embeds_out, info_update @@ -1190,7 +1257,7 @@ def _build_prompt_embeds( *, task_type: str, info_dict: dict[str, Any], - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int | None, torch.Tensor | None]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int | None, torch.Tensor | None]: text = (info_dict.get("text") or [""])[0] language = (info_dict.get("language") or ["Auto"])[0] non_streaming_mode_val = info_dict.get("non_streaming_mode") @@ -1417,12 +1484,9 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: else: first_text = self.text_projection(self.text_embedding(input_ids[:, 3:4])) + codec_input[:, -1:] talker_prompt = torch.cat([talker_prompt, first_text], dim=1) + text_embeds = self.text_projection(self.text_embedding(input_ids[:, 4:-5])) trailing_text_hidden = torch.cat( - ( - self.text_projection(self.text_embedding(input_ids[:, 4:-5])), - tts_eos_embed, - ), - dim=1, + (text_embeds, tts_eos_embed), dim=1, ) elif task_type == "CustomVoice": @@ -1530,6 +1594,7 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: talker_prompt.squeeze(0), # [prompt_len, H] trailing_text_hidden.squeeze(0), # [T, H] tts_pad_embed.squeeze(0), # [1, H] + tts_eos_embed.squeeze(0), # [1, H] ref_code_len, ref_code_prompt.contiguous() if isinstance(ref_code_prompt, torch.Tensor) else None, ) diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index 6f035a81ec7..a1d1bd7fb6a 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -24,6 +24,9 @@ class OmniModelRunnerOutput(ModelRunnerOutput): # IDs of requests whose KV cache has been extracted from GPU/NPU to CPU. # The Scheduler can safely free the block tables for these requests. kv_extracted_req_ids: list[str] | None = None + # IDs of requests that need more input before continuing. + # Scheduler pauses these until an UPDATE arrives. + streaming_pause_req_ids: list[str] | None = None @dataclass diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py index c9dea609e61..b2f2c13a547 100644 --- a/vllm_omni/patch.py +++ b/vllm_omni/patch.py @@ -3,6 +3,13 @@ from aenum import extend_enum from vllm.config import ModelConfig as _ModelConfig +from vllm.v1.engine import EngineCoreRequestType as _ECRType + +# Add UPDATE request type for streaming additional_information updates. +# This follows the same fire-and-forget pattern as ABORT. +if not hasattr(_ECRType, "UPDATE"): + extend_enum(_ECRType, "UPDATE", b"\x05") + from vllm.inputs.data import TokensPrompt as _OriginalTokensPrompt from vllm.model_executor.layers.rotary_embedding import ( MRotaryEmbedding as _OriginalMRotaryEmbedding, @@ -65,6 +72,38 @@ def _patched_is_mm_prefix_lm(self) -> bool: # as a non-finished state and remains compatible with existing comparisons. extend_enum(RequestStatus, "WAITING_FOR_CHUNK", -1) + +# --------------------------------------------------------------------------- +# Patch EngineCore._handle_client_request to support UPDATE messages. +# UPDATE carries (request_id, update_dict) and routes to the scheduler's +# update_request_additional_info() method. Fire-and-forget, same as ABORT. +# --------------------------------------------------------------------------- +from vllm.v1.engine.core import EngineCoreProc as _EngineCoreProc + +_original_handle_client_request = _EngineCoreProc._handle_client_request + + +def _patched_handle_client_request(self, request_type, request): + if request_type == _ECRType.UPDATE: + try: + req_id, update_dict = request + except (TypeError, ValueError): + return + if hasattr(self.scheduler, "update_request_additional_info"): + self.scheduler.update_request_additional_info(req_id, update_dict) + return + return _original_handle_client_request(self, request_type, request) + + +_EngineCoreProc._handle_client_request = _patched_handle_client_request + +# Also patch DPEngineCoreProc if it exists +try: + from vllm.v1.engine.core import DPEngineCoreProc as _DPEngineCoreProc + _DPEngineCoreProc._handle_client_request = _patched_handle_client_request +except ImportError: + pass + for module_name, module in sys.modules.items(): # only do patch on module of vllm, pass others if "vllm" not in module_name: diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index d7d45031af4..a8d21a97cb1 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -69,6 +69,7 @@ class only overrides sample_tokens to expose hidden states + multimodal def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._streaming_pause_req_ids: list[str] | None = None self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) # each model stage has their own hidden size self.hidden_size = self.model_config.hf_text_config.hidden_size @@ -610,6 +611,11 @@ def propose_draft_token_ids(sampled_token_ids): cudagraph_stats=cudagraph_stats, ) output.kv_extracted_req_ids = kv_extracted_req_ids + # Propagate pause signals from preprocess to scheduler + pause_ids = getattr(self, "_streaming_pause_req_ids", None) + if pause_ids: + output.streaming_pause_req_ids = pause_ids + self._streaming_pause_req_ids = None if not self.use_async_scheduling: return output diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index e17d52bdd53..a32ef1faed1 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1070,6 +1070,13 @@ def _update_additional_information(self, scheduler_output: "SchedulerOutput") -> for req_id, req_infos in cached_infos.items(): self._update_intermediate_buffer(req_id, req_infos) + # Process incremental additional_information updates from OmniSchedulerOutput. + ai_updates = getattr(scheduler_output, "additional_information_updates", None) + if ai_updates: + for req_id, update_list in ai_updates.items(): + for upd in update_list: + self._update_intermediate_buffer(req_id, upd) + def _maybe_attach_mimo_audio_req_infos( self, req_state: CachedRequestState | None, @@ -1212,9 +1219,10 @@ def _preprocess( # cached requests. This is required for stages without preprocess # (e.g., code2wav) so runtime_additional_information can be refreshed # from scheduler cached infos on every step. + # Always call _update_additional_information (handles both regular + # and incremental updates from OmniSchedulerOutput). if hasattr(self.model, "has_preprocess") or hasattr(self.model, "enable_update_additional_information"): - if self.vllm_config.model_config.async_chunk: - self._update_additional_information(scheduler_output) + self._update_additional_information(scheduler_output) if hasattr(self.model, "has_preprocess") and self.model.has_preprocess: # Overlay custom prompt_embeds per request for the prompt portion; @@ -1253,8 +1261,14 @@ def _preprocess( self.text_step.gpu[decode_slice].copy_(text_step) decode_req_ids.append(req_id) + # Propagate pause signal if the model needs more input + if update_dict.get("_streaming_needs_text"): + if not getattr(self, "_streaming_pause_req_ids", None): + self._streaming_pause_req_ids = [] + self._streaming_pause_req_ids.append(req_id) + # TODO(Peiqi): the merge stage could move out from the critical path - self._merge_additional_information_update(req_id, update_dict) + self._update_intermediate_buffer(req_id, update_dict) # update the inputs_embeds and input_ids seg_len = min(span_len, req_embeds.shape[0]) @@ -1333,6 +1347,9 @@ def _model_forward( self._omni_last_model_output = model_output return model_output + # Keys where multiple updates should be concatenated rather than replaced. + _APPEND_KEYS = frozenset({"streaming_text_token_ids"}) + def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None: if not isinstance(upd, dict) or not upd: return @@ -1345,7 +1362,10 @@ def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None: gpu_keys = self.model.gpu_resident_buffer_keys existing = self.model_intermediate_buffer.setdefault(req_id, {}) for k, v in upd.items(): - if isinstance(v, torch.Tensor): + # For append-mode keys, concatenate rather than replace. + if k in self._APPEND_KEYS and isinstance(v, list) and isinstance(existing.get(k), list): + existing[k] = existing[k] + v + elif isinstance(v, torch.Tensor): if k in gpu_keys: existing[k] = v.detach().clone() else: From 25fc012eb2a4458d833ea07b12d8bcc036a51fd1 Mon Sep 17 00:00:00 2001 From: thrashingstate Date: Fri, 13 Mar 2026 11:09:50 -0700 Subject: [PATCH 2/2] test: add streaming text input tests --- .../test_qwen3_tts_streaming.py | 286 +++++++++++++++++ tests/entrypoints/test_streaming_tts.py | 287 ++++++++++++++++++ 2 files changed, 573 insertions(+) create mode 100644 tests/e2e/online_serving/test_qwen3_tts_streaming.py create mode 100644 tests/entrypoints/test_streaming_tts.py diff --git a/tests/e2e/online_serving/test_qwen3_tts_streaming.py b/tests/e2e/online_serving/test_qwen3_tts_streaming.py new file mode 100644 index 00000000000..5ed226f2de7 --- /dev/null +++ b/tests/e2e/online_serving/test_qwen3_tts_streaming.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E Online tests for Qwen3-TTS streaming text input via WebSocket. + +These tests verify the /v1/audio/speech/stream endpoint works correctly +with actual model inference, sending text incrementally and receiving +progressive audio output. +""" + +import asyncio +import base64 +import json +import os + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" + +from pathlib import Path + +import httpx +import pytest +import websockets + +from tests.conftest import OmniServer +from tests.utils import hardware_test + +MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice" + +# Minimum expected audio size for a short sentence (~1 second of 24kHz 16-bit mono) +MIN_AUDIO_BYTES = 10000 + + +def get_stage_config(): + return str( + Path(__file__).parent.parent.parent.parent + / "vllm_omni" + / "model_executor" + / "stage_configs" + / "qwen3_tts.yaml" + ) + + +def verify_pcm_audio(chunks: list[bytes]) -> bool: + """Verify that audio chunks contain valid PCM data. + + Checks: + - At least one chunk received + - Total size above minimum threshold + - Each chunk has even byte count (int16 alignment) + - Audio data is not all zeros + """ + if not chunks: + return False + total = sum(len(c) for c in chunks) + if total < MIN_AUDIO_BYTES: + return False + # int16 PCM requires even byte count per chunk + if any(len(c) % 2 != 0 for c in chunks): + return False + # At least some non-zero audio data + all_zero = all(all(b == 0 for b in c) for c in chunks) + return not all_zero + + +@pytest.fixture(scope="module") +def omni_server(): + stage_config_path = get_stage_config() + with OmniServer( + MODEL, + [ + "--stage-configs-path", + stage_config_path, + "--stage-init-timeout", + "120", + "--trust-remote-code", + "--enforce-eager", + "--disable-log-stats", + ], + ) as server: + yield server + + +async def streaming_speech_request( + host: str, + port: int, + initial_text: str, + streaming_chunks: list[str] | None = None, + voice: str = "vivian", + chunk_delay: float = 0.03, + timeout: float = 60.0, +) -> tuple[int, list[bytes], str | None]: + """Send a streaming TTS request via WebSocket. + + Returns (total_audio_bytes, list_of_audio_chunks, error_message_or_none). + """ + uri = f"ws://{host}:{port}/v1/audio/speech/stream" + audio_chunks: list[bytes] = [] + error_msg = None + + async with asyncio.timeout(timeout): + async with websockets.connect(uri) as ws: + await ws.send(json.dumps({ + "type": "start", + "text": initial_text, + "voice": voice, + })) + + async def send_text(): + if streaming_chunks: + for chunk in streaming_chunks: + await asyncio.sleep(chunk_delay) + await ws.send(json.dumps({ + "type": "text", + "content": chunk, + })) + await ws.send(json.dumps({"type": "end"})) + + async def recv_audio(): + nonlocal error_msg + async for msg in ws: + data = json.loads(msg) + if data["type"] == "started": + continue + elif data["type"] == "audio": + audio_chunks.append(base64.b64decode(data["data"])) + elif data["type"] == "error": + error_msg = data.get("message", "unknown error") + break + elif data["type"] == "done": + break + + await asyncio.gather(send_text(), recv_audio()) + + total = sum(len(c) for c in audio_chunks) + return total, audio_chunks, error_msg + + +def make_baseline_request(host: str, port: int, text: str, voice: str = "vivian") -> int: + """Non-streaming baseline for comparison. Returns audio size in bytes.""" + url = f"http://{host}:{port}/v1/audio/speech" + with httpx.Client(timeout=120.0) as client: + resp = client.post(url, json={ + "input": text, + "voice": voice, + "non_streaming_mode": False, + }) + assert resp.status_code == 200 + return len(resp.content) + + +class TestQwen3TTSStreaming: + """E2E tests for streaming TTS text input via WebSocket.""" + + @pytest.mark.core_model + @pytest.mark.omni + @hardware_test(res={"cuda": "L4"}, num_cards=1) + def test_streaming_all_text_in_initial(self, omni_server) -> None: + """All text in initial message, no streaming chunks. Should produce + valid audio comparable to the non-streaming baseline.""" + text = "Hello, how are you today?" + baseline_bytes = make_baseline_request(omni_server.host, omni_server.port, text) + + total, chunks, error = asyncio.run(streaming_speech_request( + omni_server.host, omni_server.port, + initial_text=text, + )) + + assert error is None, f"Server returned error: {error}" + assert verify_pcm_audio(chunks), "Invalid PCM audio data" + assert len(chunks) >= 1, "Expected at least 1 audio chunk" + assert total < baseline_bytes * 2.5, ( + f"Streaming audio ({total}) much larger than baseline ({baseline_bytes})" + ) + + @pytest.mark.core_model + @pytest.mark.omni + @hardware_test(res={"cuda": "L4"}, num_cards=1) + def test_streaming_chunked_text(self, omni_server) -> None: + """Text split into initial + streaming chunks at typical LLM rate.""" + text = "Hello, I am going to tell you a story. Once upon a time." + words = text.split() + initial = " ".join(words[:4]) + remaining = [" " + w for w in words[4:]] + + baseline_bytes = make_baseline_request(omni_server.host, omni_server.port, text) + + total, chunks, error = asyncio.run(streaming_speech_request( + omni_server.host, omni_server.port, + initial_text=initial, + streaming_chunks=remaining, + chunk_delay=0.03, + )) + + assert error is None, f"Server returned error: {error}" + assert verify_pcm_audio(chunks), "Invalid PCM audio data" + assert len(chunks) > 1, "Expected multiple progressive audio chunks" + assert total < baseline_bytes * 2.5, ( + f"Streaming audio ({total}) much larger than baseline ({baseline_bytes})" + ) + + @pytest.mark.core_model + @pytest.mark.omni + @hardware_test(res={"cuda": "L4"}, num_cards=1) + def test_streaming_slow_delivery(self, omni_server) -> None: + """Text delivered slowly (100ms per word). Scheduler pausing should + prevent pad steps and the model should still stop naturally.""" + text = "Hello, this is a slow delivery test." + words = text.split() + initial = " ".join(words[:3]) + remaining = [" " + w for w in words[3:]] + + total, chunks, error = asyncio.run(streaming_speech_request( + omni_server.host, omni_server.port, + initial_text=initial, + streaming_chunks=remaining, + chunk_delay=0.1, + )) + + assert error is None, f"Server returned error: {error}" + assert verify_pcm_audio(chunks), "Invalid PCM audio data" + + @pytest.mark.core_model + @pytest.mark.omni + @hardware_test(res={"cuda": "L4"}, num_cards=1) + def test_streaming_sequential_requests(self, omni_server) -> None: + """Multiple sequential streaming requests should all complete + without hangs or state leaks between requests.""" + text = "Hello test." + for i in range(3): + total, chunks, error = asyncio.run(streaming_speech_request( + omni_server.host, omni_server.port, + initial_text=text, + )) + assert error is None, f"Request {i+1} returned error: {error}" + assert verify_pcm_audio(chunks), f"Request {i+1}: invalid PCM audio" + + @pytest.mark.core_model + @pytest.mark.omni + @hardware_test(res={"cuda": "L4"}, num_cards=1) + def test_streaming_response_is_audio_not_error(self, omni_server) -> None: + """Regression test: verify streaming returns binary audio data, + not JSON error messages disguised as audio chunks.""" + total, chunks, error = asyncio.run(streaming_speech_request( + omni_server.host, omni_server.port, + initial_text="This should return audio, not an error.", + )) + + assert error is None, f"Server returned error: {error}" + assert len(chunks) > 0, "No audio chunks received" + + # Verify chunks are binary audio, not JSON error strings + for i, chunk in enumerate(chunks): + try: + text = chunk.decode("utf-8") + assert not text.startswith("{"), ( + f"Chunk {i} appears to be JSON, not audio: {text[:100]}" + ) + except UnicodeDecodeError: + pass # Expected — binary audio can't be decoded as UTF-8 + + @pytest.mark.core_model + @pytest.mark.omni + @hardware_test(res={"cuda": "L4"}, num_cards=1) + def test_non_streaming_still_works(self, omni_server) -> None: + """Non-streaming /v1/audio/speech endpoint should still work + correctly after streaming requests.""" + # Do a streaming request first + asyncio.run(streaming_speech_request( + omni_server.host, omni_server.port, + initial_text="Streaming first.", + )) + + # Then verify non-streaming still works + url = f"http://{omni_server.host}:{omni_server.port}/v1/audio/speech" + with httpx.Client(timeout=120.0) as client: + resp = client.post(url, json={ + "input": "Non-streaming after streaming.", + "voice": "vivian", + }) + + assert resp.status_code == 200, f"Request failed: {resp.text}" + assert resp.headers.get("content-type") == "audio/wav" + from tests.e2e.online_serving.test_qwen3_tts import verify_wav_audio + assert verify_wav_audio(resp.content), "Response is not valid WAV audio" + assert len(resp.content) > MIN_AUDIO_BYTES diff --git a/tests/entrypoints/test_streaming_tts.py b/tests/entrypoints/test_streaming_tts.py new file mode 100644 index 00000000000..7ea70155a3a --- /dev/null +++ b/tests/entrypoints/test_streaming_tts.py @@ -0,0 +1,287 @@ +"""Tests for streaming TTS text input via UPDATE_REQUEST mechanism.""" +from collections import defaultdict, deque +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm_omni.core.sched.output import OmniSchedulerOutput +from vllm_omni.entrypoints.stage_utils import OmniStageTaskType + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +# --------------------------------------------------------------------------- +# Scheduler: update_request_additional_info +# --------------------------------------------------------------------------- + + +def _make_scheduler(): + """Create a minimal OmniARScheduler-like object for unit testing.""" + from vllm.v1.request import RequestStatus + + sched = SimpleNamespace( + requests={}, + running=deque(), + _pending_additional_info_updates=defaultdict(list), + _early_additional_info_updates=defaultdict(list), + ) + + # Bind methods from OmniARScheduler + from vllm_omni.core.sched.omni_ar_scheduler import OmniARScheduler + + sched.update_request_additional_info = ( + OmniARScheduler.update_request_additional_info.__get__(sched) + ) + sched._resolve_streaming_request_id = ( + OmniARScheduler._resolve_streaming_request_id.__get__(sched) + ) + sched._maybe_resume_request = ( + OmniARScheduler._maybe_resume_request.__get__(sched) + ) + sched._flush_early_updates = ( + OmniARScheduler._flush_early_updates.__get__(sched) + ) + sched._drain_pending_additional_info_updates = ( + OmniARScheduler._drain_pending_additional_info_updates.__get__(sched) + ) + return sched + + +def _make_request(req_id, external_req_id=None, status=None): + from vllm.v1.request import RequestStatus + + req = SimpleNamespace( + request_id=req_id, + external_req_id=external_req_id, + status=status or RequestStatus.RUNNING, + ) + req.is_finished = lambda: False + return req + + +class TestSchedulerUpdateRequest: + def test_update_goes_to_pending_when_request_exists(self): + sched = _make_scheduler() + req = _make_request("req-abc-1234", external_req_id="req-abc") + sched.requests["req-abc-1234"] = req + sched.running.append(req) + + sched.update_request_additional_info("req-abc", {"key": "val"}) + + assert "req-abc-1234" in sched._pending_additional_info_updates + assert len(sched._pending_additional_info_updates["req-abc-1234"]) == 1 + + def test_update_goes_to_early_when_request_not_registered(self): + sched = _make_scheduler() + + sched.update_request_additional_info("req-unknown", {"key": "val"}) + + assert "req-unknown" in sched._early_additional_info_updates + assert len(sched._pending_additional_info_updates) == 0 + + def test_early_updates_flushed_on_resolve(self): + sched = _make_scheduler() + + # Updates arrive before request registers + sched.update_request_additional_info("req-abc", {"k": 1}) + sched.update_request_additional_info("req-abc", {"k": 2}) + assert len(sched._early_additional_info_updates["req-abc"]) == 2 + + # Request registers + req = _make_request("req-abc-5678", external_req_id="req-abc") + sched.requests["req-abc-5678"] = req + sched.running.append(req) + + sched._flush_early_updates() + + assert len(sched._early_additional_info_updates) == 0 + assert len(sched._pending_additional_info_updates["req-abc-5678"]) == 2 + + def test_drain_clears_pending(self): + sched = _make_scheduler() + sched._pending_additional_info_updates["req-1"].append({"a": 1}) + sched._pending_additional_info_updates["req-1"].append({"b": 2}) + + result = sched._drain_pending_additional_info_updates() + + assert "req-1" in result + assert len(result["req-1"]) == 2 + assert len(sched._pending_additional_info_updates) == 0 + + def test_drain_returns_empty_when_nothing_pending(self): + sched = _make_scheduler() + assert sched._drain_pending_additional_info_updates() == {} + + def test_resume_paused_request_on_update(self): + from vllm.v1.request import RequestStatus + + sched = _make_scheduler() + req = _make_request( + "req-abc-1234", + external_req_id="req-abc", + status=RequestStatus.WAITING_FOR_CHUNK, + ) + sched.requests["req-abc-1234"] = req + + sched.update_request_additional_info("req-abc", {"key": "val"}) + + assert req.status == RequestStatus.RUNNING + assert req in sched.running + + def test_resolve_by_external_req_id(self): + sched = _make_scheduler() + req = _make_request("req-abc-suffix", external_req_id="req-abc") + sched.requests["req-abc-suffix"] = req + + resolved = sched._resolve_streaming_request_id("req-abc") + assert resolved == "req-abc-suffix" + + def test_resolve_returns_none_for_unknown(self): + sched = _make_scheduler() + assert sched._resolve_streaming_request_id("unknown") is None + + +# --------------------------------------------------------------------------- +# Model runner: _APPEND_KEYS and merge semantics +# --------------------------------------------------------------------------- + + +class TestModelRunnerMerge: + def test_append_keys_concatenate(self): + from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner + + # Verify _APPEND_KEYS contains the streaming key + assert "streaming_text_token_ids" in OmniGPUModelRunner._APPEND_KEYS + + def test_merge_appends_for_append_keys(self): + """Simulate the merge behavior for streaming_text_token_ids.""" + # The merge logic: if key in _APPEND_KEYS and both are lists, concatenate + existing = {"streaming_text_token_ids": [100, 200]} + new_update = {"streaming_text_token_ids": [300, 400]} + + # Simulate merge + from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner + + merged = dict(existing) + for k, v in new_update.items(): + if ( + k in OmniGPUModelRunner._APPEND_KEYS + and isinstance(v, list) + and isinstance(merged.get(k), list) + ): + merged[k] = merged[k] + v + else: + merged[k] = v + + assert merged["streaming_text_token_ids"] == [100, 200, 300, 400] + + def test_merge_replaces_when_existing_is_none(self): + """After clearing (None), new update should set, not append.""" + existing = {"streaming_text_token_ids": None} + new_update = {"streaming_text_token_ids": [500]} + + from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner + + merged = dict(existing) + for k, v in new_update.items(): + if ( + k in OmniGPUModelRunner._APPEND_KEYS + and isinstance(v, list) + and isinstance(merged.get(k), list) + ): + merged[k] = merged[k] + v + else: + merged[k] = v + + assert merged["streaming_text_token_ids"] == [500] + + +# --------------------------------------------------------------------------- +# OmniSchedulerOutput: additional_information_updates field +# --------------------------------------------------------------------------- + + +class TestSchedulerOutput: + def test_additional_information_updates_default_empty(self): + from vllm.v1.core.sched.output import SchedulerOutput + + base_fields = {name: None for name in SchedulerOutput.__dataclass_fields__} + base_fields["num_scheduled_tokens"] = {} + base_fields["scheduled_new_reqs"] = [] + base_fields["total_num_scheduled_tokens"] = 0 + base_fields["scheduled_spec_decode_tokens"] = {} + base_fields["scheduled_encoder_inputs"] = {} + base_fields["num_common_prefix_blocks"] = [0] + base_fields["finished_req_ids"] = [] + base_fields["free_encoder_mm_hashes"] = [] + base_fields["preempted_req_ids"] = [] + + output = OmniSchedulerOutput( + **base_fields, + finished_requests_needing_kv_transfer={}, + ) + assert output.additional_information_updates == {} + + +# --------------------------------------------------------------------------- +# OmniModelRunnerOutput: streaming_pause_req_ids field +# --------------------------------------------------------------------------- + + +class TestModelRunnerOutput: + def test_streaming_pause_req_ids_default_none(self): + from vllm_omni.outputs import OmniModelRunnerOutput + + output = OmniModelRunnerOutput( + req_ids=["r1"], + req_id_to_index={"r1": 0}, + ) + assert output.streaming_pause_req_ids is None + + +# --------------------------------------------------------------------------- +# OmniStageTaskType: UPDATE variant +# --------------------------------------------------------------------------- + + +class TestStageTaskType: + def test_update_task_type_exists(self): + assert hasattr(OmniStageTaskType, "UPDATE") + assert OmniStageTaskType.UPDATE.value == "update" + + +# --------------------------------------------------------------------------- +# AsyncOmni: update_request routing +# --------------------------------------------------------------------------- + + +class TestAsyncOmniUpdateRequest: + def test_update_request_submits_to_stage(self): + stage = MagicMock() + omni = SimpleNamespace(stage_list=[stage]) + + from vllm_omni.entrypoints.async_omni import AsyncOmni + + AsyncOmni.update_request(omni, "req-1", {"key": "val"}, stage_id=0) + + stage.submit.assert_called_once() + call_args = stage.submit.call_args[0][0] + assert call_args["type"] == OmniStageTaskType.UPDATE + assert call_args["request_id"] == "req-1" + assert call_args["update"] == {"key": "val"} + + +# --------------------------------------------------------------------------- +# Patch: EngineCoreRequestType.UPDATE +# --------------------------------------------------------------------------- + + +class TestEngineCoreRequestTypePatch: + def test_update_type_exists(self): + from vllm.v1.engine import EngineCoreRequestType + + assert hasattr(EngineCoreRequestType, "UPDATE") + assert EngineCoreRequestType.UPDATE.value == b"\x05"