From 95de2a1263a16043819f502f64fa242d1fdac442 Mon Sep 17 00:00:00 2001
From: ZhengWG
Date: Tue, 31 Mar 2026 18:08:32 +0800
Subject: [PATCH 01/76] feat: support multi-stage
Signed-off-by: ZhengWG
Made-with: Cursor
---
vllm_omni/engine/async_omni_engine.py | 75 ++++-
vllm_omni/engine/orchestrator.py | 272 +++++++++++++------
vllm_omni/engine/stage_engine_core_client.py | 12 +-
vllm_omni/engine/stage_init_utils.py | 7 +
4 files changed, 280 insertions(+), 86 deletions(-)
diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py
index a4d87c96e4a..f7a5a5186a7 100644
--- a/vllm_omni/engine/async_omni_engine.py
+++ b/vllm_omni/engine/async_omni_engine.py
@@ -435,10 +435,19 @@ def _attach_llm_stage(
return stage_client, output_processor, started.vllm_config, input_processor
def _initialize_stages(self, stage_init_timeout: int) -> None:
- """Initialize stage clients/processors in orchestrator thread and assign to self."""
+ """Initialize stage clients/processors in orchestrator thread and assign to self.
+
+ Multi-replica support: when a stage config contains
+ ``runtime.num_replicas > 1``, multiple clients are created for the same
+ logical stage and the flat ``stage_clients`` list grows accordingly.
+ ``logical_stage_to_clients`` maps each logical stage id to the list of
+ client indices that belong to it.
+ """
device_control_env = current_omni_platform.device_control_env_var
num_stages = self.num_stages
+ # These are indexed by *logical* stage_id during initialization, then
+ # expanded to flat client-indexed lists at the end.
stage_clients: list[Any | None] = [None] * num_stages
output_processors: list[Any | None] = [None] * num_stages
stage_vllm_configs: list[Any | None] = [None] * num_stages
@@ -448,6 +457,15 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
started_llm_stages: dict[int, StartedLlmStage] = {}
llm_stage_launch_lock = threading.Lock()
+ # Track per-logical-stage replica count from config
+ replicas_per_stage: list[int] = []
+ for stage_cfg in self.stage_configs:
+ runtime_cfg = getattr(stage_cfg, "runtime", {})
+ num_replicas = int(
+ runtime_cfg.get("num_replicas", 1) if hasattr(runtime_cfg, "get") else getattr(runtime_cfg, "num_replicas", 1)
+ )
+ replicas_per_stage.append(max(1, num_replicas))
+
async_chunk = self.async_chunk
prompt_expand_func = None
llm_stage_count = sum(
@@ -549,21 +567,61 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
)
raise
- self.stage_clients = initialized_stage_clients
- self.output_processors = output_processors
- self.stage_vllm_configs = stage_vllm_configs
+ # ---- Multi-replica expansion ----
+ # Expand the logical-indexed lists into flat client-indexed lists.
+ # For replica_index > 0 the same client/processor/config is shared
+ # (they point to the same underlying EngineCore) — full per-replica
+ # process isolation is out-of-scope for this first iteration.
+ flat_clients: list[Any] = []
+ flat_output_processors: list[Any] = []
+ flat_vllm_configs: list[Any] = []
+ logical_stage_to_clients: list[list[int]] = []
+ # sampling_params and stage_metadata are per-logical-stage
+ logical_default_sampling_params: list[Any] = []
+ logical_stage_metadata: list[dict[str, Any]] = []
+
+ for logical_id, client in enumerate(initialized_stage_clients):
+ num_replicas = replicas_per_stage[logical_id]
+ client_indices: list[int] = []
+ for replica_idx in range(num_replicas):
+ ci = len(flat_clients)
+ client_indices.append(ci)
+ if replica_idx == 0:
+ # First replica uses the already-created objects
+ flat_clients.append(client)
+ flat_output_processors.append(output_processors[logical_id])
+ flat_vllm_configs.append(stage_vllm_configs[logical_id])
+ else:
+ # Additional replicas: for now, share the same client.
+ # True per-replica process isolation will be added later.
+ # TODO: launch separate EngineCore processes for replica_idx > 0
+ flat_clients.append(client)
+ flat_output_processors.append(output_processors[logical_id])
+ flat_vllm_configs.append(stage_vllm_configs[logical_id])
+ logger.info(
+ "[AsyncOmniEngine] Logical stage %s replica %s → client %s (shared)",
+ logical_id, replica_idx, ci,
+ )
+ logical_stage_to_clients.append(client_indices)
+ logical_default_sampling_params.append(default_sampling_params_list[logical_id])
+ logical_stage_metadata.append(stage_metadata[logical_id])
+
+ self.stage_clients = flat_clients
+ self.output_processors = flat_output_processors
+ self.stage_vllm_configs = flat_vllm_configs
+ self.logical_stage_to_clients = logical_stage_to_clients
self.input_processor = input_processor
self.prompt_expand_func = prompt_expand_func
# TODO(Peiqi): Hack here
supported_tasks: set[str] = set()
- if any(getattr(stage_client, "is_comprehension", False) for stage_client in initialized_stage_clients):
+ if any(getattr(stage_client, "is_comprehension", False) for stage_client in flat_clients):
supported_tasks.add("generate")
- if any(metadata.get("final_output_type") == "audio" for metadata in stage_metadata):
+ if any(metadata.get("final_output_type") == "audio" for metadata in logical_stage_metadata):
supported_tasks.add("speech")
self.supported_tasks = tuple(supported_tasks) if supported_tasks else ("generate",)
- self.default_sampling_params_list = default_sampling_params_list
- self.stage_metadata = stage_metadata
+ self.default_sampling_params_list = logical_default_sampling_params
+ self.stage_metadata = logical_stage_metadata
def _initialize_janus_queues(self) -> None:
"""Initialize janus queues inside orchestrator thread loop context."""
@@ -594,6 +652,7 @@ async def _run_orchestrator() -> None:
stage_clients=self.stage_clients,
output_processors=self.output_processors,
stage_vllm_configs=self.stage_vllm_configs,
+ logical_stage_to_clients=getattr(self, "logical_stage_to_clients", None),
)
if not startup_future.done():
startup_future.set_result(asyncio.get_running_loop())
diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py
index e6373ec96ea..3106b876809 100644
--- a/vllm_omni/engine/orchestrator.py
+++ b/vllm_omni/engine/orchestrator.py
@@ -104,6 +104,11 @@ class OrchestratorRequestState:
# Metrics: timestamp when request was submitted to each stage
stage_submit_ts: dict[int, float] = field(default_factory=dict)
+ # Multi-replica: maps logical_stage_id -> client_index chosen for this
+ # request. Ensures the same request always hits the same replica within
+ # a given logical stage (KV / intermediate-state affinity).
+ chosen_client_index: dict[int, int] = field(default_factory=dict)
+
class Orchestrator:
"""Runs inside a background thread's asyncio event loop.
@@ -122,18 +127,39 @@ def __init__(
stage_vllm_configs: list[Any],
*,
async_chunk: bool = False,
+ logical_stage_to_clients: list[list[int]] | None = None,
) -> None:
self.request_async_queue = request_async_queue
self.output_async_queue = output_async_queue
self.rpc_async_queue = rpc_async_queue
- self.num_stages = len(stage_clients)
+ self.num_clients = len(stage_clients)
self.async_chunk = bool(async_chunk)
self.stage_clients: list[Any] = stage_clients
self.output_processors: list[Any] = output_processors
self.stage_vllm_configs: list[Any] = stage_vllm_configs
+ # Multi-replica mapping: logical_stage_id -> list of client indices.
+ # When not provided (single-replica), default to identity mapping.
+ if logical_stage_to_clients is not None:
+ self.logical_stage_to_clients = logical_stage_to_clients
+ else:
+ self.logical_stage_to_clients = [[i] for i in range(self.num_clients)]
+ self.num_logical_stages = len(self.logical_stage_to_clients)
+
+ # Reverse mapping: client_index -> logical_stage_id
+ self._client_to_logical: list[int] = [0] * self.num_clients
+ for logical_id, client_indices in enumerate(self.logical_stage_to_clients):
+ for ci in client_indices:
+ self._client_to_logical[ci] = logical_id
+
+ # Round-robin counters for replica selection per logical stage
+ self._replica_rr: list[int] = [0] * self.num_logical_stages
+
+ # Backward compat: num_stages now means num_logical_stages
+ self.num_stages = self.num_logical_stages
+
# Per-request state
self.request_states: dict[str, OrchestratorRequestState] = {}
@@ -144,15 +170,41 @@ def __init__(
self._companion_done: dict[str, set[str]] = {}
self._deferred_parents: dict[str, dict[str, Any]] = {}
- # Per-stage metrics accumulators.
- self._batch_seq: list[int] = [0] * self.num_stages
- self._agg_total_tokens: list[int] = [0] * self.num_stages
- self._agg_total_gen_time_ms: list[float] = [0.0] * self.num_stages
+ # Per-client metrics accumulators.
+ self._batch_seq: list[int] = [0] * self.num_clients
+ self._agg_total_tokens: list[int] = [0] * self.num_clients
+ self._agg_total_gen_time_ms: list[float] = [0.0] * self.num_clients
# Shutdown coordination
self._shutdown_event = asyncio.Event()
self._stages_shutdown = False
+ def _choose_client_index(
+ self,
+ logical_stage_id: int,
+ req_state: OrchestratorRequestState,
+ ) -> int:
+ """Pick a client for *logical_stage_id* and record the choice.
+
+ If this request already has a chosen client for the logical stage,
+ return the existing one (affinity). Otherwise round-robin among the
+ available replicas.
+ """
+ existing = req_state.chosen_client_index.get(logical_stage_id)
+ if existing is not None:
+ return existing
+
+ candidates = self.logical_stage_to_clients[logical_stage_id]
+ if len(candidates) == 1:
+ chosen = candidates[0]
+ else:
+ rr = self._replica_rr[logical_stage_id]
+ chosen = candidates[rr % len(candidates)]
+ self._replica_rr[logical_stage_id] = rr + 1
+
+ req_state.chosen_client_index[logical_stage_id] = chosen
+ return chosen
+
async def run(self) -> None:
"""Main entry point for the Orchestrator event loop."""
logger.info("[Orchestrator] Starting event loop")
@@ -226,31 +278,38 @@ async def _orchestration_loop(self) -> None:
"""Inner loop for _orchestration_output_handler (clean cancellation).
Control flow: poll raw → process through output processor → route.
+
+ Multi-replica: iterates over every *client_index* (not logical stage),
+ and resolves the logical_stage_id from client metadata for routing.
"""
while not self._shutdown_event.is_set():
idle = True
- for stage_id in range(self.num_stages):
+ for client_index in range(self.num_clients):
if self._shutdown_event.is_set():
return
+ logical_stage_id = self._client_to_logical[client_index]
+
# 1) Diffusion stage: poll non-blocking queue
- # TODO (Peiqi): the output of diffusion stage is OmniRequestOutput,
- # which is different from EngineCoreOutputs (LLM stages). We may want to unify
- # the output format in the future to simplify the processing logic in Orchestrator.
- stage_client = self.stage_clients[stage_id]
+ stage_client = self.stage_clients[client_index]
if stage_client.stage_type == "diffusion":
output = stage_client.get_diffusion_output_async()
if output is not None:
idle = False
req_state = self.request_states.get(output.request_id)
if req_state is not None:
- stage_metrics = self._build_stage_metrics(stage_id, output.request_id, [output], req_state)
- await self._route_output(stage_id, output, req_state, stage_metrics)
+ stage_metrics = self._build_stage_metrics(
+ client_index, output.request_id, [output], req_state
+ )
+ await self._route_output(
+ logical_stage_id, output, req_state, stage_metrics,
+ client_index=client_index,
+ )
continue
- # 1) Poll raw outputs from the stage
+ # 1) Poll raw outputs from the client
try:
- raw_outputs = await asyncio.wait_for(self._poll_stage_raw(stage_id), timeout=0.001)
+ raw_outputs = await asyncio.wait_for(self._poll_stage_raw(client_index), timeout=0.001)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
@@ -259,8 +318,9 @@ async def _orchestration_loop(self) -> None:
if self._shutdown_event.is_set():
return
logger.exception(
- "[Orchestrator] _poll_stage_raw failed for stage-%s",
- stage_id,
+ "[Orchestrator] _poll_stage_raw failed for client-%s (logical stage-%s)",
+ client_index,
+ logical_stage_id,
)
raise
@@ -269,28 +329,33 @@ async def _orchestration_loop(self) -> None:
idle = False
# 2) Process raw outputs through the output processor
- request_outputs = await self._process_stage_outputs(stage_id, raw_outputs)
+ request_outputs = await self._process_stage_outputs(client_index, raw_outputs)
# 3) Route each processed output
for output in request_outputs:
req_state = self.request_states.get(output.request_id)
if req_state is None:
logger.warning(
- "[Orchestrator] Dropping output for unknown req %s at stage-%s (known reqs: %s)",
+ "[Orchestrator] Dropping output for unknown req %s at client-%s "
+ "(logical stage-%s, known reqs: %s)",
output.request_id,
- stage_id,
+ client_index,
+ logical_stage_id,
list(self.request_states.keys()),
)
continue
stage_metrics = None
if output.finished:
stage_metrics = self._build_stage_metrics(
- stage_id,
+ client_index,
output.request_id,
[output],
req_state,
)
- await self._route_output(stage_id, output, req_state, stage_metrics)
+ await self._route_output(
+ logical_stage_id, output, req_state, stage_metrics,
+ client_index=client_index,
+ )
if idle:
await asyncio.sleep(0.001)
@@ -303,12 +368,22 @@ async def _route_output(
output: Any,
req_state: OrchestratorRequestState,
stage_metrics: Any,
+ *,
+ client_index: int | None = None,
) -> None:
- """Route a processed output: send to main thread and/or forward to next stage."""
+ """Route a processed output: send to main thread and/or forward to next stage.
+
+ Args:
+ stage_id: Logical stage id.
+ client_index: Physical client index that produced this output.
+ Defaults to stage_id for backward compat.
+ """
+ if client_index is None:
+ client_index = stage_id
req_id = output.request_id
finished = output.finished
submit_ts = req_state.stage_submit_ts.get(stage_id)
- stage_client = self.stage_clients[stage_id]
+ stage_client = self.stage_clients[client_index]
# CFG companion handling: companions don't produce user-visible output
# and don't forward to the next stage directly.
@@ -331,6 +406,7 @@ async def _route_output(
deferred["stage_id"],
deferred["output"],
parent_state,
+ client_index=deferred.get("client_index", deferred["stage_id"]),
)
self.request_states.pop(req_id, None)
return
@@ -364,13 +440,16 @@ async def _route_output(
self._deferred_parents[req_id] = {
"stage_id": stage_id,
"output": output,
+ "client_index": client_index,
}
logger.debug(
"[Orchestrator] Parent %s deferred, waiting for CFG companions",
req_id,
)
else:
- await self._forward_to_next_stage(req_id, stage_id, output, req_state)
+ await self._forward_to_next_stage(
+ req_id, stage_id, output, req_state, client_index=client_index,
+ )
if finished and stage_id == req_state.final_stage_id:
self._cleanup_companion_state(req_id)
@@ -395,35 +474,36 @@ def _all_companions_done(self, parent_id: str) -> bool:
def _build_stage_metrics(
self,
- stage_id: int,
+ client_index: int,
req_id: str,
request_outputs: list[RequestOutput],
req_state: OrchestratorRequestState,
) -> StageRequestMetrics:
- """Build StageRequestMetrics for a finished request at a stage.
+ """Build StageRequestMetrics for a finished request at a client.
Reuses StageRequestMetrics so OrchestratorMetrics and downstream
metric handlers can consume a stable schema.
"""
+ logical_stage_id = self._client_to_logical[client_index]
now = _time.time()
- submit_ts = req_state.stage_submit_ts.get(stage_id, now)
+ submit_ts = req_state.stage_submit_ts.get(logical_stage_id, now)
stage_gen_time_ms = (now - submit_ts) * 1000.0
num_tokens_out = count_tokens_from_outputs(request_outputs)
num_tokens_in = 0
- if stage_id == 0:
+ if logical_stage_id == 0:
for ro in request_outputs:
ptids = getattr(ro, "prompt_token_ids", None)
if ptids is not None:
num_tokens_in += len(ptids)
- # Monotonic batch counter per stage.
- self._batch_seq[stage_id] += 1
- batch_id = self._batch_seq[stage_id]
+ # Monotonic batch counter per client.
+ self._batch_seq[client_index] += 1
+ batch_id = self._batch_seq[client_index]
# Accumulate for running-average stage_stats
- self._agg_total_tokens[stage_id] += num_tokens_out
- self._agg_total_gen_time_ms[stage_id] += stage_gen_time_ms
+ self._agg_total_tokens[client_index] += num_tokens_out
+ self._agg_total_gen_time_ms[client_index] += stage_gen_time_ms
return StageRequestMetrics(
num_tokens_in=num_tokens_in,
@@ -435,8 +515,8 @@ def _build_stage_metrics(
rx_transfer_bytes=0,
rx_in_flight_time_ms=0.0,
stage_stats=StageStats(
- total_token=self._agg_total_tokens[stage_id],
- total_gen_time_ms=self._agg_total_gen_time_ms[stage_id],
+ total_token=self._agg_total_tokens[client_index],
+ total_gen_time_ms=self._agg_total_gen_time_ms[client_index],
),
)
@@ -446,18 +526,28 @@ async def _forward_to_next_stage(
stage_id: int,
output: Any,
req_state: OrchestratorRequestState,
+ *,
+ client_index: int | None = None,
) -> None:
"""Forward output from current stage to the next stage.
Handles the full pipeline: set outputs on current stage, compute
next-stage inputs, build lightweight requests, and submit them.
+
+ Args:
+ stage_id: Logical stage id that produced the output.
+ client_index: Physical client index that produced the output.
"""
- next_stage_id = stage_id + 1
- next_client = self.stage_clients[next_stage_id]
- params = req_state.sampling_params_list[next_stage_id]
+ if client_index is None:
+ client_index = stage_id
+
+ next_logical = stage_id + 1
+ next_ci = self._choose_client_index(next_logical, req_state)
+ next_client = self.stage_clients[next_ci]
+ params = req_state.sampling_params_list[next_logical]
if next_client.stage_type == "diffusion":
- self.stage_clients[stage_id].set_engine_outputs([output])
+ self.stage_clients[client_index].set_engine_outputs([output])
if next_client.custom_process_input_func is not None:
diffusion_prompt = next_client.custom_process_input_func(
self.stage_clients,
@@ -493,22 +583,25 @@ async def _forward_to_next_stage(
)
else:
await next_client.add_request_async(req_id, diffusion_prompt, params)
- req_state.stage_submit_ts[next_stage_id] = _time.time()
+ req_state.stage_submit_ts[next_logical] = _time.time()
return
- self.stage_clients[stage_id].set_engine_outputs([output])
+ # Set outputs on the client that actually produced them
+ self.stage_clients[client_index].set_engine_outputs([output])
# Process inputs for next stage
try:
next_inputs = next_client.process_engine_inputs(
stage_list=self.stage_clients,
prompt=req_state.prompt,
+ source_client_index=client_index,
)
except Exception:
logger.exception(
- "[Orchestrator] req=%s process_engine_inputs FAILED for stage-%s",
+ "[Orchestrator] req=%s process_engine_inputs FAILED for logical stage-%s (client-%s)",
req_id,
- next_stage_id,
+ next_logical,
+ next_ci,
)
raise
@@ -518,13 +611,13 @@ async def _forward_to_next_stage(
request_id=req_id,
prompt=next_input,
params=params,
- model_config=self.stage_vllm_configs[next_stage_id].model_config,
+ model_config=self.stage_vllm_configs[next_ci].model_config,
)
# TODO: Here we directly use the req id to assign.
request.external_req_id = request.request_id
- self.output_processors[next_stage_id].add_request(
+ self.output_processors[next_ci].add_request(
request=request,
prompt=None,
parent_req=None,
@@ -534,26 +627,26 @@ async def _forward_to_next_stage(
await next_client.add_request_async(request)
- # Record submit timestamp for the next stage
- req_state.stage_submit_ts[next_stage_id] = _time.time()
+ # Record submit timestamp for the next logical stage
+ req_state.stage_submit_ts[next_logical] = _time.time()
- async def _poll_stage_raw(self, stage_id: int) -> EngineCoreOutputs | None:
+ async def _poll_stage_raw(self, client_index: int) -> EngineCoreOutputs | None:
"""Pull raw EngineCoreOutputs from a stage client without processing.
Returns the raw outputs object, or None when there is nothing
to consume.
"""
- outputs = await self.stage_clients[stage_id].get_output_async()
+ outputs = await self.stage_clients[client_index].get_output_async()
if not outputs.outputs:
return None
return outputs
- async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOutputs) -> list[RequestOutput]:
+ async def _process_stage_outputs(self, client_index: int, raw_outputs: EngineCoreOutputs) -> list[RequestOutput]:
"""Run the output processor on raw outputs, returning RequestOutputs.
Also handles abort forwarding and scheduler stats updates.
"""
- processor = self.output_processors[stage_id]
+ processor = self.output_processors[client_index]
processed = processor.process_outputs(
raw_outputs.outputs,
@@ -562,7 +655,7 @@ async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOut
)
if processed.reqs_to_abort:
- await self.stage_clients[stage_id].abort_requests_async(processed.reqs_to_abort)
+ await self.stage_clients[client_index].abort_requests_async(processed.reqs_to_abort)
if raw_outputs.scheduler_stats is not None:
processor.update_scheduler_stats(raw_outputs.scheduler_stats)
@@ -571,7 +664,7 @@ async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOut
async def _handle_add_request(self, msg: dict[str, Any]) -> None:
"""Handle an add_request message from the main thread."""
- stage_id = 0
+ logical_stage_id = 0
request_id = msg["request_id"]
prompt = msg["prompt"]
original_prompt = msg.get("original_prompt", prompt)
@@ -585,7 +678,7 @@ async def _handle_add_request(self, msg: dict[str, Any]) -> None:
"[Orchestrator] _handle_add_request: stage=%s req=%s "
"prompt_type=%s original_prompt_type=%s final_stage=%s "
"num_sampling_params=%d",
- stage_id,
+ logical_stage_id,
request_id,
type(prompt).__name__,
type(original_prompt).__name__,
@@ -601,14 +694,17 @@ async def _handle_add_request(self, msg: dict[str, Any]) -> None:
sampling_params_list=sampling_params_list,
final_stage_id=final_stage_id,
)
- req_state.stage_submit_ts[stage_id] = _time.time()
self.request_states[request_id] = req_state
+ # Choose a replica for logical stage 0
+ client_index = self._choose_client_index(logical_stage_id, req_state)
+ req_state.stage_submit_ts[logical_stage_id] = _time.time()
+
# Stage-0 prompt is already a fully-formed OmniEngineCoreRequest
# (pre-processed by AsyncOmniEngine.add_request, output processor
# already registered there) - submit directly.
request = prompt
- stage_client = self.stage_clients[stage_id]
+ stage_client = self.stage_clients[client_index]
if stage_client.stage_type == "diffusion":
if isinstance(prompt, list):
await stage_client.add_batch_request_async(
@@ -621,7 +717,7 @@ async def _handle_add_request(self, msg: dict[str, Any]) -> None:
else:
await stage_client.add_request_async(request)
- if self.async_chunk and stage_id == 0 and final_stage_id > 0:
+ if self.async_chunk and logical_stage_id == 0 and final_stage_id > 0:
await self._prewarm_async_chunk_stages(request_id, request, req_state)
async def _prewarm_async_chunk_stages(
@@ -635,6 +731,9 @@ async def _prewarm_async_chunk_stages(
In async-chunk mode, stages exchange data through connectors/chunk adapters,
so downstream stages should be armed once at request start instead of waiting
for stage-finished forwarding.
+
+ Multi-replica: uses _choose_client_index so the prewarm targets align
+ with the orchestration-face chosen replicas.
"""
if req_state.final_stage_id <= 0:
return
@@ -661,24 +760,25 @@ async def _prewarm_async_chunk_stages(
base_input["multi_modal_data"] = None
base_input["mm_processor_kwargs"] = None
- for next_stage_id in range(1, req_state.final_stage_id + 1):
- next_client = self.stage_clients[next_stage_id]
- params = req_state.sampling_params_list[next_stage_id]
+ for next_logical in range(1, req_state.final_stage_id + 1):
+ next_ci = self._choose_client_index(next_logical, req_state)
+ next_client = self.stage_clients[next_ci]
+ params = req_state.sampling_params_list[next_logical]
if next_client.stage_type == "diffusion":
await next_client.add_request_async(request_id, req_state.prompt, params)
- req_state.stage_submit_ts[next_stage_id] = _time.time()
+ req_state.stage_submit_ts[next_logical] = _time.time()
continue
request = build_engine_core_request_from_tokens(
request_id=request_id,
prompt=base_input,
params=params,
- model_config=self.stage_vllm_configs[next_stage_id].model_config,
+ model_config=self.stage_vllm_configs[next_ci].model_config,
)
request.external_req_id = request.request_id
- self.output_processors[next_stage_id].add_request(
+ self.output_processors[next_ci].add_request(
request=request,
prompt=None,
parent_req=None,
@@ -686,7 +786,7 @@ async def _prewarm_async_chunk_stages(
queue=None,
)
await next_client.add_request_async(request)
- req_state.stage_submit_ts[next_stage_id] = _time.time()
+ req_state.stage_submit_ts[next_logical] = _time.time()
async def _handle_add_companion(self, msg: dict[str, Any]) -> None:
"""Handle an add_companion_request message: submit companion to stage 0."""
@@ -710,18 +810,28 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None:
sampling_params_list=sampling_params_list,
final_stage_id=0,
)
- companion_state.stage_submit_ts[0] = _time.time()
self.request_states[companion_id] = companion_state
+ # Use same replica as the parent for affinity, or choose one
+ parent_state = self.request_states.get(parent_id)
+ if parent_state is not None and 0 in parent_state.chosen_client_index:
+ client_index = parent_state.chosen_client_index[0]
+ companion_state.chosen_client_index[0] = client_index
+ else:
+ client_index = self._choose_client_index(0, companion_state)
+
+ companion_state.stage_submit_ts[0] = _time.time()
+
request = companion_prompt # Already a processed OmniEngineCoreRequest
- stage_client = self.stage_clients[0]
+ stage_client = self.stage_clients[client_index]
await stage_client.add_request_async(request)
logger.info(
- "[Orchestrator] CFG companion submitted: %s (role=%s, parent=%s)",
+ "[Orchestrator] CFG companion submitted: %s (role=%s, parent=%s, client=%s)",
companion_id,
role,
parent_id,
+ client_index,
)
async def _handle_abort(self, msg: dict[str, Any]) -> None:
@@ -740,8 +850,8 @@ async def _handle_abort(self, msg: dict[str, Any]) -> None:
self._deferred_parents.pop(req_id, None)
all_ids_to_abort = list(request_ids) + companion_ids_to_abort
- for stage_id in range(self.num_stages):
- await self.stage_clients[stage_id].abort_requests_async(all_ids_to_abort)
+ for ci in range(self.num_clients):
+ await self.stage_clients[ci].abort_requests_async(all_ids_to_abort)
for req_id in request_ids:
self.request_states.pop(req_id, None)
logger.info("[Orchestrator] Aborted request(s) %s", request_ids)
@@ -759,16 +869,26 @@ async def _handle_collective_rpc(self, msg: dict[str, Any]) -> None:
args = tuple(msg.get("args", ()))
kwargs = dict(msg.get("kwargs") or {})
requested_stage_ids = msg.get("stage_ids")
- stage_ids = list(range(self.num_stages)) if requested_stage_ids is None else list(requested_stage_ids)
+ # When stage_ids are provided they refer to logical stages; expand
+ # to all client indices belonging to those logical stages.
+ if requested_stage_ids is None:
+ stage_ids = list(range(self.num_clients))
+ else:
+ stage_ids = []
+ for lid in requested_stage_ids:
+ if 0 <= lid < self.num_logical_stages:
+ stage_ids.extend(self.logical_stage_to_clients[lid])
+ else:
+ stage_ids.append(lid) # keep invalid id for error reporting
results: list[Any] = []
for stage_id in stage_ids:
- if stage_id < 0 or stage_id >= self.num_stages:
+ if stage_id < 0 or stage_id >= self.num_clients:
results.append(
{
"supported": False,
"todo": True,
- "error": f"Invalid stage id {stage_id}",
+ "error": f"Invalid client index {stage_id}",
}
)
continue
@@ -817,10 +937,10 @@ def _shutdown_stages(self) -> None:
return
self._stages_shutdown = True
- logger.info("[Orchestrator] Shutting down all stages")
- for stage_id, stage_client in enumerate(self.stage_clients):
+ logger.info("[Orchestrator] Shutting down all %d client(s)", self.num_clients)
+ for ci, stage_client in enumerate(self.stage_clients):
try:
stage_client.shutdown()
- logger.info(f"[Orchestrator] Stage {stage_id} shut down")
+ logger.info("[Orchestrator] Client %d shut down", ci)
except Exception as e:
- logger.warning(f"[Orchestrator] Failed to shutdown stage {stage_id}: {e}")
+ logger.warning("[Orchestrator] Failed to shutdown client %d: %s", ci, e)
diff --git a/vllm_omni/engine/stage_engine_core_client.py b/vllm_omni/engine/stage_engine_core_client.py
index 284cc2d31a2..395a9d84550 100644
--- a/vllm_omni/engine/stage_engine_core_client.py
+++ b/vllm_omni/engine/stage_engine_core_client.py
@@ -123,8 +123,16 @@ def process_engine_inputs(
self,
stage_list: list[Any],
prompt: OmniTokensPrompt | list[OmniTokensPrompt] | None = None,
+ source_client_index: int | None = None,
) -> list[OmniTokensPrompt]:
- """Process inputs from upstream stages."""
+ """Process inputs from upstream stages.
+
+ Args:
+ source_client_index: When multi-replica is enabled, specifies the
+ exact client index in *stage_list* that produced the upstream
+ output. Falls back to ``engine_input_source[0]`` for backward
+ compat.
+ """
from vllm_omni.inputs.data import OmniTokensPrompt
if self.custom_process_input_func is not None:
@@ -138,7 +146,7 @@ def process_engine_inputs(
if not self.engine_input_source:
raise ValueError(f"engine_input_source empty for stage {self.stage_id}")
- source_id = self.engine_input_source[0]
+ source_id = source_client_index if source_client_index is not None else self.engine_input_source[0]
source_outputs = stage_list[source_id].engine_outputs
if not isinstance(prompt, list):
diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py
index 9c246ce6eb3..73c63255dcc 100644
--- a/vllm_omni/engine/stage_init_utils.py
+++ b/vllm_omni/engine/stage_init_utils.py
@@ -111,6 +111,11 @@ class StageMetadata:
runtime_cfg: Any
prompt_expand_func: Callable | None = None
cfg_kv_collect_func: Callable | None = None
+ # Multi-replica fields: logical_stage_id is the original stage_id from
+ # the YAML config; replica_index distinguishes replicas of the same
+ # logical stage. For single-replica stages these default to stage_id / 0.
+ logical_stage_id: int = -1
+ replica_index: int = 0
@dataclass
@@ -172,6 +177,7 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata:
model_stage=None,
runtime_cfg=runtime_cfg,
cfg_kv_collect_func=cfg_kv_collect_func,
+ logical_stage_id=stage_id,
)
model_stage = getattr(engine_args, "model_stage", None)
@@ -193,6 +199,7 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata:
model_stage=model_stage,
runtime_cfg=runtime_cfg,
prompt_expand_func=prompt_expand_func,
+ logical_stage_id=stage_id,
)
From 0b774ca0d77002f2ffd9c546b330eb5a7d7b3454 Mon Sep 17 00:00:00 2001
From: ZhengWG
Date: Wed, 1 Apr 2026 12:18:57 +0800
Subject: [PATCH 02/76] feat: init multi engine-cores
Signed-off-by: ZhengWG
Made-with: Cursor
---
vllm_omni/engine/async_omni_engine.py | 232 +++++++++++++++++---------
vllm_omni/engine/stage_init_utils.py | 45 +++++
2 files changed, 194 insertions(+), 83 deletions(-)
diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py
index f7a5a5186a7..49b5bd5b161 100644
--- a/vllm_omni/engine/async_omni_engine.py
+++ b/vllm_omni/engine/async_omni_engine.py
@@ -9,6 +9,7 @@
import asyncio
import concurrent.futures
+import copy
import dataclasses
import json
import os
@@ -55,11 +56,13 @@
extract_stage_metadata,
finalize_initialized_stages,
get_stage_connector_spec,
+ get_stage_tp_size,
initialize_diffusion_stage,
load_omni_transfer_config_for_model,
prepare_engine_environment,
release_device_locks,
setup_stage_devices,
+ split_devices_for_replicas,
)
from vllm_omni.entrypoints.utils import (
load_and_resolve_stage_configs,
@@ -438,24 +441,22 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
"""Initialize stage clients/processors in orchestrator thread and assign to self.
Multi-replica support: when a stage config contains
- ``runtime.num_replicas > 1``, multiple clients are created for the same
- logical stage and the flat ``stage_clients`` list grows accordingly.
- ``logical_stage_to_clients`` maps each logical stage id to the list of
- client indices that belong to it.
+ ``runtime.num_replicas > 1``, each replica launches its own EngineCore
+ process with a dedicated slice of devices. The flat ``stage_clients``
+ list contains all replica clients; ``logical_stage_to_clients`` maps
+ each logical stage id to the list of client indices that belong to it.
"""
device_control_env = current_omni_platform.device_control_env_var
num_stages = self.num_stages
- # These are indexed by *logical* stage_id during initialization, then
- # expanded to flat client-indexed lists at the end.
- stage_clients: list[Any | None] = [None] * num_stages
- output_processors: list[Any | None] = [None] * num_stages
- stage_vllm_configs: list[Any | None] = [None] * num_stages
input_processor: InputProcessor | None = None
- llm_stage_ids: list[int] = []
- llm_launch_futures: dict[int, concurrent.futures.Future[StartedLlmStage]] = {}
- started_llm_stages: dict[int, StartedLlmStage] = {}
+ # Keyed by (logical_stage_id, replica_idx)
+ llm_launch_keys: list[tuple[int, int]] = []
+ llm_launch_futures: dict[tuple[int, int], concurrent.futures.Future[StartedLlmStage]] = {}
+ started_llm_stages: dict[tuple[int, int], StartedLlmStage] = {}
llm_stage_launch_lock = threading.Lock()
+ # Diffusion stages (no multi-replica support yet)
+ diffusion_clients: dict[int, Any] = {}
# Track per-logical-stage replica count from config
replicas_per_stage: list[int] = []
@@ -466,22 +467,47 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
)
replicas_per_stage.append(max(1, num_replicas))
+ # Pre-compute per-replica device assignments for multi-replica stages
+ replica_devices_map: dict[tuple[int, int], str] = {}
+ for logical_id, stage_cfg in enumerate(self.stage_configs):
+ num_replicas = replicas_per_stage[logical_id]
+ if num_replicas <= 1:
+ continue
+ runtime_cfg = getattr(stage_cfg, "runtime", {})
+ devices_str = (
+ runtime_cfg.get("devices") if hasattr(runtime_cfg, "get")
+ else getattr(runtime_cfg, "devices", None)
+ )
+ tp_size = get_stage_tp_size(stage_cfg)
+ per_replica = split_devices_for_replicas(devices_str, num_replicas, tp_size, logical_id)
+ for r, dev_str in enumerate(per_replica):
+ replica_devices_map[(logical_id, r)] = dev_str
+ logger.info(
+ "[AsyncOmniEngine] Stage %s: %d replicas, tp=%d, devices split: %s",
+ logical_id, num_replicas, tp_size, per_replica,
+ )
+
async_chunk = self.async_chunk
prompt_expand_func = None
- llm_stage_count = sum(
- 1 for stage_cfg in self.stage_configs if getattr(stage_cfg, "stage_type", "llm") != "diffusion"
+ total_llm_replicas = sum(
+ replicas_per_stage[i]
+ for i, cfg in enumerate(self.stage_configs)
+ if getattr(cfg, "stage_type", "llm") != "diffusion"
)
prepare_engine_environment()
omni_transfer_config = load_omni_transfer_config_for_model(self.model, self.config_path)
+ # Initialized outside try so error handler can always access them
+ flat_clients: list[Any] = []
+ all_clients: dict[tuple[int, int], Any] = {}
+
try:
with concurrent.futures.ThreadPoolExecutor(
- max_workers=max(1, llm_stage_count),
+ max_workers=max(1, total_llm_replicas),
thread_name_prefix="llm-stage-launch",
) as launch_executor:
for stage_id, stage_cfg in enumerate(self.stage_configs):
- logger.info("[AsyncOmniEngine] Initializing stage %s", stage_id)
metadata = extract_stage_metadata(stage_cfg)
if metadata.prompt_expand_func is not None:
prompt_expand_func = metadata.prompt_expand_func
@@ -505,7 +531,7 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
inject_omni_kv_config(stage_cfg, omni_conn_cfg, omni_from, omni_to)
_inject_kv_stage_info(stage_cfg, stage_id)
- stage_clients[stage_id] = initialize_diffusion_stage(
+ diffusion_clients[stage_id] = initialize_diffusion_stage(
self.model,
stage_cfg,
metadata,
@@ -523,89 +549,129 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
current_omni_platform.set_device_control_env_var(previous_visible_devices)
continue
- llm_stage_ids.append(stage_id)
- llm_launch_futures[stage_id] = launch_executor.submit(
- self._launch_llm_stage,
- stage_cfg,
- metadata,
- stage_connector_spec,
- stage_init_timeout,
- llm_stage_launch_lock,
- omni_kv_connector,
- )
+ # Submit one launch future per replica
+ num_replicas = replicas_per_stage[stage_id]
+ for replica_idx in range(num_replicas):
+ key = (stage_id, replica_idx)
+ llm_launch_keys.append(key)
+
+ # For replica > 0, deep-copy stage_cfg and override devices
+ if replica_idx > 0:
+ replica_cfg = copy.deepcopy(stage_cfg)
+ else:
+ replica_cfg = stage_cfg
+
+ if key in replica_devices_map:
+ replica_cfg.runtime.devices = replica_devices_map[key]
+
+ replica_metadata = extract_stage_metadata(replica_cfg)
+ replica_metadata.replica_index = replica_idx
+
+ logger.info(
+ "[AsyncOmniEngine] Launching stage %s replica %s (devices=%s)",
+ stage_id, replica_idx,
+ getattr(getattr(replica_cfg, "runtime", None), "devices", "default"),
+ )
+
+ llm_launch_futures[key] = launch_executor.submit(
+ self._launch_llm_stage,
+ replica_cfg,
+ replica_metadata,
+ stage_connector_spec,
+ stage_init_timeout,
+ llm_stage_launch_lock,
+ omni_kv_connector,
+ )
concurrent.futures.wait(list(llm_launch_futures.values()))
- for stage_id in llm_stage_ids:
- started_llm_stages[stage_id] = llm_launch_futures[stage_id].result()
+ for key in llm_launch_keys:
+ started_llm_stages[key] = llm_launch_futures[key].result()
+
+ # ---- Build flat client lists directly ----
+ # Attach each launched replica and build the flat index structures.
+ flat_output_processors: list[Any] = []
+ flat_vllm_configs: list[Any] = []
+ logical_stage_to_clients: list[list[int]] = []
+
+ # Per-logical-stage lists (not per-client)
+ logical_stage_clients_for_finalize: list[Any | None] = [None] * num_stages
+ all_output_processors: dict[tuple[int, int], Any] = {}
+ all_vllm_configs: dict[tuple[int, int], Any] = {}
+
+ for key in llm_launch_keys:
+ stage_id, replica_idx = key
+ started = started_llm_stages[key]
+ client, output_proc, vllm_cfg, stage0_inp = self._attach_llm_stage(started)
+ all_clients[key] = client
+ all_output_processors[key] = output_proc
+ all_vllm_configs[key] = vllm_cfg
+ if stage0_inp is not None:
+ input_processor = stage0_inp
+ # Use first replica for finalize_initialized_stages
+ if replica_idx == 0:
+ logical_stage_clients_for_finalize[stage_id] = client
- for stage_id in llm_stage_ids:
- started = started_llm_stages[stage_id]
- stage_client, output_processor, vllm_config, stage0_input_processor = self._attach_llm_stage(started)
- stage_clients[stage_id] = stage_client
- output_processors[stage_id] = output_processor
- stage_vllm_configs[stage_id] = vllm_config
- if stage0_input_processor is not None:
- input_processor = stage0_input_processor
+ # Place diffusion clients into the logical list
+ for stage_id, diff_client in diffusion_clients.items():
+ logical_stage_clients_for_finalize[stage_id] = diff_client
initialized_stage_clients, default_sampling_params_list, stage_metadata = finalize_initialized_stages(
- stage_clients,
+ logical_stage_clients_for_finalize,
input_processor,
)
+
+ # Now build flat lists in logical-stage order, replicas within
+ logical_default_sampling_params: list[Any] = []
+ logical_stage_metadata: list[dict[str, Any]] = []
+
+ for logical_id in range(num_stages):
+ num_replicas = replicas_per_stage[logical_id]
+ client_indices: list[int] = []
+
+ if logical_id in diffusion_clients:
+ # Diffusion: single client, no multi-replica
+ ci = len(flat_clients)
+ client_indices.append(ci)
+ flat_clients.append(diffusion_clients[logical_id])
+ flat_output_processors.append(None)
+ flat_vllm_configs.append(None)
+ else:
+ for replica_idx in range(num_replicas):
+ key = (logical_id, replica_idx)
+ ci = len(flat_clients)
+ client_indices.append(ci)
+ flat_clients.append(all_clients[key])
+ flat_output_processors.append(all_output_processors[key])
+ flat_vllm_configs.append(all_vllm_configs[key])
+ if num_replicas > 1:
+ logger.info(
+ "[AsyncOmniEngine] Logical stage %s replica %s → client %s (isolated)",
+ logical_id, replica_idx, ci,
+ )
+
+ logical_stage_to_clients.append(client_indices)
+ logical_default_sampling_params.append(default_sampling_params_list[logical_id])
+ logical_stage_metadata.append(stage_metadata[logical_id])
+
except Exception:
- for stage_id, future in llm_launch_futures.items():
+ for key, future in llm_launch_futures.items():
if not future.done() or future.cancelled() or future.exception() is not None:
continue
- started_llm_stages.setdefault(stage_id, future.result())
+ started_llm_stages.setdefault(key, future.result())
+ # Collect all initialized clients for cleanup
+ cleanup_clients: list[Any] = list(diffusion_clients.values()) + list(all_clients.values())
+ cleanup_clients = [c for c in cleanup_clients if c is not None]
logger.exception(
- "[AsyncOmniEngine] Stage initialization failed; shutting down %s initialized stage(s)",
- len([stage_client for stage_client in stage_clients if stage_client is not None]),
+ "[AsyncOmniEngine] Stage initialization failed; shutting down %s initialized client(s)",
+ len(cleanup_clients),
)
cleanup_failed_stage_initialization(
- stage_clients,
- [started_llm_stages[stage_id] for stage_id in llm_stage_ids if stage_id in started_llm_stages],
+ cleanup_clients,
+ list(started_llm_stages.values()),
)
raise
- # ---- Multi-replica expansion ----
- # Expand the logical-indexed lists into flat client-indexed lists.
- # For replica_index > 0 the same client/processor/config is shared
- # (they point to the same underlying EngineCore) — full per-replica
- # process isolation is out-of-scope for this first iteration.
- flat_clients: list[Any] = []
- flat_output_processors: list[Any] = []
- flat_vllm_configs: list[Any] = []
- logical_stage_to_clients: list[list[int]] = []
- # sampling_params and stage_metadata are per-logical-stage
- logical_default_sampling_params: list[Any] = []
- logical_stage_metadata: list[dict[str, Any]] = []
-
- for logical_id, client in enumerate(initialized_stage_clients):
- num_replicas = replicas_per_stage[logical_id]
- client_indices: list[int] = []
- for replica_idx in range(num_replicas):
- ci = len(flat_clients)
- client_indices.append(ci)
- if replica_idx == 0:
- # First replica uses the already-created objects
- flat_clients.append(client)
- flat_output_processors.append(output_processors[logical_id])
- flat_vllm_configs.append(stage_vllm_configs[logical_id])
- else:
- # Additional replicas: for now, share the same client.
- # True per-replica process isolation will be added later.
- # TODO: launch separate EngineCore processes for replica_idx > 0
- flat_clients.append(client)
- flat_output_processors.append(output_processors[logical_id])
- flat_vllm_configs.append(stage_vllm_configs[logical_id])
- logger.info(
- "[AsyncOmniEngine] Logical stage %s replica %s → client %s (shared)",
- logical_id, replica_idx, ci,
- )
- logical_stage_to_clients.append(client_indices)
- logical_default_sampling_params.append(default_sampling_params_list[logical_id])
- logical_stage_metadata.append(stage_metadata[logical_id])
-
self.stage_clients = flat_clients
self.output_processors = flat_output_processors
self.stage_vllm_configs = flat_vllm_configs
diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py
index 73c63255dcc..1f360d85aa8 100644
--- a/vllm_omni/engine/stage_init_utils.py
+++ b/vllm_omni/engine/stage_init_utils.py
@@ -218,6 +218,51 @@ def prepare_engine_environment() -> None:
pass
+def split_devices_for_replicas(
+ devices_str: str | None,
+ num_replicas: int,
+ tp_size: int,
+ stage_id: int,
+) -> list[str]:
+ """Split a devices string into per-replica subsets.
+
+ When ``num_replicas`` is 1, returns ``[devices_str]`` unchanged.
+ Otherwise, the total number of device IDs must equal
+ ``num_replicas * tp_size``; each replica gets ``tp_size`` consecutive
+ device IDs.
+
+ Example::
+
+ split_devices_for_replicas("1,2,3,4", num_replicas=2, tp_size=2, stage_id=1)
+ # → ["1,2", "3,4"]
+ """
+ if num_replicas <= 1 or devices_str is None:
+ return [devices_str] if devices_str is not None else [devices_str]
+
+ device_list = [d.strip() for d in devices_str.split(",") if d.strip()]
+ required = num_replicas * tp_size
+ if len(device_list) != required:
+ raise ValueError(
+ f"Stage {stage_id}: num_replicas={num_replicas}, "
+ f"tensor_parallel_size={tp_size} requires "
+ f"{required} devices, got {len(device_list)}: {devices_str}"
+ )
+
+ result: list[str] = []
+ for r in range(num_replicas):
+ chunk = device_list[r * tp_size : (r + 1) * tp_size]
+ result.append(",".join(chunk))
+ return result
+
+
+def get_stage_tp_size(stage_cfg: Any) -> int:
+ """Extract tensor_parallel_size from a stage config object."""
+ engine_args = getattr(stage_cfg, "engine_args", {})
+ if hasattr(engine_args, "get"):
+ return int(engine_args.get("tensor_parallel_size", 1) or 1)
+ return int(getattr(engine_args, "tensor_parallel_size", 1) or 1)
+
+
def setup_stage_devices(stage_id: int, runtime_cfg: Any) -> None:
"""Device mapping via set_stage_devices for a single stage."""
physical_devices = set_stage_devices(
From 4786f5eafdcbd20bb466626c9037d53d57f27ed3 Mon Sep 17 00:00:00 2001
From: ZhengWG
Date: Wed, 1 Apr 2026 14:30:46 +0800
Subject: [PATCH 03/76] fix lint
Signed-off-by: ZhengWG
Made-with: Cursor
---
vllm_omni/engine/async_omni_engine.py | 19 +++++++++++++------
vllm_omni/engine/orchestrator.py | 16 +++++++++++++---
2 files changed, 26 insertions(+), 9 deletions(-)
diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py
index 49b5bd5b161..c875fca82ac 100644
--- a/vllm_omni/engine/async_omni_engine.py
+++ b/vllm_omni/engine/async_omni_engine.py
@@ -463,7 +463,9 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
for stage_cfg in self.stage_configs:
runtime_cfg = getattr(stage_cfg, "runtime", {})
num_replicas = int(
- runtime_cfg.get("num_replicas", 1) if hasattr(runtime_cfg, "get") else getattr(runtime_cfg, "num_replicas", 1)
+ runtime_cfg.get("num_replicas", 1)
+ if hasattr(runtime_cfg, "get")
+ else getattr(runtime_cfg, "num_replicas", 1)
)
replicas_per_stage.append(max(1, num_replicas))
@@ -475,8 +477,7 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
continue
runtime_cfg = getattr(stage_cfg, "runtime", {})
devices_str = (
- runtime_cfg.get("devices") if hasattr(runtime_cfg, "get")
- else getattr(runtime_cfg, "devices", None)
+ runtime_cfg.get("devices") if hasattr(runtime_cfg, "get") else getattr(runtime_cfg, "devices", None)
)
tp_size = get_stage_tp_size(stage_cfg)
per_replica = split_devices_for_replicas(devices_str, num_replicas, tp_size, logical_id)
@@ -484,7 +485,10 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
replica_devices_map[(logical_id, r)] = dev_str
logger.info(
"[AsyncOmniEngine] Stage %s: %d replicas, tp=%d, devices split: %s",
- logical_id, num_replicas, tp_size, per_replica,
+ logical_id,
+ num_replicas,
+ tp_size,
+ per_replica,
)
async_chunk = self.async_chunk
@@ -569,7 +573,8 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
logger.info(
"[AsyncOmniEngine] Launching stage %s replica %s (devices=%s)",
- stage_id, replica_idx,
+ stage_id,
+ replica_idx,
getattr(getattr(replica_cfg, "runtime", None), "devices", "default"),
)
@@ -647,7 +652,9 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
if num_replicas > 1:
logger.info(
"[AsyncOmniEngine] Logical stage %s replica %s → client %s (isolated)",
- logical_id, replica_idx, ci,
+ logical_id,
+ replica_idx,
+ ci,
)
logical_stage_to_clients.append(client_indices)
diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py
index 3106b876809..9c44f8ab605 100644
--- a/vllm_omni/engine/orchestrator.py
+++ b/vllm_omni/engine/orchestrator.py
@@ -302,7 +302,10 @@ async def _orchestration_loop(self) -> None:
client_index, output.request_id, [output], req_state
)
await self._route_output(
- logical_stage_id, output, req_state, stage_metrics,
+ logical_stage_id,
+ output,
+ req_state,
+ stage_metrics,
client_index=client_index,
)
continue
@@ -353,7 +356,10 @@ async def _orchestration_loop(self) -> None:
req_state,
)
await self._route_output(
- logical_stage_id, output, req_state, stage_metrics,
+ logical_stage_id,
+ output,
+ req_state,
+ stage_metrics,
client_index=client_index,
)
@@ -448,7 +454,11 @@ async def _route_output(
)
else:
await self._forward_to_next_stage(
- req_id, stage_id, output, req_state, client_index=client_index,
+ req_id,
+ stage_id,
+ output,
+ req_state,
+ client_index=client_index,
)
if finished and stage_id == req_state.final_stage_id:
From 11976f10a51291aeb0801f7de81501ea9f5cd582 Mon Sep 17 00:00:00 2001
From: ZhengWG
Date: Thu, 2 Apr 2026 16:13:42 +0800
Subject: [PATCH 04/76] refacotr: keep name consistency
Signed-off-by: ZhengWG
Made-with: Cursor
---
vllm_omni/engine/async_omni_engine.py | 112 +++++----
vllm_omni/engine/orchestrator.py | 229 ++++++++++---------
vllm_omni/engine/stage_engine_core_client.py | 14 +-
3 files changed, 196 insertions(+), 159 deletions(-)
diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py
index c875fca82ac..d0dca006324 100644
--- a/vllm_omni/engine/async_omni_engine.py
+++ b/vllm_omni/engine/async_omni_engine.py
@@ -450,10 +450,10 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
num_stages = self.num_stages
input_processor: InputProcessor | None = None
- # Keyed by (logical_stage_id, replica_idx)
- llm_launch_keys: list[tuple[int, int]] = []
- llm_launch_futures: dict[tuple[int, int], concurrent.futures.Future[StartedLlmStage]] = {}
- started_llm_stages: dict[tuple[int, int], StartedLlmStage] = {}
+ # Per-stage launch futures and results: stage_id → [replicas]
+ llm_stage_ids: list[int] = []
+ llm_launch_futures: dict[int, list[concurrent.futures.Future[StartedLlmStage]]] = {}
+ started_llm_stages: dict[int, list[StartedLlmStage]] = {}
llm_stage_launch_lock = threading.Lock()
# Diffusion stages (no multi-replica support yet)
diffusion_clients: dict[int, Any] = {}
@@ -470,7 +470,8 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
replicas_per_stage.append(max(1, num_replicas))
# Pre-compute per-replica device assignments for multi-replica stages
- replica_devices_map: dict[tuple[int, int], str] = {}
+ # stage_id → [devices_str_per_replica]
+ replica_devices_map: dict[int, list[str]] = {}
for logical_id, stage_cfg in enumerate(self.stage_configs):
num_replicas = replicas_per_stage[logical_id]
if num_replicas <= 1:
@@ -480,15 +481,15 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
runtime_cfg.get("devices") if hasattr(runtime_cfg, "get") else getattr(runtime_cfg, "devices", None)
)
tp_size = get_stage_tp_size(stage_cfg)
- per_replica = split_devices_for_replicas(devices_str, num_replicas, tp_size, logical_id)
- for r, dev_str in enumerate(per_replica):
- replica_devices_map[(logical_id, r)] = dev_str
+ replica_devices_map[logical_id] = split_devices_for_replicas(
+ devices_str, num_replicas, tp_size, logical_id,
+ )
logger.info(
"[AsyncOmniEngine] Stage %s: %d replicas, tp=%d, devices split: %s",
logical_id,
num_replicas,
tp_size,
- per_replica,
+ replica_devices_map[logical_id],
)
async_chunk = self.async_chunk
@@ -504,7 +505,8 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
# Initialized outside try so error handler can always access them
flat_clients: list[Any] = []
- all_clients: dict[tuple[int, int], Any] = {}
+ # stage_id → [client_per_replica]
+ all_clients: dict[int, list[Any]] = {}
try:
with concurrent.futures.ThreadPoolExecutor(
@@ -554,19 +556,19 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
continue
# Submit one launch future per replica
+ llm_stage_ids.append(stage_id)
num_replicas = replicas_per_stage[stage_id]
- for replica_idx in range(num_replicas):
- key = (stage_id, replica_idx)
- llm_launch_keys.append(key)
+ stage_futures: list[concurrent.futures.Future[StartedLlmStage]] = []
+ for replica_idx in range(num_replicas):
# For replica > 0, deep-copy stage_cfg and override devices
if replica_idx > 0:
replica_cfg = copy.deepcopy(stage_cfg)
else:
replica_cfg = stage_cfg
- if key in replica_devices_map:
- replica_cfg.runtime.devices = replica_devices_map[key]
+ if stage_id in replica_devices_map:
+ replica_cfg.runtime.devices = replica_devices_map[stage_id][replica_idx]
replica_metadata = extract_stage_metadata(replica_cfg)
replica_metadata.replica_index = replica_idx
@@ -578,7 +580,7 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
getattr(getattr(replica_cfg, "runtime", None), "devices", "default"),
)
- llm_launch_futures[key] = launch_executor.submit(
+ stage_futures.append(launch_executor.submit(
self._launch_llm_stage,
replica_cfg,
replica_metadata,
@@ -586,12 +588,18 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
stage_init_timeout,
llm_stage_launch_lock,
omni_kv_connector,
- )
+ ))
- concurrent.futures.wait(list(llm_launch_futures.values()))
+ llm_launch_futures[stage_id] = stage_futures
- for key in llm_launch_keys:
- started_llm_stages[key] = llm_launch_futures[key].result()
+ # Wait for all futures across all stages
+ all_futures = [f for futures in llm_launch_futures.values() for f in futures]
+ concurrent.futures.wait(all_futures)
+
+ for stage_id in llm_stage_ids:
+ started_llm_stages[stage_id] = [
+ f.result() for f in llm_launch_futures[stage_id]
+ ]
# ---- Build flat client lists directly ----
# Attach each launched replica and build the flat index structures.
@@ -601,21 +609,27 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
# Per-logical-stage lists (not per-client)
logical_stage_clients_for_finalize: list[Any | None] = [None] * num_stages
- all_output_processors: dict[tuple[int, int], Any] = {}
- all_vllm_configs: dict[tuple[int, int], Any] = {}
-
- for key in llm_launch_keys:
- stage_id, replica_idx = key
- started = started_llm_stages[key]
- client, output_proc, vllm_cfg, stage0_inp = self._attach_llm_stage(started)
- all_clients[key] = client
- all_output_processors[key] = output_proc
- all_vllm_configs[key] = vllm_cfg
- if stage0_inp is not None:
- input_processor = stage0_inp
+ all_output_processors: dict[int, list[Any]] = {}
+ all_vllm_configs: dict[int, list[Any]] = {}
+
+ for stage_id in llm_stage_ids:
+ stage_clients_list: list[Any] = []
+ stage_output_procs: list[Any] = []
+ stage_vllm_cfgs: list[Any] = []
+
+ for replica_idx, started in enumerate(started_llm_stages[stage_id]):
+ client, output_proc, vllm_cfg, stage0_inp = self._attach_llm_stage(started)
+ stage_clients_list.append(client)
+ stage_output_procs.append(output_proc)
+ stage_vllm_cfgs.append(vllm_cfg)
+ if stage0_inp is not None:
+ input_processor = stage0_inp
+
+ all_clients[stage_id] = stage_clients_list
+ all_output_processors[stage_id] = stage_output_procs
+ all_vllm_configs[stage_id] = stage_vllm_cfgs
# Use first replica for finalize_initialized_stages
- if replica_idx == 0:
- logical_stage_clients_for_finalize[stage_id] = client
+ logical_stage_clients_for_finalize[stage_id] = stage_clients_list[0]
# Place diffusion clients into the logical list
for stage_id, diff_client in diffusion_clients.items():
@@ -643,15 +657,14 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
flat_vllm_configs.append(None)
else:
for replica_idx in range(num_replicas):
- key = (logical_id, replica_idx)
ci = len(flat_clients)
client_indices.append(ci)
- flat_clients.append(all_clients[key])
- flat_output_processors.append(all_output_processors[key])
- flat_vllm_configs.append(all_vllm_configs[key])
+ flat_clients.append(all_clients[logical_id][replica_idx])
+ flat_output_processors.append(all_output_processors[logical_id][replica_idx])
+ flat_vllm_configs.append(all_vllm_configs[logical_id][replica_idx])
if num_replicas > 1:
logger.info(
- "[AsyncOmniEngine] Logical stage %s replica %s → client %s (isolated)",
+ "[AsyncOmniEngine] Stage %s replica %s → client %s (isolated)",
logical_id,
replica_idx,
ci,
@@ -662,21 +675,22 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
logical_stage_metadata.append(stage_metadata[logical_id])
except Exception:
- for key, future in llm_launch_futures.items():
- if not future.done() or future.cancelled() or future.exception() is not None:
- continue
- started_llm_stages.setdefault(key, future.result())
+ for stage_id, futures in llm_launch_futures.items():
+ for f in futures:
+ if not f.done() or f.cancelled() or f.exception() is not None:
+ continue
+ started_llm_stages.setdefault(stage_id, []).append(f.result())
# Collect all initialized clients for cleanup
- cleanup_clients: list[Any] = list(diffusion_clients.values()) + list(all_clients.values())
+ cleanup_clients: list[Any] = list(diffusion_clients.values())
+ for clients in all_clients.values():
+ cleanup_clients.extend(clients)
cleanup_clients = [c for c in cleanup_clients if c is not None]
+ all_started = [s for stages in started_llm_stages.values() for s in stages]
logger.exception(
"[AsyncOmniEngine] Stage initialization failed; shutting down %s initialized client(s)",
len(cleanup_clients),
)
- cleanup_failed_stage_initialization(
- cleanup_clients,
- list(started_llm_stages.values()),
- )
+ cleanup_failed_stage_initialization(cleanup_clients, all_started)
raise
self.stage_clients = flat_clients
@@ -725,7 +739,7 @@ async def _run_orchestrator() -> None:
stage_clients=self.stage_clients,
output_processors=self.output_processors,
stage_vllm_configs=self.stage_vllm_configs,
- logical_stage_to_clients=getattr(self, "logical_stage_to_clients", None),
+ logical_stage_to_clients=self.logical_stage_to_clients,
)
if not startup_future.done():
startup_future.set_result(asyncio.get_running_loop())
diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py
index 9c44f8ab605..7e5e75a6ae2 100644
--- a/vllm_omni/engine/orchestrator.py
+++ b/vllm_omni/engine/orchestrator.py
@@ -148,11 +148,13 @@ def __init__(
self.logical_stage_to_clients = [[i] for i in range(self.num_clients)]
self.num_logical_stages = len(self.logical_stage_to_clients)
- # Reverse mapping: client_index -> logical_stage_id
+ # Reverse mappings: client_index -> (logical_stage_id, replica_index)
self._client_to_logical: list[int] = [0] * self.num_clients
+ self._client_to_replica: list[int] = [0] * self.num_clients
for logical_id, client_indices in enumerate(self.logical_stage_to_clients):
- for ci in client_indices:
+ for ri, ci in enumerate(client_indices):
self._client_to_logical[ci] = logical_id
+ self._client_to_replica[ci] = ri
# Round-robin counters for replica selection per logical stage
self._replica_rr: list[int] = [0] * self.num_logical_stages
@@ -205,6 +207,10 @@ def _choose_client_index(
req_state.chosen_client_index[logical_stage_id] = chosen
return chosen
+ def _resolve_client_index(self, stage_id: int, replica_index: int = 0) -> int:
+ """Resolve (stage_id, replica_index) to a flat client index."""
+ return self.logical_stage_to_clients[stage_id][replica_index]
+
async def run(self) -> None:
"""Main entry point for the Orchestrator event loop."""
logger.info("[Orchestrator] Starting event loop")
@@ -279,90 +285,92 @@ async def _orchestration_loop(self) -> None:
Control flow: poll raw → process through output processor → route.
- Multi-replica: iterates over every *client_index* (not logical stage),
- and resolves the logical_stage_id from client metadata for routing.
+ Multi-replica: iterates over every (stage_id, replica_index) pair,
+ resolves to a flat client_index internally for resource access.
"""
while not self._shutdown_event.is_set():
idle = True
- for client_index in range(self.num_clients):
- if self._shutdown_event.is_set():
- return
-
- logical_stage_id = self._client_to_logical[client_index]
-
- # 1) Diffusion stage: poll non-blocking queue
- stage_client = self.stage_clients[client_index]
- if stage_client.stage_type == "diffusion":
- output = stage_client.get_diffusion_output_async()
- if output is not None:
- idle = False
- req_state = self.request_states.get(output.request_id)
- if req_state is not None:
- stage_metrics = self._build_stage_metrics(
- client_index, output.request_id, [output], req_state
- )
- await self._route_output(
- logical_stage_id,
- output,
- req_state,
- stage_metrics,
- client_index=client_index,
- )
- continue
-
- # 1) Poll raw outputs from the client
- try:
- raw_outputs = await asyncio.wait_for(self._poll_stage_raw(client_index), timeout=0.001)
- except asyncio.TimeoutError:
- continue
- except asyncio.CancelledError:
- raise
- except Exception:
+ for stage_id in range(self.num_logical_stages):
+ for replica_index in range(len(self.logical_stage_to_clients[stage_id])):
if self._shutdown_event.is_set():
return
- logger.exception(
- "[Orchestrator] _poll_stage_raw failed for client-%s (logical stage-%s)",
- client_index,
- logical_stage_id,
- )
- raise
-
- if raw_outputs is None:
- continue
- idle = False
-
- # 2) Process raw outputs through the output processor
- request_outputs = await self._process_stage_outputs(client_index, raw_outputs)
-
- # 3) Route each processed output
- for output in request_outputs:
- req_state = self.request_states.get(output.request_id)
- if req_state is None:
- logger.warning(
- "[Orchestrator] Dropping output for unknown req %s at client-%s "
- "(logical stage-%s, known reqs: %s)",
- output.request_id,
- client_index,
- logical_stage_id,
- list(self.request_states.keys()),
+
+ client_index = self._resolve_client_index(stage_id, replica_index)
+
+ # 1) Diffusion stage: poll non-blocking queue
+ # TODO (Peiqi): the output of diffusion stage is OmniRequestOutput,
+ # which is different from EngineCoreOutputs (LLM stages). We may want to unify
+ # the output format in the future to simplify the processing logic in Orchestrator.
+ stage_client = self.stage_clients[client_index]
+ if stage_client.stage_type == "diffusion":
+ output = stage_client.get_diffusion_output_async()
+ if output is not None:
+ idle = False
+ req_state = self.request_states.get(output.request_id)
+ if req_state is not None:
+ stage_metrics = self._build_stage_metrics(
+ stage_id, output.request_id, [output], req_state,
+ replica_index=replica_index,
+ )
+ await self._route_output(
+ stage_id, output, req_state, stage_metrics,
+ replica_index=replica_index,
+ )
+ continue
+
+ # 1) Poll raw outputs from the stage replica
+ try:
+ raw_outputs = await asyncio.wait_for(
+ self._poll_stage_raw(stage_id, replica_index=replica_index),
+ timeout=0.001,
)
+ except asyncio.TimeoutError:
continue
- stage_metrics = None
- if output.finished:
- stage_metrics = self._build_stage_metrics(
- client_index,
- output.request_id,
- [output],
- req_state,
+ except asyncio.CancelledError:
+ raise
+ except Exception:
+ if self._shutdown_event.is_set():
+ return
+ logger.exception(
+ "[Orchestrator] _poll_stage_raw failed for stage-%s replica-%s",
+ stage_id,
+ replica_index,
)
- await self._route_output(
- logical_stage_id,
- output,
- req_state,
- stage_metrics,
- client_index=client_index,
+ raise
+
+ if raw_outputs is None:
+ continue
+ idle = False
+
+ # 2) Process raw outputs through the output processor
+ request_outputs = await self._process_stage_outputs(
+ stage_id, raw_outputs, replica_index=replica_index,
)
+ # 3) Route each processed output
+ for output in request_outputs:
+ req_state = self.request_states.get(output.request_id)
+ if req_state is None:
+ logger.warning(
+ "[Orchestrator] Dropping output for unknown req %s "
+ "at stage-%s replica-%s (known reqs: %s)",
+ output.request_id,
+ stage_id,
+ replica_index,
+ list(self.request_states.keys()),
+ )
+ continue
+ stage_metrics = None
+ if output.finished:
+ stage_metrics = self._build_stage_metrics(
+ stage_id, output.request_id, [output], req_state,
+ replica_index=replica_index,
+ )
+ await self._route_output(
+ stage_id, output, req_state, stage_metrics,
+ replica_index=replica_index,
+ )
+
if idle:
await asyncio.sleep(0.001)
else:
@@ -375,17 +383,15 @@ async def _route_output(
req_state: OrchestratorRequestState,
stage_metrics: Any,
*,
- client_index: int | None = None,
+ replica_index: int = 0,
) -> None:
"""Route a processed output: send to main thread and/or forward to next stage.
Args:
stage_id: Logical stage id.
- client_index: Physical client index that produced this output.
- Defaults to stage_id for backward compat.
+ replica_index: Replica index within the logical stage.
"""
- if client_index is None:
- client_index = stage_id
+ client_index = self._resolve_client_index(stage_id, replica_index)
req_id = output.request_id
finished = output.finished
submit_ts = req_state.stage_submit_ts.get(stage_id)
@@ -412,7 +418,7 @@ async def _route_output(
deferred["stage_id"],
deferred["output"],
parent_state,
- client_index=deferred.get("client_index", deferred["stage_id"]),
+ replica_index=deferred.get("replica_index", 0),
)
self.request_states.pop(req_id, None)
return
@@ -446,7 +452,7 @@ async def _route_output(
self._deferred_parents[req_id] = {
"stage_id": stage_id,
"output": output,
- "client_index": client_index,
+ "replica_index": replica_index,
}
logger.debug(
"[Orchestrator] Parent %s deferred, waiting for CFG companions",
@@ -458,7 +464,7 @@ async def _route_output(
stage_id,
output,
req_state,
- client_index=client_index,
+ replica_index=replica_index,
)
if finished and stage_id == req_state.final_stage_id:
@@ -484,24 +490,26 @@ def _all_companions_done(self, parent_id: str) -> bool:
def _build_stage_metrics(
self,
- client_index: int,
+ stage_id: int,
req_id: str,
request_outputs: list[RequestOutput],
req_state: OrchestratorRequestState,
+ *,
+ replica_index: int = 0,
) -> StageRequestMetrics:
- """Build StageRequestMetrics for a finished request at a client.
+ """Build StageRequestMetrics for a finished request at a stage replica.
Reuses StageRequestMetrics so OrchestratorMetrics and downstream
metric handlers can consume a stable schema.
"""
- logical_stage_id = self._client_to_logical[client_index]
+ client_index = self._resolve_client_index(stage_id, replica_index)
now = _time.time()
- submit_ts = req_state.stage_submit_ts.get(logical_stage_id, now)
+ submit_ts = req_state.stage_submit_ts.get(stage_id, now)
stage_gen_time_ms = (now - submit_ts) * 1000.0
num_tokens_out = count_tokens_from_outputs(request_outputs)
num_tokens_in = 0
- if logical_stage_id == 0:
+ if stage_id == 0:
for ro in request_outputs:
ptids = getattr(ro, "prompt_token_ids", None)
if ptids is not None:
@@ -537,7 +545,7 @@ async def _forward_to_next_stage(
output: Any,
req_state: OrchestratorRequestState,
*,
- client_index: int | None = None,
+ replica_index: int = 0,
) -> None:
"""Forward output from current stage to the next stage.
@@ -546,10 +554,9 @@ async def _forward_to_next_stage(
Args:
stage_id: Logical stage id that produced the output.
- client_index: Physical client index that produced the output.
+ replica_index: Replica index of the stage that produced the output.
"""
- if client_index is None:
- client_index = stage_id
+ client_index = self._resolve_client_index(stage_id, replica_index)
next_logical = stage_id + 1
next_ci = self._choose_client_index(next_logical, req_state)
@@ -604,14 +611,14 @@ async def _forward_to_next_stage(
next_inputs = next_client.process_engine_inputs(
stage_list=self.stage_clients,
prompt=req_state.prompt,
- source_client_index=client_index,
+ source_client=self.stage_clients[client_index],
)
except Exception:
logger.exception(
- "[Orchestrator] req=%s process_engine_inputs FAILED for logical stage-%s (client-%s)",
+ "[Orchestrator] req=%s process_engine_inputs FAILED for stage-%s replica-%s",
req_id,
next_logical,
- next_ci,
+ self._client_to_replica[next_ci],
)
raise
@@ -640,22 +647,28 @@ async def _forward_to_next_stage(
# Record submit timestamp for the next logical stage
req_state.stage_submit_ts[next_logical] = _time.time()
- async def _poll_stage_raw(self, client_index: int) -> EngineCoreOutputs | None:
- """Pull raw EngineCoreOutputs from a stage client without processing.
+ async def _poll_stage_raw(
+ self, stage_id: int, *, replica_index: int = 0,
+ ) -> EngineCoreOutputs | None:
+ """Pull raw EngineCoreOutputs from a stage replica without processing.
Returns the raw outputs object, or None when there is nothing
to consume.
"""
+ client_index = self._resolve_client_index(stage_id, replica_index)
outputs = await self.stage_clients[client_index].get_output_async()
if not outputs.outputs:
return None
return outputs
- async def _process_stage_outputs(self, client_index: int, raw_outputs: EngineCoreOutputs) -> list[RequestOutput]:
+ async def _process_stage_outputs(
+ self, stage_id: int, raw_outputs: EngineCoreOutputs, *, replica_index: int = 0,
+ ) -> list[RequestOutput]:
"""Run the output processor on raw outputs, returning RequestOutputs.
Also handles abort forwarding and scheduler stats updates.
"""
+ client_index = self._resolve_client_index(stage_id, replica_index)
processor = self.output_processors[client_index]
processed = processor.process_outputs(
@@ -837,11 +850,12 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None:
await stage_client.add_request_async(request)
logger.info(
- "[Orchestrator] CFG companion submitted: %s (role=%s, parent=%s, client=%s)",
+ "[Orchestrator] CFG companion submitted: %s (role=%s, parent=%s, "
+ "stage-0 replica-%s)",
companion_id,
role,
parent_id,
- client_index,
+ self._client_to_replica[client_index],
)
async def _handle_abort(self, msg: dict[str, Any]) -> None:
@@ -951,6 +965,15 @@ def _shutdown_stages(self) -> None:
for ci, stage_client in enumerate(self.stage_clients):
try:
stage_client.shutdown()
- logger.info("[Orchestrator] Client %d shut down", ci)
+ logger.info(
+ "[Orchestrator] Stage %d replica %d shut down",
+ self._client_to_logical[ci],
+ self._client_to_replica[ci],
+ )
except Exception as e:
- logger.warning("[Orchestrator] Failed to shutdown client %d: %s", ci, e)
+ logger.warning(
+ "[Orchestrator] Failed to shutdown stage %d replica %d: %s",
+ self._client_to_logical[ci],
+ self._client_to_replica[ci],
+ e,
+ )
diff --git a/vllm_omni/engine/stage_engine_core_client.py b/vllm_omni/engine/stage_engine_core_client.py
index 395a9d84550..dd8c1b4fd55 100644
--- a/vllm_omni/engine/stage_engine_core_client.py
+++ b/vllm_omni/engine/stage_engine_core_client.py
@@ -123,15 +123,14 @@ def process_engine_inputs(
self,
stage_list: list[Any],
prompt: OmniTokensPrompt | list[OmniTokensPrompt] | None = None,
- source_client_index: int | None = None,
+ source_client: Any | None = None,
) -> list[OmniTokensPrompt]:
"""Process inputs from upstream stages.
Args:
- source_client_index: When multi-replica is enabled, specifies the
- exact client index in *stage_list* that produced the upstream
- output. Falls back to ``engine_input_source[0]`` for backward
- compat.
+ source_client: When multi-replica is enabled, the upstream client
+ object that produced the output. Falls back to
+ ``stage_list[engine_input_source[0]]`` for backward compat.
"""
from vllm_omni.inputs.data import OmniTokensPrompt
@@ -146,8 +145,9 @@ def process_engine_inputs(
if not self.engine_input_source:
raise ValueError(f"engine_input_source empty for stage {self.stage_id}")
- source_id = source_client_index if source_client_index is not None else self.engine_input_source[0]
- source_outputs = stage_list[source_id].engine_outputs
+ if source_client is None:
+ source_client = stage_list[self.engine_input_source[0]]
+ source_outputs = source_client.engine_outputs
if not isinstance(prompt, list):
prompt = [prompt]
From eaa9dfdab2df94f016166350a34613c9d3ebc3ed Mon Sep 17 00:00:00 2001
From: ZhengWG
Date: Sun, 5 Apr 2026 00:41:59 +0800
Subject: [PATCH 05/76] fix lint
Signed-off-by: ZhengWG
---
vllm_omni/engine/async_omni_engine.py | 29 +++++++++---------
vllm_omni/engine/orchestrator.py | 42 ++++++++++++++++++++-------
2 files changed, 48 insertions(+), 23 deletions(-)
diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py
index b7177ebe10f..050c635de23 100644
--- a/vllm_omni/engine/async_omni_engine.py
+++ b/vllm_omni/engine/async_omni_engine.py
@@ -482,7 +482,10 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
)
tp_size = get_stage_tp_size(stage_cfg)
replica_devices_map[logical_id] = split_devices_for_replicas(
- devices_str, num_replicas, tp_size, logical_id,
+ devices_str,
+ num_replicas,
+ tp_size,
+ logical_id,
)
logger.info(
"[AsyncOmniEngine] Stage %s: %d replicas, tp=%d, devices split: %s",
@@ -580,15 +583,17 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
getattr(getattr(replica_cfg, "runtime", None), "devices", "default"),
)
- stage_futures.append(launch_executor.submit(
- self._launch_llm_stage,
- replica_cfg,
- replica_metadata,
- stage_connector_spec,
- stage_init_timeout,
- llm_stage_launch_lock,
- omni_kv_connector,
- ))
+ stage_futures.append(
+ launch_executor.submit(
+ self._launch_llm_stage,
+ replica_cfg,
+ replica_metadata,
+ stage_connector_spec,
+ stage_init_timeout,
+ llm_stage_launch_lock,
+ omni_kv_connector,
+ )
+ )
llm_launch_futures[stage_id] = stage_futures
@@ -597,9 +602,7 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
concurrent.futures.wait(all_futures)
for stage_id in llm_stage_ids:
- started_llm_stages[stage_id] = [
- f.result() for f in llm_launch_futures[stage_id]
- ]
+ started_llm_stages[stage_id] = [f.result() for f in llm_launch_futures[stage_id]]
# ---- Build flat client lists directly ----
# Attach each launched replica and build the flat index structures.
diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py
index daabdbe4d8b..b79f88933ff 100644
--- a/vllm_omni/engine/orchestrator.py
+++ b/vllm_omni/engine/orchestrator.py
@@ -309,11 +309,17 @@ async def _orchestration_loop(self) -> None:
req_state = self.request_states.get(output.request_id)
if req_state is not None:
stage_metrics = self._build_stage_metrics(
- stage_id, output.request_id, [output], req_state,
+ stage_id,
+ output.request_id,
+ [output],
+ req_state,
replica_index=replica_index,
)
await self._route_output(
- stage_id, output, req_state, stage_metrics,
+ stage_id,
+ output,
+ req_state,
+ stage_metrics,
replica_index=replica_index,
)
continue
@@ -344,12 +350,16 @@ async def _orchestration_loop(self) -> None:
# Handle prefill-finished KV-ready signals before finished outputs.
await self._handle_kv_ready_raw_outputs(
- stage_id, raw_outputs, replica_index=replica_index,
+ stage_id,
+ raw_outputs,
+ replica_index=replica_index,
)
# 2) Process raw outputs through the output processor
request_outputs = await self._process_stage_outputs(
- stage_id, raw_outputs, replica_index=replica_index,
+ stage_id,
+ raw_outputs,
+ replica_index=replica_index,
)
# 3) Route each processed output
@@ -368,11 +378,17 @@ async def _orchestration_loop(self) -> None:
stage_metrics = None
if output.finished:
stage_metrics = self._build_stage_metrics(
- stage_id, output.request_id, [output], req_state,
+ stage_id,
+ output.request_id,
+ [output],
+ req_state,
replica_index=replica_index,
)
await self._route_output(
- stage_id, output, req_state, stage_metrics,
+ stage_id,
+ output,
+ req_state,
+ stage_metrics,
replica_index=replica_index,
)
@@ -698,7 +714,10 @@ async def _forward_to_next_stage(
req_state.stage_submit_ts[next_logical] = _time.time()
async def _poll_stage_raw(
- self, stage_id: int, *, replica_index: int = 0,
+ self,
+ stage_id: int,
+ *,
+ replica_index: int = 0,
) -> EngineCoreOutputs | None:
"""Pull raw EngineCoreOutputs from a stage replica without processing.
@@ -712,7 +731,11 @@ async def _poll_stage_raw(
return outputs
async def _process_stage_outputs(
- self, stage_id: int, raw_outputs: EngineCoreOutputs, *, replica_index: int = 0,
+ self,
+ stage_id: int,
+ raw_outputs: EngineCoreOutputs,
+ *,
+ replica_index: int = 0,
) -> list[RequestOutput]:
"""Run the output processor on raw outputs, returning RequestOutputs.
@@ -900,8 +923,7 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None:
await stage_client.add_request_async(request)
logger.info(
- "[Orchestrator] CFG companion submitted: %s (role=%s, parent=%s, "
- "stage-0 replica-%s)",
+ "[Orchestrator] CFG companion submitted: %s (role=%s, parent=%s, stage-0 replica-%s)",
companion_id,
role,
parent_id,
From e12250119bc7f90745354a5349e550f391fa123b Mon Sep 17 00:00:00 2001
From: NATURE
Date: Mon, 13 Apr 2026 11:36:20 +0800
Subject: [PATCH 06/76] [Bugfix] Fix Bagel online mode for 1. Hang after
several requests 2. Non-deterministic image quality regression. (#2458)
Signed-off-by: natureofnature
---
vllm_omni/core/sched/omni_ar_scheduler.py | 105 +++++-----
.../model_executor/models/bagel/bagel.py | 195 ++++++------------
.../npu/worker/npu_ar_model_runner.py | 26 ++-
vllm_omni/worker/gpu_ar_model_runner.py | 35 +++-
4 files changed, 164 insertions(+), 197 deletions(-)
diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py
index eac737b6e66..0ee8cd16a3a 100644
--- a/vllm_omni/core/sched/omni_ar_scheduler.py
+++ b/vllm_omni/core/sched/omni_ar_scheduler.py
@@ -59,6 +59,11 @@ def __init__(self, *args, **kwargs):
# Track ACTIVE transfers (submitted to runner but not yet acked via kv_extracted_req_ids)
self.active_kv_transfers: set[str] = set()
+ # Requests marked for deferred stop: keep running until KV extraction
+ # completes so that kv_ready can be emitted while the request is still
+ # alive. Stopped on the first scheduler step after extraction ack.
+ self.pending_stop_after_extraction: set[str] = set()
+
# [Omni] Pre-parse KV transfer criteria
self.kv_transfer_criteria = self._get_kv_transfer_criteria()
@@ -126,11 +131,16 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
stop_decode_on_trigger = self.kv_transfer_criteria.get("stop_after_transfer", True)
if request.request_id in self.transfer_triggered_requests:
- # Already triggered. When stop_decode_on_trigger is True AND
- # transfer was actually queued, the request was already stopped
- # at trigger time (see below). Any request that reaches this
- # point either has stop_decode_on_trigger=False (continue
- # decoding) or was not actually queued (should not be stopped).
+ # Deferred stop: once KV extraction is complete (no longer in
+ # active_kv_transfers), stop the request. This guarantees the
+ # kv_ready signal was emitted while the request was still alive.
+ if (
+ request.request_id in self.pending_stop_after_extraction
+ and request.request_id not in self.active_kv_transfers
+ ):
+ self.pending_stop_after_extraction.discard(request.request_id)
+ request.status = RequestStatus.FINISHED_STOPPED
+ return True
return False
if criteria_type == "prefill_finished":
@@ -140,14 +150,11 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
actually_queued = request.request_id in self.requests_needing_kv_transfer
if stop_decode_on_trigger and actually_queued:
- # Stop immediately so the request is NOT scheduled in
- # the next step, freeing scheduling budget for companion
- # requests whose chunked-prefill boundaries must be
- # deterministic. waiting_for_transfer_free keeps blocks
- # alive until the model runner finishes KV extraction.
- self.waiting_for_transfer_free.add(request.request_id)
- request.status = RequestStatus.FINISHED_STOPPED
- return True
+ # Defer the stop until KV extraction completes so that
+ # the kv_ready signal can be emitted while the request
+ # is still alive. The request will be stopped on the
+ # next scheduler step after extraction ack arrives.
+ self.pending_stop_after_extraction.add(request.request_id)
return False
@@ -167,9 +174,7 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
actually_queued = request.request_id in self.requests_needing_kv_transfer
if stop_decode_on_trigger and actually_queued:
- self.waiting_for_transfer_free.add(request.request_id)
- request.status = RequestStatus.FINISHED_STOPPED
- return True
+ self.pending_stop_after_extraction.add(request.request_id)
return False
@@ -268,6 +273,26 @@ def update_from_output(
num_scheduled_tokens,
)
+ # Pre-process KV extraction acks so that the per-request loop below
+ # can see up-to-date active_kv_transfers state and emit kv_ready
+ # signals while requests are still alive (before any deferred stop).
+ kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None)
+ if kv_extracted_ids:
+ for req_id in kv_extracted_ids:
+ try:
+ self.active_kv_transfers.discard(req_id)
+ req = self.requests.get(req_id)
+ if req is not None and not req.is_finished():
+ outputs[req.client_index].append(
+ EngineCoreOutput(
+ request_id=req_id,
+ new_token_ids=[],
+ kv_transfer_params={"kv_ready": True},
+ )
+ )
+ except Exception:
+ init_logger(__name__).exception("Failed to pre-process KV extraction for %s", req_id)
+
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best
# to avoid expensive operations inside the loop.
@@ -436,6 +461,7 @@ def update_from_output(
self.transfer_triggered_requests.remove(req.request_id)
if req.request_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req.request_id)
+ self.pending_stop_after_extraction.discard(req.request_id)
# Same for preempted
for req in stopped_preempted_reqs:
@@ -444,6 +470,8 @@ def update_from_output(
self.transfer_triggered_requests.remove(req.request_id)
if req.request_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req.request_id)
+ self.pending_stop_after_extraction.discard(req.request_id)
+
# KV Connector: update state for finished KV Transfers.
if kv_connector_output:
self._update_from_kv_xfer_finished(kv_connector_output)
@@ -489,35 +517,12 @@ def update_from_output(
engine_core_outputs[0] = eco = EngineCoreOutputs()
eco.scheduler_stats = stats
- # This is where we free blocks that were held for transfer
- try:
- kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None)
- if kv_extracted_ids:
- for req_id in kv_extracted_ids:
- # Emit a kv_ready signal so the orchestrator can forward
- # the request to the DiT stage immediately after KV
- # extraction, without waiting for AR decode to finish.
- req = self.requests.get(req_id)
- if req is not None and not req.is_finished():
- eco = engine_core_outputs.get(req.client_index)
- if eco is None:
- eco = EngineCoreOutputs()
- engine_core_outputs[req.client_index] = eco
- eco.outputs.append(
- EngineCoreOutput(
- request_id=req_id,
- new_token_ids=[],
- kv_transfer_params={"kv_ready": True},
- )
- )
-
- # Mark transfer as finished
- if req_id in self.active_kv_transfers:
- self.active_kv_transfers.remove(req_id)
- logger.debug(f"[Omni] KV Transfer finished for {req_id}")
-
+ # Free blocks that were held for transfer (kv_ready and
+ # active_kv_transfers updates already done before the per-request loop).
+ if kv_extracted_ids:
+ for req_id in kv_extracted_ids:
+ try:
if req_id in self.waiting_for_transfer_free:
- # Now it's safe to free blocks
req = self.requests.get(req_id)
if req:
self.kv_cache_manager.free(req)
@@ -525,13 +530,12 @@ def update_from_output(
del self.requests[req_id]
if req_id in self.transfer_triggered_requests:
self.transfer_triggered_requests.remove(req_id)
- if req_id in self.active_kv_transfers:
- self.active_kv_transfers.remove(req_id)
-
+ self.active_kv_transfers.discard(req_id)
+ self.pending_stop_after_extraction.discard(req_id)
logger.debug(f"Freed blocks for {req_id} after transfer extraction")
self.waiting_for_transfer_free.remove(req_id)
- except Exception:
- init_logger(__name__).exception("Failed to process finished transfer requests")
+ except Exception:
+ init_logger(__name__).exception("Failed to free blocks for %s after transfer", req_id)
return engine_core_outputs
@@ -564,8 +568,7 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di
kv_xfer_params = None
return kv_xfer_params
elif request_id in self.waiting_for_transfer_free:
- # Stopped immediately by stop_decode_on_trigger; blocks are
- # held until KV extraction completes in a future step.
+ # Blocks held until KV extraction completes in a future step.
return None
else:
logger.debug(
diff --git a/vllm_omni/model_executor/models/bagel/bagel.py b/vllm_omni/model_executor/models/bagel/bagel.py
index acbbc28b4cf..cbb775680cc 100644
--- a/vllm_omni/model_executor/models/bagel/bagel.py
+++ b/vllm_omni/model_executor/models/bagel/bagel.py
@@ -1,4 +1,3 @@
-from collections import deque
from collections.abc import Iterable, Mapping, Sequence
from math import isqrt
from typing import Any
@@ -442,14 +441,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._pending_img2img_info: list[tuple[int, int, int, int]] = []
self._ropes_pending: list[dict[str, Any]] = []
self._ropes_metadata: dict[str, dict[str, Any]] = {}
- self._cfg_companion_queue: deque[tuple[tuple[int, int, int, int], int]] = deque()
-
- # Per-request position offset for decode after img2img prefill.
- # Prefill rewrites positions (VAE→0, ViT→1, text→2..N) but the model
- # runner assigns decode positions starting from prefill_len, not N+1.
- # offset = rope - prefill_len (a negative number).
- self._pending_decode_offsets: list[int] = []
- self._decode_position_offsets: dict[str, int] = {}
+ self._last_img2img_info: tuple[int, int, int, int] | None = None
from transformers import AutoTokenizer
@@ -461,7 +453,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._start_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_start|>"))
self._end_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_end|>"))
self._img2img_token_id = int(_tok.convert_tokens_to_ids("<|fim_middle|>"))
-
self._vae_token_mask: torch.Tensor | None = None
self.device = get_local_device()
self._install_mot_modules(config)
@@ -540,9 +531,7 @@ def _clear_warmup_state(self):
self._ropes_pending.clear()
self._ropes_metadata.clear()
self._pending_img2img_info.clear()
- self._cfg_companion_queue.clear()
- self._pending_decode_offsets.clear()
- self._decode_position_offsets.clear()
+ self._last_img2img_info = None
self._vae_token_mask = None
def get_kv_transfer_metadata(
@@ -554,12 +543,10 @@ def get_kv_transfer_metadata(
meta = self._ropes_metadata.pop(req_id, None)
if meta is None:
return None
- # In think-mode img2img the prefill rope doesn't account for decoded
- # thinking tokens; correct it to num_computed_tokens + offset.
- # Skip correction when num_computed_tokens is unavailable (None).
- offset = self._decode_position_offsets.pop(req_id, 0)
- if offset != 0 and "ropes" in meta and num_computed_tokens is not None:
- meta["ropes"] = [num_computed_tokens + offset]
+ if num_computed_tokens is not None and "image_shape" in meta:
+ prefill_rope = meta["ropes"][0] if meta.get("ropes") else 0
+ if num_computed_tokens > prefill_rope:
+ meta["ropes"] = [num_computed_tokens]
return meta
def prepare_runner_inputs(
@@ -572,48 +559,29 @@ def prepare_runner_inputs(
num_scheduled_tokens: list[int],
input_ids_buffer: torch.Tensor | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
- """Model-runner hook: adjust inputs before ``forward()``.
-
- Returns ``(input_ids, positions)`` — possibly modified.
-
- Two adjustments for BAGEL img2img:
-
- 1. **Restore input_ids** when ``inputs_embeds`` is present so that
- ``_adjust_positions_for_img2img`` can locate the
- ``<|fim_middle|>`` placeholder.
- 2. **Decode position offset**: prefill rewrites positions to a
- compact scheme (rope ≪ prefill_len). The runner assigns decode
- positions from ``num_computed_tokens``, which is far too large;
- apply the stored per-request offset.
- """
+ """Restore input_ids so _adjust_positions_for_img2img can locate
+ the <|fim_middle|> placeholder for thinking-mode pre_text_len
+ detection."""
if inputs_embeds is not None and input_ids is None and input_ids_buffer is not None:
input_ids = input_ids_buffer
-
- if self._decode_position_offsets and positions is not None:
- token_start = 0
- for i, rid in enumerate(req_ids):
- sched = num_scheduled_tokens[i]
- offset = self._decode_position_offsets.get(rid, 0)
- if offset != 0 and num_computed_tokens[i] > 0:
- positions[token_start : token_start + sched] += offset
- token_start += sched
-
return input_ids, positions
def flush_pending_metadata(self, req_ids: list[str]) -> None:
- """Map pending metadata (batch order) to req_ids after forward()."""
+ """Map pending metadata (batch order) to req_ids after forward().
+
+ Guard: if a request already has metadata with ``image_shape``
+ (written during img2img prefill), don't overwrite it with
+ decode-step metadata that lacks ``image_shape``.
+ """
pending = self._ropes_pending
self._ropes_pending = []
for i, meta in enumerate(pending):
if i < len(req_ids):
- if req_ids[i] not in self._ropes_metadata:
- self._ropes_metadata[req_ids[i]] = meta
-
- pending_offsets = self._pending_decode_offsets
- self._pending_decode_offsets = []
- for i, offset in enumerate(pending_offsets):
- if i < len(req_ids) and offset != 0:
- self._decode_position_offsets[req_ids[i]] = offset
+ rid = req_ids[i]
+ existing = self._ropes_metadata.get(rid)
+ if existing and "image_shape" in existing and "image_shape" not in meta:
+ continue
+ self._ropes_metadata[rid] = meta
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
@@ -727,16 +695,7 @@ def _process_img2img_input(self, multimodal_input):
num_vit = vit_emb.shape[0] + 2
info = (num_vae, num_vit, int(H), int(W))
self._pending_img2img_info.append(info)
- # Only the gen (main) request should add a companion queue entry.
- # Companion requests (cfg_text, cfg_img) also call this method with
- # the same image, so guard by checking whether this exact info
- # tuple is already enqueued. For batched img2img with multiple
- # concurrent gen requests this correctly adds one entry per unique
- # image; images with identical (num_vae, num_vit, H, W) that arrive
- # in the same batch are indistinguishable here and will share one
- # entry, but that is an uncommon edge case.
- if not any(entry[0] == info for entry in self._cfg_companion_queue):
- self._cfg_companion_queue.append((info, 2)) # cfg_text + cfg_img
+ self._last_img2img_info = info
return tuple(results)
@@ -755,31 +714,18 @@ def forward(
positions = self._adjust_positions_for_img2img(positions, input_ids)
use_mot = True
- elif self._cfg_companion_queue:
- # Guard: if this looks like a pure decode step (small token count,
- # no multimodal embeddings), the queue has stale entries from a
- # previous prefill cycle — clear them instead of consuming.
- if inputs_embeds is None and seq_len <= 2:
- self._cfg_companion_queue.clear()
- else:
- cached, remaining = self._cfg_companion_queue[0]
- remaining -= 1
- num_vae, num_vit, img_H, img_W = cached
- num_img2img = num_vae + 1 + num_vit # +1 separator
- seq_len = inputs_embeds.shape[0] if inputs_embeds is not None else positions.shape[0]
-
- if inputs_embeds is not None and seq_len >= num_img2img:
- self._pending_img2img_info = [cached]
- positions = self._adjust_positions_for_img2img(positions, input_ids)
- use_mot = True
- else:
- rope = int(positions[seq_len - 1].item()) + 1
- self._ropes_pending.append({"ropes": [rope]})
+ elif self._last_img2img_info is not None:
+ info = self._last_img2img_info
+ num_vae, num_vit, _, _ = info
+ num_img2img = num_vae + 1 + num_vit
- if remaining == 0:
- self._cfg_companion_queue.popleft()
- else:
- self._cfg_companion_queue[0] = (cached, remaining)
+ if seq_len >= num_img2img:
+ self._pending_img2img_info = [info]
+ positions = self._adjust_positions_for_img2img(positions, input_ids)
+ use_mot = True
+ else:
+ rope = int(positions[seq_len - 1].item()) + 1
+ self._ropes_pending.append({"ropes": [rope]})
if use_mot:
return self._mot_forward(input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs)
@@ -790,27 +736,18 @@ def _adjust_positions_for_img2img(
positions: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
- """Rewrite position IDs to match the original BAGEL position scheme:
-
- If there are ``pre_text_len`` text tokens before the img2img block::
-
- pre_text → 0, 1, ..., M-1
- VAE → M (all share)
- separator→ M
- ViT → M+1 (all share)
- post_text→ M+2, M+3, ...
+ """Rewrite position IDs for img2img.
- When no text precedes the img2img block (M=0), this reduces to the
- simpler scheme: VAE→0, ViT→1, text→2, 3, ...
+ Supports an optional ``pre_text_len`` prefix (thinking-mode) detected
+ via the ``<|fim_middle|>`` token in *input_ids*:
- Also computes ``self._vae_token_mask`` (bool tensor, True for actual
- VAE latent patches that should use gen-mode weights) and pushes
- per-request ropes + image_shape to the FIFO consumed by
- ``get_kv_transfer_metadata``.
+ pre_text -> 0 .. M-1
+ VAE -> M (all share)
+ separator-> M
+ ViT -> M+1 (all share)
+ post_text-> M+2, M+3, ...
- For img2img requests, also stores a decode position offset so that
- subsequent autoregressive decode steps use positions that continue
- from the rewritten scheme rather than from the original prefill length.
+ When M=0 (standard img2img) this reduces to VAE->0, ViT->1, text->2..
"""
info_list = self._pending_img2img_info
self._pending_img2img_info = []
@@ -836,70 +773,64 @@ def _adjust_positions_for_img2img(
req_len = end - start
if img2img_idx < len(info_list):
- num_vae, num_vit, img_H, img_W = info_list[img2img_idx]
+ cur_info = info_list[img2img_idx]
+ elif self._last_img2img_info is not None:
+ cur_info = self._last_img2img_info
+ else:
+ cur_info = None
+
+ if cur_info is not None:
+ num_vae, num_vit, img_H, img_W = cur_info
num_img2img = num_vae + 1 + num_vit # +1 separator
if req_len >= num_img2img:
- # Detect offset of img2img tokens within this request
- # by searching for the img2img placeholder token ID.
pre_text_len = 0
if input_ids is not None:
- req_ids = input_ids[start:end]
- mask = req_ids == self._img2img_token_id
- indices = mask.nonzero(as_tuple=True)[0]
+ req_ids_slice = input_ids[start:end]
+ indices = (req_ids_slice == self._img2img_token_id).nonzero(as_tuple=True)[0]
if indices.numel() > 0:
pre_text_len = int(indices[0].item())
- img_start = start + pre_text_len
+ M = pre_text_len
+ img_start = start + M
post_text_start = img_start + num_img2img
- # pre_text_pos: position base for image tokens
- pre_text_pos = pre_text_len
- # Pre-image text: sequential positions 0..pre_text_pos-1
- if pre_text_len > 0:
+ if M > 0:
new_positions[start:img_start] = torch.arange(
- 0, pre_text_pos, device=positions.device, dtype=positions.dtype
+ 0, M, device=positions.device, dtype=positions.dtype
)
- # VAE tokens: all share position pre_text_pos
- new_positions[img_start : img_start + num_vae] = pre_text_pos
- # Separator: position pre_text_pos
- new_positions[img_start + num_vae] = pre_text_pos
- # ViT tokens: all share position pre_text_pos+1
+ new_positions[img_start : img_start + num_vae] = M
+ new_positions[img_start + num_vae] = M # separator
vit_start = img_start + num_vae + 1
- new_positions[vit_start : vit_start + num_vit] = pre_text_pos + 1
+ new_positions[vit_start : vit_start + num_vit] = M + 1
- # Post-image text: sequential positions pre_text_pos+2, pre_text_pos+3, ...
num_post_text = end - post_text_start
if num_post_text > 0:
new_positions[post_text_start:end] = torch.arange(
- pre_text_pos + 2,
- pre_text_pos + 2 + num_post_text,
+ M + 2,
+ M + 2 + num_post_text,
device=positions.device,
dtype=positions.dtype,
)
- # VAE gen-mode mask: only actual VAE latent patches (not markers)
- vae_patches_start = img_start + 1 # skip start_marker
- vae_patches_end = img_start + num_vae - 1 # before end_marker
+ vae_patches_start = img_start + 1
+ vae_patches_end = img_start + num_vae - 1
if vae_patches_end > vae_patches_start:
vae_mask[vae_patches_start:vae_patches_end] = True
- rope = pre_text_pos + 2 + num_post_text
+ rope = M + 2 + num_post_text
self._ropes_pending.append(
{
"ropes": [rope],
"image_shape": [img_H, img_W],
}
)
- decode_offset = rope - req_len
- self._pending_decode_offsets.append(decode_offset)
img2img_idx += 1
continue
rope = int(new_positions[end - 1].item()) + 1
self._ropes_pending.append({"ropes": [rope]})
- self._pending_decode_offsets.append(0)
self._vae_token_mask = vae_mask if vae_mask.any() else None
return new_positions
diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
index 138948064ba..ffb997048bd 100644
--- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
+++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
@@ -149,7 +149,15 @@ def execute_model(
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
- return make_empty_encoder_model_runner_output(scheduler_output)
+
+ kv_ids = self.kv_extracted_req_ids
+ self.kv_extracted_req_ids = None
+
+ output = make_empty_encoder_model_runner_output(scheduler_output)
+ if kv_ids:
+ output = copy(output)
+ output.kv_extracted_req_ids = kv_ids
+ return output
if not num_scheduled_tokens:
if (
@@ -163,10 +171,20 @@ def execute_model(
# dummy run to ensure coordinate_batch_across_dp
# is called into to avoid out of sync issues.
self._dummy_run(1)
+
+ kv_ids = self.kv_extracted_req_ids
+ self.kv_extracted_req_ids = None
+
if not has_kv_transfer_group():
- # Return empty ModelRunnerOutput if no work to do.
- return EMPTY_MODEL_RUNNER_OUTPUT
- return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
+ output = EMPTY_MODEL_RUNNER_OUTPUT
+ else:
+ output = self.kv_connector_no_forward(scheduler_output, self.vllm_config)
+
+ if kv_ids:
+ output = copy(output)
+ output.kv_extracted_req_ids = kv_ids
+
+ return output
if self.cache_config.kv_sharing_fast_prefill:
assert not self.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect "
diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py
index 01ec23acb47..554ac6355de 100644
--- a/vllm_omni/worker/gpu_ar_model_runner.py
+++ b/vllm_omni/worker/gpu_ar_model_runner.py
@@ -205,24 +205,39 @@ def execute_model(
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
- return make_empty_encoder_model_runner_output(scheduler_output)
+
+ kv_ids = self.kv_extracted_req_ids
+ self.kv_extracted_req_ids = None
+
+ output = make_empty_encoder_model_runner_output(scheduler_output)
+ if kv_ids:
+ output = copy(output)
+ output.kv_extracted_req_ids = kv_ids
+ return output
if not num_scheduled_tokens:
if (
self.parallel_config.distributed_executor_backend == "external_launcher"
and self.parallel_config.data_parallel_size > 1
):
- # this is a corner case when both external launcher
- # and DP are enabled, num_scheduled_tokens could be
- # 0, and has_unfinished_requests in the outer loop
- # returns True. before returning early here we call
- # dummy run to ensure coordinate_batch_across_dp
- # is called into to avoid out of sync issues.
self._dummy_run(1)
+
+ # Capture KV extraction results before early return;
+ # sample_tokens() is skipped on this path so the IDs
+ # would otherwise be silently overwritten next step.
+ kv_ids = self.kv_extracted_req_ids
+ self.kv_extracted_req_ids = None
+
if not has_kv_transfer_group():
- # Return empty ModelRunnerOutput if no work to do.
- return EMPTY_MODEL_RUNNER_OUTPUT
- return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
+ output = EMPTY_MODEL_RUNNER_OUTPUT
+ else:
+ output = self.kv_connector_no_forward(scheduler_output, self.vllm_config)
+
+ if kv_ids:
+ output = copy(output)
+ output.kv_extracted_req_ids = kv_ids
+
+ return output
if self.cache_config.kv_sharing_fast_prefill:
assert not self.num_prompt_logprobs, (
From cb4d13a65806d18337628da0768539ba97c6cd4d Mon Sep 17 00:00:00 2001
From: Sy03 <1370724210@qq.com>
Date: Mon, 13 Apr 2026 12:53:35 +0800
Subject: [PATCH 07/76] [Perf][Fish Speech] Enable CUDA Graph capture for Fast
AR code predictor (#2520)
Signed-off-by: Sy03 <1370724210@qq.com>
---
.../models/fish_speech/fish_speech_fast_ar.py | 22 +++++--
.../models/fish_speech/fish_speech_slow_ar.py | 39 ++++++------
vllm_omni/worker/gpu_ar_model_runner.py | 62 +++++++++++++++++++
vllm_omni/worker/gpu_model_runner.py | 6 +-
4 files changed, 99 insertions(+), 30 deletions(-)
diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py
index 8bbb643ebec..22a2744ff5d 100644
--- a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py
+++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py
@@ -310,6 +310,7 @@ def __init__(
self._compiled_model_fwd: object | None = None
self._compile_attempted = False
self._compile_failed = False
+ self._disable_compile_for_graph = False
def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None:
max_seq = self._num_codebooks + 1 # hidden_state + num_codebooks codes
@@ -327,11 +328,20 @@ def _setup_compile(self) -> None:
if self._compile_attempted:
return
self._compile_attempted = True
+ if self._disable_compile_for_graph:
+ try:
+ self._compiled_model_fwd = torch.compile(
+ self.model.forward,
+ dynamic=True,
+ options={"epilogue_fusion": False},
+ )
+ except Exception as exc:
+ logger.warning("Fast AR torch.compile (graph mode) failed: %s", exc)
+ self._compiled_model_fwd = self.model.forward
+ return
try:
self._compiled_model_fwd = torch.compile(
self.model.forward,
- # Keep the helper compiler separate from vLLM's outer
- # cudagraph-managed Stage-0 execution.
mode="default",
dynamic=True,
fullgraph=False,
@@ -366,10 +376,10 @@ def warmup_compile(
@torch.inference_mode()
def _run_model(self, step_input: torch.Tensor, step_pos_ids: torch.Tensor, bsz: int) -> torch.Tensor:
- # Default-on compile only pays off for single-request decode. For
- # batched decode, eager preserves loaded throughput and avoids the
- # regression seen with batch>1 compiled execution.
- model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward
+ if self._disable_compile_for_graph:
+ model_fwd = self._compiled_model_fwd or self.model.forward
+ else:
+ model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward
try:
return model_fwd(step_input, step_pos_ids)
except Exception as exc:
diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py
index 3813597caad..62776cbb31f 100644
--- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py
+++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py
@@ -194,6 +194,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.has_postprocess = True
self.mtp_hidden_size = int(self.text_config.hidden_size)
self.talker_mtp_output_key = "audio_codes"
+ self.talker_mtp_graph_safe = True
self.gpu_resident_buffer_keys: set[str] = {"last_slow_ar_hidden"}
# Qwen3 transformer backbone.
@@ -236,6 +237,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
slow_ar_config=self.text_config,
prefix="fast_ar",
)
+ if self.talker_mtp_graph_safe:
+ self.fast_ar._disable_compile_for_graph = True
# Constant logit mask: allow only semantic tokens + im_end.
vocab = int(self.text_config.vocab_size)
@@ -680,18 +683,13 @@ def talker_mtp(
inputs_embeds_out = input_embeds.reshape(bsz, -1).clone()
semantic_mask = (input_ids[:, 0] >= self._semantic_begin_id) & (input_ids[:, 0] <= self._semantic_end_id)
- if semantic_mask.any():
- semantic_codes = audio_codes[semantic_mask].clamp(min=0)
- offsets = (
- torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size
- ).unsqueeze(0)
- codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16)
-
- # Normalize by sqrt(num_codebooks + 1) as in the reference model
- # (scale_codebook_embeddings=True for fish_qwen3_omni).
- inputs_embeds_out[semantic_mask] = (inputs_embeds_out[semantic_mask] + codebook_sum) / math.sqrt(
- self._num_codebooks + 1
- )
+ semantic_codes = audio_codes.clamp(min=0, max=self._codebook_size - 1)
+ offsets = (
+ torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size
+ ).unsqueeze(0)
+ codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16)
+ norm_embeds = (inputs_embeds_out + codebook_sum) / math.sqrt(self._num_codebooks + 1)
+ inputs_embeds_out = torch.where(semantic_mask.unsqueeze(-1), norm_embeds, inputs_embeds_out)
return inputs_embeds_out, audio_codes.to(dtype=torch.long)
@@ -802,14 +800,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if truncated:
logger.info("Truncated %d RoPE cos_sin_cache buffers to bf16 precision", truncated)
- try:
- self.fast_ar.warmup_compile(
- device=self.codebook_embeddings.weight.device,
- dtype=torch.bfloat16,
- batch_sizes=(1,),
- )
- except Exception as exc:
- logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc)
+ if not getattr(self, "talker_mtp_graph_safe", False):
+ try:
+ self.fast_ar.warmup_compile(
+ device=self.codebook_embeddings.weight.device,
+ dtype=torch.bfloat16,
+ batch_sizes=(1,),
+ )
+ except Exception as exc:
+ logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc)
codec_device = self.codebook_embeddings.weight.device
_load_dac_codec(
diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py
index 554ac6355de..72e745fb172 100644
--- a/vllm_omni/worker/gpu_ar_model_runner.py
+++ b/vllm_omni/worker/gpu_ar_model_runner.py
@@ -138,6 +138,68 @@ def _sampling_metadata_for_model_sampler(self, sampling_metadata):
return sampling_metadata
return replace(sampling_metadata, output_token_ids=output_token_ids)
+ def capture_model(self) -> int:
+ result = super().capture_model()
+ self._capture_talker_mtp_graphs()
+ return result
+
+ def _capture_talker_mtp_graphs(self) -> None:
+ from vllm_omni.worker.gpu_model_runner import CUDAGraphWrapper
+
+ if not self.has_talker_mtp or not isinstance(self.talker_mtp, CUDAGraphWrapper):
+ return
+
+ from vllm.compilation.monitor import set_cudagraph_capturing_enabled
+ from vllm.distributed.parallel_state import graph_capture
+
+ capture_sizes = self.compilation_config.cudagraph_capture_sizes
+ num_warmups = self.compilation_config.cudagraph_num_of_warmups
+ capture_sizes = sorted(capture_sizes, reverse=True)
+ logger.info("Capturing talker_mtp graphs for sizes %s", capture_sizes)
+
+ set_cudagraph_capturing_enabled(True)
+ try:
+ with torch.inference_mode(), graph_capture(device=self.device):
+ for bsz in capture_sizes:
+ _, batch_desc, _, _, _ = self._determine_batch_execution_and_padding(
+ num_tokens=bsz,
+ num_reqs=bsz,
+ num_scheduled_tokens_np=np.ones(bsz, dtype=np.int32),
+ max_num_scheduled_tokens=1,
+ use_cascade_attn=False,
+ )
+ n = batch_desc.num_tokens
+ ids = self.talker_mtp_input_ids.gpu[:n]
+ emb = self.talker_mtp_inputs_embeds.gpu[:n]
+ hid = self.last_talker_hidden.gpu[:n]
+ ts = self.text_step.gpu[:n]
+
+ for _ in range(num_warmups):
+ with set_forward_context(
+ None,
+ self.vllm_config,
+ cudagraph_runtime_mode=CUDAGraphMode.NONE,
+ batch_descriptor=batch_desc,
+ ):
+ self.talker_mtp(ids, emb, hid, ts)
+
+ with set_forward_context(
+ None,
+ self.vllm_config,
+ cudagraph_runtime_mode=CUDAGraphMode.FULL,
+ batch_descriptor=batch_desc,
+ ):
+ self.talker_mtp(ids, emb, hid, ts)
+ torch.cuda.synchronize()
+
+ logger.info("Captured talker_mtp graphs for %d sizes", len(capture_sizes))
+ except RuntimeError as e:
+ raise RuntimeError(
+ f"talker_mtp graph capture failed for a model that declared talker_mtp_graph_safe=True: {e}"
+ ) from e
+ finally:
+ set_cudagraph_capturing_enabled(False)
+
@torch.inference_mode()
def execute_model(
self,
diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py
index 35e15984355..1f678b579fa 100644
--- a/vllm_omni/worker/gpu_model_runner.py
+++ b/vllm_omni/worker/gpu_model_runner.py
@@ -83,11 +83,9 @@ def load_model(self, *args, **kwargs) -> None:
self.has_talker_mtp = True
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
- # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that
- # have a separate .talker sub-module. TTS models' code predictor
- # has internal AR loops / torch.multinomial — not graph-safe.
has_separate_talker = getattr(self.model, "talker", None) is not None
- if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
+ talker_mtp_graph_safe = getattr(self.model, "talker_mtp_graph_safe", False)
+ if cudagraph_mode.has_full_cudagraphs() and (has_separate_talker or talker_mtp_graph_safe):
self.talker_mtp = CUDAGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL)
# TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size.
hidden_size = int(
From 8097747a5dc0d90f267050ae4b77d53bbaea88ae Mon Sep 17 00:00:00 2001
From: Jiaqian Liu <61532106+Celeste-jq@users.noreply.github.com>
Date: Mon, 13 Apr 2026 14:20:04 +0800
Subject: [PATCH 08/76] [Model] Adapt Wan2.2-I2V-A14B via LightX2V offline
conversion path (#2134)
Signed-off-by: Celeste-jq <591998922@qq.com>
Co-authored-by: Canlin Guo
---
docs/user_guide/diffusion/lora.md | 86 ++++
.../offline_inference/image_to_video.md | 6 +-
.../image_to_video/README.md | 6 +-
.../image_to_video/image_to_video.py | 13 +
.../online_serving/image_to_video/README.md | 49 +++
.../image_to_video/run_curl_image_to_video.sh | 5 +
.../openai_api/test_video_server.py | 22 +
tools/wan22/assemble_wan22_i2v_diffusers.py | 385 ++++++++++++++++++
.../models/wan2_2/pipeline_wan2_2.py | 58 ++-
.../models/wan2_2/pipeline_wan2_2_i2v.py | 21 +-
.../models/wan2_2/pipeline_wan2_2_ti2v.py | 21 +-
.../models/wan2_2/scheduling_wan_euler.py | 147 +++++++
.../models/wan2_2/wan2_2_transformer.py | 8 +
vllm_omni/engine/async_omni_engine.py | 2 +
14 files changed, 804 insertions(+), 25 deletions(-)
create mode 100644 tools/wan22/assemble_wan22_i2v_diffusers.py
create mode 100644 vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py
diff --git a/docs/user_guide/diffusion/lora.md b/docs/user_guide/diffusion/lora.md
index e45c033b848..256698752a1 100644
--- a/docs/user_guide/diffusion/lora.md
+++ b/docs/user_guide/diffusion/lora.md
@@ -56,6 +56,92 @@ outputs = omni.generate(
!!! note "Server-side Path Requirement"
The LoRA adapter path (`local_path`) must be readable on the **server** machine. If your client and server are on different machines, ensure the LoRA adapter is accessible via a shared mount or copied to the server.
+## Wan2.2 LightX2V Offline Assembly
+
+This workflow is LoRA-adjacent: it uses external LightX2V conversion plus
+`Wan2.2-Distill-Loras` to bake converted Wan2.2 I2V checkpoints into a local
+Diffusers directory, instead of loading LoRA adapters at runtime.
+
+### Required assets
+
+- Base model: `Wan-AI/Wan2.2-I2V-A14B`
+- Diffusers skeleton: `Wan-AI/Wan2.2-I2V-A14B-Diffusers`
+- Optional external converter from the LightX2V project (not shipped in this repository)
+- Optional LoRA weights: `lightx2v/Wan2.2-Distill-Loras`
+
+### Step 1: Optional - convert high/low-noise DiT weights with LightX2V
+
+Install or clone LightX2V from the upstream repository
+(`https://github.com/ModelTC/LightX2V`). After cloning, the converter used
+below is available at `/tools/convert/converter.py`.
+
+```bash
+python /path/to/lightx2v/tools/convert/converter.py \
+ --source /path/to/Wan2.2-I2V-A14B/high_noise_model \
+ --output /tmp/wan22_lightx2v/high_noise_out \
+ --output_ext .safetensors \
+ --output_name diffusion_pytorch_model \
+ --model_type wan_dit \
+ --direction forward \
+ --lora_path /path/to/wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors \
+ --lora_key_convert auto \
+ --single_file
+
+python /path/to/lightx2v/tools/convert/converter.py \
+ --source /path/to/Wan2.2-I2V-A14B/low_noise_model \
+ --output /tmp/wan22_lightx2v/low_noise_out \
+ --output_ext .safetensors \
+ --output_name diffusion_pytorch_model \
+ --model_type wan_dit \
+ --direction forward \
+ --lora_path /path/to/wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors \
+ --lora_key_convert auto \
+ --single_file
+```
+
+If you are not using LightX2V, skip this step and either keep the original
+Diffusers weights from the skeleton or point Step 2 at any other converted
+`transformer/` and `transformer_2/` checkpoints.
+
+### Step 2: Assemble a final Diffusers-style directory
+
+```bash
+python tools/wan22/assemble_wan22_i2v_diffusers.py \
+ --diffusers-skeleton /path/to/Wan2.2-I2V-A14B-Diffusers \
+ --transformer-weight /tmp/wan22_lightx2v/high_noise_out \
+ --transformer-2-weight /tmp/wan22_lightx2v/low_noise_out \
+ --output-dir /path/to/Wan2.2-I2V-A14B-Custom-Diffusers \
+ --asset-mode symlink \
+ --overwrite
+```
+
+`--transformer-weight` and `--transformer-2-weight` are optional. If you omit
+them, the tool keeps the original weights from the Diffusers skeleton.
+
+### Step 3: Run offline inference
+
+```bash
+python examples/offline_inference/image_to_video/image_to_video.py \
+ --model /path/to/Wan2.2-I2V-A14B-Custom-Diffusers \
+ --image /path/to/input.jpg \
+ --prompt "A cat playing with yarn" \
+ --num-frames 81 \
+ --num-inference-steps 4 \
+ --tensor-parallel-size 4 \
+ --height 480 \
+ --width 832 \
+ --flow-shift 12 \
+ --sample-solver euler \
+ --guidance-scale 1.0 \
+ --guidance-scale-high 1.0 \
+ --boundary-ratio 0.875
+```
+
+Notes:
+
+- This route avoids runtime LoRA loading changes in vLLM-Omni when you choose to bake converted weights into a local Diffusers directory.
+- Output quality and speed depend on the replacement checkpoints and sampling params you choose.
+
## See Also
diff --git a/docs/user_guide/examples/offline_inference/image_to_video.md b/docs/user_guide/examples/offline_inference/image_to_video.md
index 7a750aeff3b..6e105741a7e 100644
--- a/docs/user_guide/examples/offline_inference/image_to_video.md
+++ b/docs/user_guide/examples/offline_inference/image_to_video.md
@@ -62,12 +62,13 @@ Key arguments:
- `--negative-prompt`: Optional list of artifacts to suppress.
- `--boundary-ratio`: Boundary split ratio for two-stage MoE models.
- `--flow-shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p).
+- `--sample-solver`: Wan2.2 sampling solver. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints.
- `--num-inference-steps`: Number of denoising steps (default 50).
- `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video).
- `--output`: Path to save the generated video.
- `--vae-use-slicing`: Enable VAE slicing for memory optimization.
- `--vae-use-tiling`: Enable VAE tiling for memory optimization.
-- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel).
+- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism/cfg_parallel.md).
- `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2).
- `--enable-cpu-offload`: enable CPU offloading for diffusion models.
- `--use-hsdp`: Enable Hybrid Sharded Data Parallel to shard model weights across GPUs.
@@ -78,6 +79,9 @@ Key arguments:
> ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage.
+For Wan2.2 LightX2V-converted local Diffusers directories and related LoRA
+assets, see the [LoRA guide](../../diffusion/lora.md#wan22-lightx2v-offline-assembly).
+
## Example materials
??? abstract "image_to_video.py"
diff --git a/examples/offline_inference/image_to_video/README.md b/examples/offline_inference/image_to_video/README.md
index 2692c76df26..a458850a02b 100644
--- a/examples/offline_inference/image_to_video/README.md
+++ b/examples/offline_inference/image_to_video/README.md
@@ -59,12 +59,13 @@ Key arguments:
- `--negative-prompt`: Optional list of artifacts to suppress.
- `--boundary-ratio`: Boundary split ratio for two-stage MoE models.
- `--flow-shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p).
+- `--sample-solver`: Wan2.2 sampling solver. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints.
- `--num-inference-steps`: Number of denoising steps (default 50).
- `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video).
- `--output`: Path to save the generated video.
- `--vae-use-slicing`: Enable VAE slicing for memory optimization.
- `--vae-use-tiling`: Enable VAE tiling for memory optimization.
-- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel).
+- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism/cfg_parallel.md).
- `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2).
- `--enable-cpu-offload`: enable CPU offloading for diffusion models.
- `--use-hsdp`: Enable Hybrid Sharded Data Parallel to shard model weights across GPUs.
@@ -74,3 +75,6 @@ Key arguments:
> ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage.
+
+For Wan2.2 LightX2V-converted local Diffusers directories and related LoRA
+assets, see the [LoRA guide](../../../docs/user_guide/diffusion/lora.md#wan22-lightx2v-offline-assembly).
diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py
index 7e7cfbf84e8..53319c82211 100644
--- a/examples/offline_inference/image_to_video/image_to_video.py
+++ b/examples/offline_inference/image_to_video/image_to_video.py
@@ -84,6 +84,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--flow-shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)."
)
+ parser.add_argument(
+ "--sample-solver",
+ type=str,
+ default="unipc",
+ choices=["unipc", "euler"],
+ help="Sampling solver for Wan2.2 pipelines. Use 'euler' for Lightning/Distill setups.",
+ )
parser.add_argument("--output", type=str, default="i2v_output.mp4", help="Path to save the video (mp4).")
parser.add_argument("--fps", type=int, default=None, help="Frames per second for the output video.")
parser.add_argument(
@@ -305,6 +312,7 @@ def main():
print(f" Model: {args.model}")
print(f" Inference steps: {args.num_inference_steps}")
print(f" Frames: {args.num_frames}")
+ print(f" Solver: {args.sample_solver}")
print(
f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size},"
f" tensor_parallel_size={args.tensor_parallel_size}, vae_patch_parallel_size={args.vae_patch_parallel_size}"
@@ -326,9 +334,14 @@ def main():
generator=generator,
guidance_scale=guidance_scale,
guidance_scale_2=args.guidance_scale_high,
+ boundary_ratio=args.boundary_ratio,
num_inference_steps=num_inference_steps,
num_frames=num_frames,
frame_rate=frame_rate,
+ extra_args={
+ "sample_solver": args.sample_solver,
+ "flow_shift": args.flow_shift,
+ },
),
)
generation_end = time.perf_counter()
diff --git a/examples/online_serving/image_to_video/README.md b/examples/online_serving/image_to_video/README.md
index 49283bd9a06..285eeb27983 100644
--- a/examples/online_serving/image_to_video/README.md
+++ b/examples/online_serving/image_to_video/README.md
@@ -26,6 +26,23 @@ The script allows overriding:
- `CACHE_BACKEND` (default: `none`)
- `ENABLE_CACHE_DIT_SUMMARY` (default: `0`)
+### Ascend / Local LightX2V Example
+
+For a local Wan2.2-LightX2V Diffusers directory on Ascend/NPU, you can start the server like this:
+
+```bash
+vllm serve /path/to/Wan2.2-I2V-A14B-LightX2V-Diffusers-Lightning \
+ --omni \
+ --port 8091 \
+ --flow-shift 12 \
+ --cfg-parallel-size 1 \
+ --ulysses-degree 4 \
+ --use-hsdp \
+ --trust-remote-code \
+ --allowed-local-media-path / \
+ --seed 42
+```
+
## Async Job Behavior
`POST /v1/videos` is asynchronous. It creates a video job and immediately
@@ -69,10 +86,35 @@ curl -X POST http://localhost:8091/v1/videos/sync \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
+ -F 'extra_params={"sample_solver":"euler"}' \
-F "seed=42" \
-o sync_i2v_output.mp4
```
+For Wan Lightning/Distill checkpoints, pass `{"sample_solver":"euler"}` via `extra_params`. The default solver is `unipc`.
+
+Example matching the local LightX2V deployment above:
+
+```bash
+curl -sS -X POST http://localhost:8091/v1/videos/sync \
+ -H "Accept: video/mp4" \
+ -F "prompt=A cat playing with yarn" \
+ -F "input_reference=@/path/to/input.jpg" \
+ -F "width=832" \
+ -F "height=480" \
+ -F "num_frames=81" \
+ -F "fps=16" \
+ -F "num_inference_steps=4" \
+ -F "guidance_scale=1.0" \
+ -F "guidance_scale_2=1.0" \
+ -F "boundary_ratio=0.875" \
+ -F "seed=42" \
+ -F 'extra_params={"sample_solver":"euler"}' \
+ -o ./output.mp4
+```
+
+Use `/v1/videos/sync` if you want to write the MP4 directly to a file. `POST /v1/videos` is async and returns job metadata, not inline `b64_json`.
+
## Storage
Generated video files are stored on local disk by the async video API.
@@ -96,6 +138,9 @@ export VLLM_OMNI_STORAGE_MAX_CONCURRENCY=8
# Basic image-to-video generation
bash run_curl_image_to_video.sh
+# Wan Lightning/Distill checkpoints
+SAMPLE_SOLVER=euler bash run_curl_image_to_video.sh
+
# Or execute directly (OpenAI-style multipart)
create_response=$(curl -s http://localhost:8091/v1/videos \
-H "Accept: application/json" \
@@ -111,6 +156,7 @@ create_response=$(curl -s http://localhost:8091/v1/videos \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
+ -F 'extra_params={"sample_solver":"euler"}' \
-F "seed=42")
video_id=$(echo "$create_response" | jq -r '.id')
@@ -169,9 +215,12 @@ curl -X POST http://localhost:8091/v1/videos \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
+ -F 'extra_params={"sample_solver":"euler"}' \
-F "seed=42"
```
+`sample_solver` is supported by Wan2.2 online serving through the existing `extra_params` field, which is merged into the pipeline `extra_args`. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints.
+
## Create Response Format
`POST /v1/videos` returns a job record, not inline base64 video data.
diff --git a/examples/online_serving/image_to_video/run_curl_image_to_video.sh b/examples/online_serving/image_to_video/run_curl_image_to_video.sh
index f4c1496a69a..6f6a6f96d59 100644
--- a/examples/online_serving/image_to_video/run_curl_image_to_video.sh
+++ b/examples/online_serving/image_to_video/run_curl_image_to_video.sh
@@ -7,6 +7,7 @@ INPUT_IMAGE="${INPUT_IMAGE:-../../offline_inference/image_to_video/qwen-bear.png
BASE_URL="${BASE_URL:-http://localhost:8099}"
OUTPUT_PATH="${OUTPUT_PATH:-wan22_i2v_output.mp4}"
NEGATIVE_PROMPT="${NEGATIVE_PROMPT:-}"
+SAMPLE_SOLVER="${SAMPLE_SOLVER:-}"
POLL_INTERVAL="${POLL_INTERVAL:-2}"
if [ ! -f "$INPUT_IMAGE" ]; then
@@ -34,6 +35,10 @@ if [ -n "${NEGATIVE_PROMPT}" ]; then
create_cmd+=(-F "negative_prompt=${NEGATIVE_PROMPT}")
fi
+if [ -n "${SAMPLE_SOLVER}" ]; then
+ create_cmd+=(-F "extra_params={\"sample_solver\":\"${SAMPLE_SOLVER}\"}")
+fi
+
create_response="$("${create_cmd[@]}")"
video_id="$(echo "${create_response}" | jq -r '.id')"
if [ -z "${video_id}" ] || [ "${video_id}" = "null" ]; then
diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py
index 0fdee7a77a8..fd7d4df60da 100644
--- a/tests/entrypoints/openai_api/test_video_server.py
+++ b/tests/entrypoints/openai_api/test_video_server.py
@@ -766,6 +766,28 @@ def test_extra_params_merged_with_existing_extra_args(test_client, mocker: Mocke
assert captured.extra_args["zero_steps"] == 2
+def test_sample_solver_forwarded_via_extra_params(test_client, mocker: MockerFixture):
+ """sample_solver can be passed through existing extra_params for Wan2.2 online serving."""
+ mocker.patch(
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ return_value="Zg==",
+ )
+ response = test_client.post(
+ "/v1/videos",
+ data={
+ "prompt": "A fox running through snow.",
+ "extra_params": json.dumps({"sample_solver": "euler"}),
+ },
+ )
+
+ assert response.status_code == 200
+ video_id = response.json()["id"]
+ _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
+ engine = test_client.app.state.openai_serving_video._engine_client
+ captured = engine.captured_sampling_params_list[0]
+ assert captured.extra_args["sample_solver"] == "euler"
+
+
# ---------------------------------------------------------------------------
# Sync endpoint tests (POST /v1/videos/sync)
# ---------------------------------------------------------------------------
diff --git a/tools/wan22/assemble_wan22_i2v_diffusers.py b/tools/wan22/assemble_wan22_i2v_diffusers.py
new file mode 100644
index 00000000000..8e14ca3c26d
--- /dev/null
+++ b/tools/wan22/assemble_wan22_i2v_diffusers.py
@@ -0,0 +1,385 @@
+#!/usr/bin/env python3
+"""
+Assemble a Wan2.2-I2V-A14B-Diffusers-style model directory using a Diffusers
+skeleton and optional replacement transformer checkpoints.
+
+This tool does NOT run any external conversion step. You can use it in two
+ways:
+- keep the original weights from the Diffusers skeleton
+- replace transformer/transformer_2 with converted checkpoints such as
+ LightX2V outputs
+- use legacy LightX2V arg names (--high-noise-weight/--low-noise-weight),
+ which are accepted as aliases
+
+Typical use:
+ python tools/wan22/assemble_wan22_i2v_diffusers.py \
+ --diffusers-skeleton /path/to/Wan2.2-I2V-A14B-Diffusers \
+ --transformer-weight /path/to/high_noise_out/diffusion_pytorch_model.safetensors \
+ --transformer-2-weight /path/to/low_noise_out/diffusion_pytorch_model.safetensors \
+ --output-dir /path/to/Wan2.2-I2V-A14B-Custom-Diffusers
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import shutil
+import sys
+from dataclasses import dataclass
+from pathlib import Path
+
+WEIGHT_CANDIDATES = (
+ "diffusion_pytorch_model.safetensors",
+ "diffusion_pytorch_model.bin",
+ "diffusion_pytorch_model.pt",
+ "model.safetensors",
+ "pytorch_model.bin",
+ "model.pt",
+)
+WEIGHT_INDEX_CANDIDATES = (
+ "diffusion_pytorch_model.safetensors.index.json",
+ "model.safetensors.index.json",
+ "pytorch_model.bin.index.json",
+)
+
+ROOT_REQUIRED_FILES = ("model_index.json",)
+ROOT_REQUIRED_DIRS = ("tokenizer", "text_encoder", "vae", "transformer", "transformer_2")
+OPTIONAL_DIRS = ("image_encoder", "image_processor", "scheduler", "feature_extractor")
+
+
+class AssembleError(RuntimeError):
+ pass
+
+
+@dataclass(frozen=True)
+class WeightSpec:
+ kind: str # "single" | "sharded"
+ single_file: Path | None = None
+ index_file: Path | None = None
+ shard_files: tuple[Path, ...] = ()
+
+
+def _load_shard_files_from_index(index_file: Path, role: str) -> tuple[Path, ...]:
+ try:
+ with index_file.open(encoding="utf-8") as f:
+ payload = json.load(f)
+ except Exception as exc:
+ raise AssembleError(f"Failed to parse {role} index file: {index_file}. error={exc}") from exc
+
+ weight_map = payload.get("weight_map")
+ if not isinstance(weight_map, dict) or not weight_map:
+ raise AssembleError(f"Invalid {role} index file (missing/empty weight_map): {index_file}")
+
+ shard_names = sorted({str(v) for v in weight_map.values()})
+ shard_paths: list[Path] = []
+ missing: list[str] = []
+ for shard_name in shard_names:
+ shard_path = index_file.parent / shard_name
+ if not shard_path.is_file():
+ missing.append(str(shard_path))
+ else:
+ shard_paths.append(shard_path)
+
+ if missing:
+ raise AssembleError(f"{role} index references missing shard file(s): " + ", ".join(missing))
+
+ if not shard_paths:
+ raise AssembleError(f"No shard files referenced by {role} index: {index_file}")
+
+ return tuple(shard_paths)
+
+
+def _resolve_weight_spec(path: Path, role: str) -> WeightSpec:
+ if path.is_file():
+ return WeightSpec(kind="single", single_file=path)
+
+ if path.is_dir():
+ for name in WEIGHT_CANDIDATES:
+ candidate = path / name
+ if candidate.is_file():
+ return WeightSpec(kind="single", single_file=candidate)
+
+ for index_name in WEIGHT_INDEX_CANDIDATES:
+ index_file = path / index_name
+ if not index_file.is_file():
+ continue
+ shard_files = _load_shard_files_from_index(index_file, role=role)
+ return WeightSpec(
+ kind="sharded",
+ index_file=index_file,
+ shard_files=shard_files,
+ )
+
+ shard_candidates = sorted(path.glob("diffusion_pytorch_model-*.safetensors"))
+ if shard_candidates:
+ raise AssembleError(
+ f"Detected sharded {role} files under {path}, but index json is missing. "
+ f"Expected one of: {', '.join(WEIGHT_INDEX_CANDIDATES)}"
+ )
+
+ raise AssembleError(
+ f"Cannot find {role} weight under directory: {path}. "
+ f"Expected one of single files [{', '.join(WEIGHT_CANDIDATES)}] "
+ f"or sharded index files [{', '.join(WEIGHT_INDEX_CANDIDATES)}]."
+ )
+
+ raise AssembleError(f"{role} path does not exist: {path}")
+
+
+def _canonical_weight_name(weight_file: Path) -> str:
+ suffix = weight_file.suffix.lower()
+ if suffix == ".safetensors":
+ return "diffusion_pytorch_model.safetensors"
+ if suffix == ".bin":
+ return "diffusion_pytorch_model.bin"
+ if suffix == ".pt":
+ return "diffusion_pytorch_model.pt"
+ return weight_file.name
+
+
+def _validate_skeleton(skeleton: Path) -> None:
+ if not skeleton.is_dir():
+ raise AssembleError(f"--diffusers-skeleton is not a directory: {skeleton}")
+
+ for file_name in ROOT_REQUIRED_FILES:
+ if not (skeleton / file_name).is_file():
+ raise AssembleError(f"Missing required file in skeleton: {skeleton / file_name}")
+
+ for dir_name in ROOT_REQUIRED_DIRS:
+ if not (skeleton / dir_name).is_dir():
+ raise AssembleError(f"Missing required directory in skeleton: {skeleton / dir_name}")
+
+ if not (skeleton / "transformer" / "config.json").is_file():
+ raise AssembleError(f"Missing transformer config: {skeleton / 'transformer/config.json'}")
+
+ if not (skeleton / "transformer_2" / "config.json").is_file():
+ raise AssembleError(f"Missing transformer_2 config: {skeleton / 'transformer_2/config.json'}")
+
+
+def _ensure_clean_output(output_dir: Path, overwrite: bool) -> None:
+ if output_dir.exists():
+ if not overwrite:
+ raise AssembleError(
+ f"Output directory already exists: {output_dir}. Use --overwrite to remove and recreate it."
+ )
+ shutil.rmtree(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+
+def _copy_or_link_dir(src: Path, dst: Path, asset_mode: str) -> None:
+ if asset_mode == "copy":
+ shutil.copytree(src, dst)
+ elif asset_mode == "symlink":
+ dst.symlink_to(src, target_is_directory=True)
+ else:
+ raise AssembleError(f"Unknown asset mode: {asset_mode}")
+
+
+def _materialize_weight(weight: WeightSpec, dst_dir: Path, role: str) -> tuple[Path, ...]:
+ if weight.kind == "single":
+ assert weight.single_file is not None
+ dst = dst_dir / _canonical_weight_name(weight.single_file)
+ shutil.copy2(weight.single_file, dst)
+ return (dst,)
+
+ if weight.kind == "sharded":
+ assert weight.index_file is not None
+ copied: list[Path] = []
+ index_dst = dst_dir / weight.index_file.name
+ shutil.copy2(weight.index_file, index_dst)
+ copied.append(index_dst)
+ for shard_file in weight.shard_files:
+ shard_dst = dst_dir / shard_file.name
+ shutil.copy2(shard_file, shard_dst)
+ copied.append(shard_dst)
+ return tuple(copied)
+
+ raise AssembleError(f"Unknown {role} weight kind: {weight.kind}")
+
+
+def _assemble(
+ skeleton: Path,
+ output_dir: Path,
+ transformer_weight: WeightSpec,
+ transformer_2_weight: WeightSpec,
+ asset_mode: str,
+) -> tuple[tuple[Path, ...], tuple[Path, ...]]:
+ shutil.copy2(skeleton / "model_index.json", output_dir / "model_index.json")
+
+ for dir_name in ROOT_REQUIRED_DIRS:
+ if dir_name in ("transformer", "transformer_2"):
+ continue
+ _copy_or_link_dir(skeleton / dir_name, output_dir / dir_name, asset_mode)
+
+ for dir_name in OPTIONAL_DIRS:
+ src_dir = skeleton / dir_name
+ if src_dir.is_dir():
+ _copy_or_link_dir(src_dir, output_dir / dir_name, asset_mode)
+
+ (output_dir / "transformer").mkdir(parents=True, exist_ok=True)
+ (output_dir / "transformer_2").mkdir(parents=True, exist_ok=True)
+
+ shutil.copy2(skeleton / "transformer" / "config.json", output_dir / "transformer" / "config.json")
+ shutil.copy2(skeleton / "transformer_2" / "config.json", output_dir / "transformer_2" / "config.json")
+
+ transformer_copied = _materialize_weight(transformer_weight, output_dir / "transformer", role="transformer")
+ transformer_2_copied = _materialize_weight(
+ transformer_2_weight,
+ output_dir / "transformer_2",
+ role="transformer_2",
+ )
+
+ return transformer_copied, transformer_2_copied
+
+
+def _validate_output(
+ output_dir: Path,
+ transformer_copied: tuple[Path, ...],
+ transformer_2_copied: tuple[Path, ...],
+) -> None:
+ if not (output_dir / "model_index.json").is_file():
+ raise AssembleError("Output validation failed: model_index.json missing")
+
+ required_paths = (
+ output_dir / "tokenizer",
+ output_dir / "text_encoder",
+ output_dir / "vae",
+ output_dir / "transformer" / "config.json",
+ output_dir / "transformer_2" / "config.json",
+ *transformer_copied,
+ *transformer_2_copied,
+ )
+ missing = [str(p) for p in required_paths if not p.exists()]
+ if missing:
+ raise AssembleError("Output validation failed, missing: " + ", ".join(missing))
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description=(
+ "Assemble a Wan2.2-I2V-A14B-Diffusers directory while optionally "
+ "replacing transformer and transformer_2 weights."
+ )
+ )
+ parser.add_argument(
+ "--diffusers-skeleton",
+ type=Path,
+ required=True,
+ help="Path to a local Wan-AI/Wan2.2-I2V-A14B-Diffusers directory.",
+ )
+ parser.add_argument(
+ "--transformer-weight",
+ type=Path,
+ help=(
+ "Optional checkpoint file, or directory containing either a single-file "
+ "weight or sharded index+shards for transformer/. If omitted, keep the "
+ "skeleton's original transformer weights."
+ ),
+ )
+ parser.add_argument(
+ "--transformer-2-weight",
+ type=Path,
+ help=(
+ "Optional checkpoint file, or directory containing either a single-file "
+ "weight or sharded index+shards for transformer_2/. If omitted, keep the "
+ "skeleton's original transformer_2 weights."
+ ),
+ )
+ parser.add_argument(
+ "--high-noise-weight",
+ type=Path,
+ help=argparse.SUPPRESS,
+ )
+ parser.add_argument(
+ "--low-noise-weight",
+ type=Path,
+ help=argparse.SUPPRESS,
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=Path,
+ required=True,
+ help="Output directory for the assembled model.",
+ )
+ parser.add_argument(
+ "--asset-mode",
+ choices=("symlink", "copy"),
+ default="symlink",
+ help=(
+ "How to materialize non-transformer assets (tokenizer/text_encoder/vae/optional dirs). "
+ "symlink saves disk and is default."
+ ),
+ )
+ parser.add_argument(
+ "--overwrite",
+ action="store_true",
+ help="Overwrite output-dir if it exists.",
+ )
+ return parser.parse_args()
+
+
+def main() -> int:
+ args = parse_args()
+
+ skeleton = args.diffusers_skeleton.resolve()
+ output_dir = args.output_dir.resolve()
+
+ if args.transformer_weight is not None and args.high_noise_weight is not None:
+ print(
+ "[ERROR] --transformer-weight and --high-noise-weight are aliases; please provide only one.",
+ file=sys.stderr,
+ )
+ return 2
+ if args.transformer_2_weight is not None and args.low_noise_weight is not None:
+ print(
+ "[ERROR] --transformer-2-weight and --low-noise-weight are aliases; please provide only one.",
+ file=sys.stderr,
+ )
+ return 2
+
+ transformer_weight_arg = args.transformer_weight if args.transformer_weight is not None else args.high_noise_weight
+ transformer_2_weight_arg = (
+ args.transformer_2_weight if args.transformer_2_weight is not None else args.low_noise_weight
+ )
+
+ transformer_input = (
+ transformer_weight_arg.resolve() if transformer_weight_arg is not None else skeleton / "transformer"
+ )
+ transformer_2_input = (
+ transformer_2_weight_arg.resolve() if transformer_2_weight_arg is not None else skeleton / "transformer_2"
+ )
+
+ try:
+ _validate_skeleton(skeleton)
+ transformer_weight = _resolve_weight_spec(transformer_input, role="transformer")
+ transformer_2_weight = _resolve_weight_spec(transformer_2_input, role="transformer_2")
+
+ _ensure_clean_output(output_dir, overwrite=args.overwrite)
+ transformer_copied, transformer_2_copied = _assemble(
+ skeleton=skeleton,
+ output_dir=output_dir,
+ transformer_weight=transformer_weight,
+ transformer_2_weight=transformer_2_weight,
+ asset_mode=args.asset_mode,
+ )
+ _validate_output(output_dir, transformer_copied, transformer_2_copied)
+ except AssembleError as exc:
+ print(f"[ERROR] {exc}", file=sys.stderr)
+ return 2
+
+ def _weight_summary(copied: tuple[Path, ...]) -> str:
+ if len(copied) == 1:
+ return copied[0].name
+ return f"{copied[0].name} + {len(copied) - 1} shard files"
+
+ print("[OK] Assembled Wan2.2 I2V Diffusers directory:")
+ print(f" output_dir: {output_dir}")
+ print(f" transformer weight: {_weight_summary(transformer_copied)}")
+ print(f" transformer_2 weight: {_weight_summary(transformer_2_copied)}")
+ print("\nUse it with vLLM-Omni, for example:")
+ print(f" vllm serve {output_dir} --omni --port 8091")
+ return 0
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
index a550e576f01..84d89619e86 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
@@ -24,6 +24,7 @@
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero
from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
+from vllm_omni.diffusion.models.wan2_2.scheduling_wan_euler import WanEulerScheduler
from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
@@ -32,6 +33,46 @@
logger = logging.getLogger(__name__)
DEBUG_PERF = False
+WAN_SAMPLE_SOLVER_CHOICES = {"unipc", "euler"}
+
+
+def build_wan_scheduler(sample_solver: str, flow_shift: float) -> Any:
+ if sample_solver == "unipc":
+ return FlowUniPCMultistepScheduler(
+ num_train_timesteps=1000,
+ shift=flow_shift,
+ prediction_type="flow_prediction",
+ )
+ if sample_solver == "euler":
+ return WanEulerScheduler(
+ num_train_timesteps=1000,
+ shift=flow_shift,
+ )
+
+ raise ValueError(
+ f"Unsupported Wan sample_solver: {sample_solver}. Expected one of: {sorted(WAN_SAMPLE_SOLVER_CHOICES)}"
+ )
+
+
+def resolve_wan_sample_solver(req: OmniDiffusionRequest, default: str = "unipc") -> str:
+ extra_args = getattr(req.sampling_params, "extra_args", {}) or {}
+ raw = extra_args.get("sample_solver", default)
+ sample_solver = str(raw).strip().lower()
+ if sample_solver not in WAN_SAMPLE_SOLVER_CHOICES:
+ raise ValueError(f"Invalid sample_solver={raw!r}. Expected one of: {sorted(WAN_SAMPLE_SOLVER_CHOICES)}")
+ return sample_solver
+
+
+def resolve_wan_flow_shift(req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> float:
+ extra_args = getattr(req.sampling_params, "extra_args", {}) or {}
+ raw_flow_shift = extra_args.get("flow_shift")
+ if raw_flow_shift is None:
+ raw_flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0
+
+ try:
+ return float(raw_flow_shift)
+ except (TypeError, ValueError) as exc:
+ raise ValueError(f"Invalid flow_shift={raw_flow_shift!r}. flow_shift must be a float.") from exc
def retrieve_latents(
@@ -296,13 +337,9 @@ def __init__(
else:
raise RuntimeError("No transformer loaded")
- # Initialize UniPC scheduler
- flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p
- self.scheduler = FlowUniPCMultistepScheduler(
- num_train_timesteps=1000,
- shift=flow_shift,
- prediction_type="flow_prediction",
- )
+ self._sample_solver = "unipc"
+ self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0
+ self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
@@ -462,6 +499,13 @@ def forward(
current_omni_platform.synchronize()
_t_text_enc_ms = (time.perf_counter() - _t_text_enc_start) * 1000
+ sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver)
+ flow_shift = resolve_wan_flow_shift(req, self.od_config)
+ if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6:
+ self.scheduler = build_wan_scheduler(sample_solver, flow_shift)
+ self._sample_solver = sample_solver
+ self._flow_shift = flow_shift
+
# Timesteps
self.scheduler.set_timesteps(num_steps, device=device)
timesteps = self.scheduler.timesteps
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
index c05ecc9c9a2..46484cd789d 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
@@ -24,10 +24,12 @@
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero
-from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
+ build_wan_scheduler,
create_transformer_from_config,
load_transformer_config,
+ resolve_wan_flow_shift,
+ resolve_wan_sample_solver,
retrieve_latents,
)
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
@@ -230,13 +232,9 @@ def __init__(
else:
self.transformer_2 = None
- # Initialize UniPC scheduler
- flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p
- self.scheduler = FlowUniPCMultistepScheduler(
- num_train_timesteps=1000,
- shift=flow_shift,
- prediction_type="flow_prediction",
- )
+ self._sample_solver = "unipc"
+ self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0
+ self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift)
# VAE scale factors
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4
@@ -440,6 +438,13 @@ def forward(
current_omni_platform.synchronize()
_t_img_enc_ms = (time.perf_counter() - _t_img_enc_start) * 1000
+ sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver)
+ flow_shift = resolve_wan_flow_shift(req, self.od_config)
+ if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6:
+ self.scheduler = build_wan_scheduler(sample_solver, flow_shift)
+ self._sample_solver = sample_solver
+ self._flow_shift = flow_shift
+
# Timesteps
self.scheduler.set_timesteps(num_steps, device=device)
timesteps = self.scheduler.timesteps
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
index 261f62fb798..939fe294a33 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
@@ -36,10 +36,12 @@
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
-from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
+ build_wan_scheduler,
create_transformer_from_config,
load_transformer_config,
+ resolve_wan_flow_shift,
+ resolve_wan_sample_solver,
retrieve_latents,
)
from vllm_omni.diffusion.request import OmniDiffusionRequest
@@ -183,13 +185,9 @@ def __init__(
transformer_config = load_transformer_config(model, "transformer", local_files_only)
self.transformer = create_transformer_from_config(transformer_config)
- # Initialize UniPC scheduler
- flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p
- self.scheduler = FlowUniPCMultistepScheduler(
- num_train_timesteps=1000,
- shift=flow_shift,
- prediction_type="flow_prediction",
- )
+ self._sample_solver = "unipc"
+ self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0
+ self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift)
# VAE scale factors
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4
@@ -323,6 +321,13 @@ def forward(
batch_size = prompt_embeds.shape[0]
+ sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver)
+ flow_shift = resolve_wan_flow_shift(req, self.od_config)
+ if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6:
+ self.scheduler = build_wan_scheduler(sample_solver, flow_shift)
+ self._sample_solver = sample_solver
+ self._flow_shift = flow_shift
+
# Timesteps
self.scheduler.set_timesteps(num_steps, device=device)
timesteps = self.scheduler.timesteps
diff --git a/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py
new file mode 100644
index 00000000000..25444044c2d
--- /dev/null
+++ b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py
@@ -0,0 +1,147 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+
+import numpy as np
+import torch
+
+
+@dataclass
+class WanEulerSchedulerOutput:
+ prev_sample: torch.FloatTensor
+
+
+def _unsqueeze_to_ndim(in_tensor: torch.Tensor, target_ndim: int) -> torch.Tensor:
+ if in_tensor.ndim >= target_ndim:
+ return in_tensor
+ return in_tensor[(...,) + (None,) * (target_ndim - in_tensor.ndim)]
+
+
+def _get_timesteps(num_steps: int, max_steps: int = 1000) -> np.ndarray:
+ # Keep num_steps + 1 points so Euler update can always access sigma_next.
+ return np.linspace(max_steps, 0, num_steps + 1, dtype=np.float32)
+
+
+def _timestep_shift(timesteps: torch.Tensor, shift: float = 1.0) -> torch.Tensor:
+ return shift * timesteps / (1 + (shift - 1) * timesteps)
+
+
+class WanEulerScheduler:
+ order = 1
+
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ shift: float = 1.0,
+ device: torch.device | str = "cpu",
+ ) -> None:
+ self.num_train_timesteps = int(num_train_timesteps)
+ self._shift = float(shift)
+ self.device = device
+ self.config = SimpleNamespace(num_train_timesteps=self.num_train_timesteps)
+ self.init_noise_sigma = 1.0
+
+ self._step_index: int | None = None
+ self._begin_index: int | None = None
+
+ self.timesteps = torch.empty(0, dtype=torch.float32)
+ self.sigmas = torch.empty(0, dtype=torch.float32)
+ self.timesteps_ori = torch.empty(0, dtype=torch.float32)
+
+ self.set_timesteps(num_inference_steps=self.num_train_timesteps, device=self.device)
+
+ @property
+ def step_index(self) -> int | None:
+ return self._step_index
+
+ @property
+ def begin_index(self) -> int | None:
+ return self._begin_index
+
+ def set_begin_index(self, begin_index: int = 0) -> None:
+ self._begin_index = int(begin_index)
+
+ def index_for_timestep(self, timestep: torch.Tensor) -> int:
+ indices = (self.timesteps == timestep).nonzero()
+ if len(indices) > 0:
+ pos = 1 if len(indices) > 1 else 0
+ return int(indices[pos].item())
+ # Fallback for tiny float drift
+ return int(torch.argmin(torch.abs(self.timesteps - timestep)).item())
+
+ def _init_step_index(self, timestep: float | torch.Tensor) -> None:
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep_t = timestep.to(self.timesteps.device, dtype=self.timesteps.dtype)
+ else:
+ timestep_t = torch.tensor(timestep, device=self.timesteps.device, dtype=self.timesteps.dtype)
+ self._step_index = self.index_for_timestep(timestep_t)
+ else:
+ self._step_index = self._begin_index
+
+ def set_shift(self, shift: float = 1.0) -> None:
+ # Compute shifted sigma schedule on [0, 1].
+ sigmas_full = self.timesteps_ori / float(self.num_train_timesteps)
+ sigmas_full = _timestep_shift(sigmas_full, shift=float(shift))
+ self.sigmas = sigmas_full
+ # Public timesteps are the first N points; next point is consumed as sigma_next.
+ self.timesteps = self.sigmas[:-1] * self.num_train_timesteps
+ self._shift = float(shift)
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: torch.device | str | int | None = None,
+ **kwargs, # noqa: ARG002 - kept for scheduler API compatibility
+ ) -> None:
+ timesteps = _get_timesteps(
+ num_steps=int(num_inference_steps),
+ max_steps=self.num_train_timesteps,
+ )
+ self.timesteps_ori = torch.from_numpy(timesteps).to(
+ dtype=torch.float32,
+ device=device or self.device,
+ )
+ self.set_shift(self._shift)
+ self._step_index = None
+ self._begin_index = None
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: int | None = None) -> torch.Tensor: # noqa: ARG002
+ return sample
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: float | torch.FloatTensor,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ **kwargs, # noqa: ARG002 - kept for scheduler API compatibility
+ ) -> WanEulerSchedulerOutput | tuple[torch.FloatTensor]:
+ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
+ raise ValueError(
+ "Passing integer indices as timesteps is not supported. Use one value from scheduler.timesteps instead."
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+ assert self._step_index is not None
+
+ sample_fp32 = sample.to(torch.float32)
+ sigma = _unsqueeze_to_ndim(self.sigmas[self._step_index], sample_fp32.ndim).to(sample_fp32.device)
+ sigma_next = _unsqueeze_to_ndim(self.sigmas[self._step_index + 1], sample_fp32.ndim).to(sample_fp32.device)
+
+ prev_sample = sample_fp32 + (sigma_next - sigma) * model_output
+ prev_sample = prev_sample.to(model_output.dtype)
+
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+ return WanEulerSchedulerOutput(prev_sample=prev_sample)
+
+ def __len__(self) -> int:
+ return self.num_train_timesteps
diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py
index 65a2d4390ae..3b43f3eaf51 100644
--- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py
+++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py
@@ -1015,6 +1015,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if ".to_out.0." in lookup_name:
lookup_name = lookup_name.replace(".to_out.0.", ".to_out.")
+ # Compatibility: some Wan conversion pipelines still keep
+ # block modulation keys as `blocks.N.modulation` instead of
+ # `blocks.N.scale_shift_table`.
+ if lookup_name.endswith(".modulation"):
+ modulation_alias = lookup_name[: -len(".modulation")] + ".scale_shift_table"
+ if modulation_alias in params_dict:
+ lookup_name = modulation_alias
+
if lookup_name not in params_dict:
logger.warning(f"Skipping weight {original_name} -> {lookup_name}")
continue
diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py
index 8e0b2b2df11..32e8336f6da 100644
--- a/vllm_omni/engine/async_omni_engine.py
+++ b/vllm_omni/engine/async_omni_engine.py
@@ -1221,6 +1221,8 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
"enable_cpu_offload": kwargs.get("enable_cpu_offload", False),
"enable_layerwise_offload": kwargs.get("enable_layerwise_offload", False),
"enforce_eager": kwargs.get("enforce_eager", False),
+ "boundary_ratio": kwargs.get("boundary_ratio", None),
+ "flow_shift": kwargs.get("flow_shift", None),
"diffusion_load_format": kwargs.get("diffusion_load_format", "default"),
"custom_pipeline_args": kwargs.get("custom_pipeline_args", None),
"worker_extension_cls": kwargs.get("worker_extension_cls", None),
From d9e745ce2c562be06913cf27c3c9942a56154b93 Mon Sep 17 00:00:00 2001
From: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Date: Mon, 13 Apr 2026 02:30:56 -0400
Subject: [PATCH 09/76] [Fix] VoxCPM2: support raw audio for voice cloning via
OpenAI API (#2720)
Signed-off-by: Yueqian Lin
---
examples/online_serving/voxcpm2/README.md | 42 ++++++
.../voxcpm2/openai_speech_client.py | 108 +++++++++++++++
.../models/voxcpm2/voxcpm2_talker.py | 130 +++++++++++++++++-
3 files changed, 277 insertions(+), 3 deletions(-)
create mode 100644 examples/online_serving/voxcpm2/README.md
create mode 100644 examples/online_serving/voxcpm2/openai_speech_client.py
diff --git a/examples/online_serving/voxcpm2/README.md b/examples/online_serving/voxcpm2/README.md
new file mode 100644
index 00000000000..8735180f0ac
--- /dev/null
+++ b/examples/online_serving/voxcpm2/README.md
@@ -0,0 +1,42 @@
+# VoxCPM2 Online Serving
+
+Serve VoxCPM2 TTS via the OpenAI-compatible `/v1/audio/speech` endpoint.
+
+## Start the Server
+
+```bash
+python -m vllm_omni.entrypoints.openai.api_server \
+ --model openbmb/VoxCPM2 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm2.yaml \
+ --host 0.0.0.0 --port 8000
+```
+
+## Zero-shot Synthesis
+
+```bash
+python openai_speech_client.py --text "Hello, this is VoxCPM2."
+```
+
+Or with curl:
+
+```bash
+curl -X POST http://localhost:8000/v1/audio/speech \
+ -H "Content-Type: application/json" \
+ -d '{"model": "voxcpm2", "input": "Hello, this is VoxCPM2.", "voice": "default"}' \
+ --output output.wav
+```
+
+## Voice Cloning
+
+Clone a speaker's voice using a reference audio file:
+
+```bash
+python openai_speech_client.py \
+ --text "This should sound like the reference speaker." \
+ --ref-audio /path/to/reference.wav
+```
+
+The `--ref-audio` parameter accepts:
+- Local file path (auto-encoded to base64)
+- URL (`https://...`)
+- Base64 data URI (`data:audio/wav;base64,...`)
diff --git a/examples/online_serving/voxcpm2/openai_speech_client.py b/examples/online_serving/voxcpm2/openai_speech_client.py
new file mode 100644
index 00000000000..a117d24fd1a
--- /dev/null
+++ b/examples/online_serving/voxcpm2/openai_speech_client.py
@@ -0,0 +1,108 @@
+"""OpenAI-compatible client for VoxCPM2 TTS via /v1/audio/speech endpoint.
+
+Examples:
+ # Zero-shot synthesis
+ python openai_speech_client.py --text "Hello, this is VoxCPM2."
+
+ # Voice cloning with a local reference audio file
+ python openai_speech_client.py --text "Hello world" \
+ --ref-audio /path/to/reference.wav
+
+ # Voice cloning with a URL
+ python openai_speech_client.py --text "Hello world" \
+ --ref-audio "https://example.com/reference.wav"
+
+Server setup:
+ python -m vllm_omni.entrypoints.openai.api_server \
+ --model openbmb/VoxCPM2 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm2.yaml \
+ --host 0.0.0.0 --port 8000
+"""
+
+from __future__ import annotations
+
+import argparse
+import base64
+import os
+
+import httpx
+
+DEFAULT_API_BASE = "http://localhost:8000"
+DEFAULT_API_KEY = "sk-empty"
+
+
+def encode_audio_to_base64(audio_path: str) -> str:
+ """Encode a local audio file to a base64 data URL."""
+ if not os.path.exists(audio_path):
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
+
+ ext = audio_path.lower().rsplit(".", 1)[-1]
+ mime = {
+ "wav": "audio/wav",
+ "mp3": "audio/mpeg",
+ "flac": "audio/flac",
+ "ogg": "audio/ogg",
+ }.get(ext, "audio/wav")
+
+ with open(audio_path, "rb") as f:
+ b64 = base64.b64encode(f.read()).decode("utf-8")
+ return f"data:{mime};base64,{b64}"
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="VoxCPM2 OpenAI speech client")
+ parser.add_argument("--text", type=str, required=True, help="Text to synthesize")
+ parser.add_argument(
+ "--ref-audio",
+ type=str,
+ default=None,
+ help="Reference audio for voice cloning (local path, URL, or data: URI)",
+ )
+ parser.add_argument("--model", type=str, default="voxcpm2")
+ parser.add_argument("--output", type=str, default="output.wav")
+ parser.add_argument("--api-base", type=str, default=DEFAULT_API_BASE)
+ parser.add_argument("--api-key", type=str, default=DEFAULT_API_KEY)
+ parser.add_argument("--response-format", type=str, default="wav")
+ args = parser.parse_args()
+
+ # VoxCPM2 has no predefined voices. The "voice" field is required by
+ # the OpenAI API schema but ignored by VoxCPM2 — use any placeholder.
+ # For voice cloning, pass --ref-audio instead.
+ payload: dict = {
+ "model": args.model,
+ "input": args.text,
+ "voice": "default",
+ "response_format": args.response_format,
+ }
+
+ if args.ref_audio:
+ ref = args.ref_audio
+ if ref.startswith(("http://", "https://", "data:")):
+ payload["ref_audio"] = ref
+ else:
+ payload["ref_audio"] = encode_audio_to_base64(ref)
+
+ url = f"{args.api_base}/v1/audio/speech"
+ print(f"POST {url}")
+ print(f" text: {args.text}")
+ if args.ref_audio:
+ print(f" ref_audio: {args.ref_audio[:80]}...")
+
+ with httpx.Client(timeout=300) as client:
+ resp = client.post(
+ url,
+ json=payload,
+ headers={"Authorization": f"Bearer {args.api_key}"},
+ )
+
+ if resp.status_code != 200:
+ print(f"Error {resp.status_code}: {resp.text[:500]}")
+ return
+
+ with open(args.output, "wb") as f:
+ f.write(resp.content)
+ print(f"Saved: {args.output} ({len(resp.content):,} bytes)")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
index ade68b673b7..b9faf9fa3b8 100644
--- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
+++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
@@ -22,6 +22,7 @@
from collections.abc import Iterable
from typing import Any
+import librosa
import torch
import torch.nn as nn
from vllm.config import VllmConfig
@@ -41,6 +42,53 @@
logger = init_logger(__name__)
+def _encode_raw_audio(
+ tts: nn.Module,
+ samples: list[float] | torch.Tensor,
+ sr: int,
+ padding_mode: str = "right",
+) -> torch.Tensor:
+ """Encode raw audio samples using the native VoxCPM2 AudioVAE.
+
+ Mirrors ``VoxCPM2Model._encode_wav`` but accepts in-memory samples
+ instead of a file path. This is needed for the OpenAI speech API
+ where ``_resolve_ref_audio`` returns decoded audio data.
+
+ Args:
+ tts: Native VoxCPM2 tts_model instance.
+ samples: Audio samples (mono, float32).
+ sr: Sample rate of the input audio.
+ padding_mode: "right" (default) or "left" padding.
+
+ Returns:
+ audio_feat: (T, P, D) tensor of latent patches.
+ """
+ if isinstance(samples, list):
+ audio = torch.tensor(samples, dtype=torch.float32)
+ else:
+ audio = samples.float()
+
+ if audio.ndim == 1:
+ audio = audio.unsqueeze(0)
+
+ # Resample to the model's expected encoding sample rate
+ encode_sr = tts._encode_sample_rate
+ if sr != encode_sr:
+ audio_np = audio.squeeze(0).numpy()
+ audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=encode_sr)
+ audio = torch.from_numpy(audio_np).unsqueeze(0)
+
+ # Pad to patch boundary
+ patch_len = tts.patch_size * tts.chunk_size
+ if audio.size(1) % patch_len != 0:
+ padding_size = patch_len - audio.size(1) % patch_len
+ pad = (padding_size, 0) if padding_mode == "left" else (0, padding_size)
+ audio = torch.nn.functional.pad(audio, pad)
+
+ feat = tts.audio_vae.encode(audio.to(tts.device), encode_sr).cpu()
+ return feat.view(tts.audio_vae.latent_dim, -1, tts.patch_size).permute(1, 2, 0)
+
+
class VoxCPM2TalkerForConditionalGeneration(nn.Module):
"""VoxCPM2 talker using native MiniCPM4 base_lm.
@@ -83,6 +131,82 @@ def tts(self) -> nn.Module:
assert self._tts is not None, "Model not loaded yet"
return self._tts
+ def _build_prompt_cache(
+ self,
+ ref_audio: Any = None,
+ prompt_audio: Any = None,
+ prompt_text: str | None = None,
+ ) -> dict | None:
+ """Build prompt cache, handling both file paths and raw audio data.
+
+ The OpenAI speech API sends decoded audio as [samples_list, sr]
+ via ``_resolve_ref_audio``, while offline usage sends file paths.
+ This method detects the format and routes accordingly.
+ """
+ tts = self.tts
+
+ def _is_raw_audio(v: Any) -> bool:
+ """Check if value is [samples, sr] from serving_speech."""
+ return (
+ isinstance(v, (list, tuple))
+ and len(v) == 2
+ and isinstance(v[1], int)
+ and isinstance(v[0], (list, torch.Tensor))
+ )
+
+ # If all inputs are file paths (or None), use native build_prompt_cache
+ if not _is_raw_audio(ref_audio) and not _is_raw_audio(prompt_audio):
+ return tts.build_prompt_cache(
+ prompt_text=prompt_text,
+ prompt_wav_path=prompt_audio,
+ reference_wav_path=ref_audio,
+ )
+
+ # Raw audio path: encode directly
+ cache: dict[str, Any] = {}
+
+ if ref_audio is not None:
+ if _is_raw_audio(ref_audio):
+ samples, sr = ref_audio
+ cache["ref_audio_feat"] = _encode_raw_audio(
+ tts,
+ samples,
+ sr,
+ padding_mode="right",
+ )
+ else:
+ cache["ref_audio_feat"] = tts._encode_wav(
+ ref_audio,
+ padding_mode="right",
+ )
+
+ if prompt_audio is not None and prompt_text is not None:
+ cache["prompt_text"] = prompt_text
+ if _is_raw_audio(prompt_audio):
+ samples, sr = prompt_audio
+ cache["audio_feat"] = _encode_raw_audio(
+ tts,
+ samples,
+ sr,
+ padding_mode="left",
+ )
+ else:
+ cache["audio_feat"] = tts._encode_wav(
+ prompt_audio,
+ padding_mode="left",
+ )
+
+ has_ref = "ref_audio_feat" in cache
+ has_prompt = "audio_feat" in cache
+ if has_ref and has_prompt:
+ cache["mode"] = "ref_continuation"
+ elif has_ref:
+ cache["mode"] = "reference"
+ else:
+ cache["mode"] = "continuation"
+
+ return cache
+
# -------------------- vllm hooks --------------------
def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor:
@@ -482,10 +606,10 @@ def preprocess(
self._prompt_cache = None
if ref_audio or (prompt_audio and prompt_text):
try:
- self._prompt_cache = self.tts.build_prompt_cache(
+ self._prompt_cache = self._build_prompt_cache(
+ ref_audio=ref_audio,
+ prompt_audio=prompt_audio,
prompt_text=prompt_text,
- prompt_wav_path=prompt_audio,
- reference_wav_path=ref_audio,
)
except Exception as e:
logger.warning("build_prompt_cache failed: %s; falling back to zero-shot", e)
From 22261430b42b3e91d2019367da9fe1a8bac7f58a Mon Sep 17 00:00:00 2001
From: wangyu <53896905+yenuo26@users.noreply.github.com>
Date: Mon, 13 Apr 2026 14:47:55 +0800
Subject: [PATCH 10/76] [CI][Bugfix] Refactor the test case to add support for
increasing init timeout and stage init timeout in order to resolve the CI
timeout error. (#2711)
Signed-off-by: wangyu <410167048@qq.com>
---
.buildkite/test-merge.yml | 2 +-
.buildkite/test-nightly.yml | 3 +-
tests/conftest.py | 8 +-
.../offline_inference/test_bagel_img2img.py | 15 +-
.../e2e/offline_inference/test_bagel_lora.py | 11 +-
.../offline_inference/test_bagel_text2img.py | 32 ++--
.../test_bagel_understanding.py | 27 +--
tests/e2e/offline_inference/test_cache_dit.py | 35 +---
.../test_diffusion_cpu_offload.py | 43 ++---
.../test_diffusion_layerwise_offload.py | 56 +++---
.../offline_inference/test_diffusion_lora.py | 14 +-
.../e2e/offline_inference/test_dynin_omni.py | 73 ++------
.../offline_inference/test_expert_parallel.py | 51 +++---
.../test_flux_autoround_w4a16.py | 40 ++---
.../offline_inference/test_flux_kontext.py | 97 +++++-----
.../test_hunyuanimage3_text2img.py | 14 +-
.../e2e/offline_inference/test_magi_human.py | 17 +-
.../offline_inference/test_mammoth_moda2.py | 11 +-
tests/e2e/offline_inference/test_omnivoice.py | 55 +++---
.../test_quantization_fp8.py | 19 +-
.../test_qwen_image_diffusion_batching.py | 165 ++++++++----------
.../test_sequence_parallel.py | 63 ++++---
.../test_stable_audio_model.py | 21 +--
tests/e2e/offline_inference/test_t2i_model.py | 101 +++++------
tests/e2e/offline_inference/test_t2v_model.py | 51 +++---
tests/e2e/offline_inference/test_teacache.py | 37 +---
.../test_vae_decode_parallelism.py | 36 ++--
tests/e2e/offline_inference/test_voxcpm2.py | 7 +-
.../e2e/offline_inference/test_voxtral_tts.py | 17 +-
.../test_zimage_parallelism.py | 112 ++++++------
.../test_images_generations_lora.py | 2 +-
31 files changed, 497 insertions(+), 738 deletions(-)
diff --git a/.buildkite/test-merge.yml b/.buildkite/test-merge.yml
index 7355e2b4c7c..24fc6dd3dc2 100644
--- a/.buildkite/test-merge.yml
+++ b/.buildkite/test-merge.yml
@@ -113,7 +113,7 @@ steps:
- "/fsx/hf_cache:/fsx/hf_cache"
- label: "Diffusion Sequence Parallelism Test"
- timeout_in_minutes: 20
+ timeout_in_minutes: 25
depends_on: upload-merge-pipeline
commands:
- pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py tests/diffusion/distributed/test_ulysses_uaa_perf.py
diff --git a/.buildkite/test-nightly.yml b/.buildkite/test-nightly.yml
index 06b7c14ae1d..31b3e17976c 100644
--- a/.buildkite/test-nightly.yml
+++ b/.buildkite/test-nightly.yml
@@ -141,7 +141,6 @@ steps:
- export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- pytest -s -v tests/dfx/perf/scripts/run_benchmark.py
- buildkite-agent artifact upload "tests/dfx/perf/results/*.json"
- - buildkite-agent artifact upload "tests/dfx/perf/results/*.html"
agents:
queue: "mithril-h100-pool"
plugins:
@@ -244,7 +243,7 @@ steps:
- export DEFAULT_OUTPUT_DIR=tests/dfx/perf/results
- buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-omni-performance
- buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-qwen-image-performance
- - buildkite-agent artifact download "tests/dfx/perf/results/*.html" . --step nightly-omni-performance
+ - buildkite-agent artifact download "tests/dfx/perf/results/*.html" . --step nightly-testcase-statistics
- python tools/nightly/generate_nightly_perf_excel.py
- python tools/nightly/generate_nightly_perf_html.py
- python tools/nightly/send_nightly_email.py --report-file "tests/dfx/perf/results/*.xlsx, tests/dfx/perf/results/*.html"
diff --git a/tests/conftest.py b/tests/conftest.py
index 18a0ee57d97..9c739533b83 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1771,8 +1771,12 @@ def omni_server(request: pytest.FixtureRequest, run_level: str, model_prefix: st
server_args = params.server_args or []
if params.use_omni and params.stage_init_timeout is not None:
server_args = [*server_args, "--stage-init-timeout", str(params.stage_init_timeout)]
+ else:
+ server_args = [*server_args, "--stage-init-timeout", "600"]
if params.init_timeout is not None:
server_args = [*server_args, "--init-timeout", str(params.init_timeout)]
+ else:
+ server_args = [*server_args, "--init-timeout", "900"]
if params.use_stage_cli:
if not params.use_omni:
raise ValueError("omni_server with use_stage_cli=True requires use_omni=True")
@@ -2870,9 +2874,9 @@ def __init__(
self,
model_name: str,
seed: int = 42,
- stage_init_timeout: int = 300,
+ stage_init_timeout: int = 600,
batch_timeout: int = 10,
- init_timeout: int = 300,
+ init_timeout: int = 900,
shm_threshold_bytes: int = 65536,
log_stats: bool = False,
stage_configs_path: str | None = None,
diff --git a/tests/e2e/offline_inference/test_bagel_img2img.py b/tests/e2e/offline_inference/test_bagel_img2img.py
index a0c3f6cc9fc..63d2a37da79 100644
--- a/tests/e2e/offline_inference/test_bagel_img2img.py
+++ b/tests/e2e/offline_inference/test_bagel_img2img.py
@@ -22,9 +22,9 @@
from PIL import Image
from vllm.assets.image import ImageAsset
-from tests.conftest import modify_stage_config
+from tests.conftest import OmniRunner, modify_stage_config
from tests.utils import hardware_test
-from vllm_omni.entrypoints.omni import Omni
+from vllm_omni import Omni
from vllm_omni.platforms import current_omni_platform
# Reference pixel data extracted from the known-good output image
@@ -210,11 +210,10 @@ def test_bagel_img2img_shared_memory_connector(run_level):
input_image = _load_input_image()
config_path = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml")
config_path = _resolve_stage_config(config_path, run_level)
- omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=config_path, stage_init_timeout=300)
-
- try:
- generated_image = _generate_bagel_img2img(omni, input_image)
+ with OmniRunner(
+ "ByteDance-Seed/BAGEL-7B-MoT",
+ stage_configs_path=config_path,
+ ) as runner:
+ generated_image = _generate_bagel_img2img(runner.omni, input_image)
if run_level == "advanced_model":
_validate_pixels(generated_image)
- finally:
- omni.close()
diff --git a/tests/e2e/offline_inference/test_bagel_lora.py b/tests/e2e/offline_inference/test_bagel_lora.py
index 593a640478d..501d23eaa88 100644
--- a/tests/e2e/offline_inference/test_bagel_lora.py
+++ b/tests/e2e/offline_inference/test_bagel_lora.py
@@ -22,7 +22,6 @@
from vllm_omni.outputs import OmniRequestOutput
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
from pathlib import Path
@@ -32,9 +31,9 @@
from PIL import Image
from safetensors.torch import save_file
-from tests.conftest import modify_stage_config
+from tests.conftest import OmniRunner, modify_stage_config
from tests.utils import hardware_test
-from vllm_omni.entrypoints.omni import Omni
+from vllm_omni import Omni
from vllm_omni.lora.request import LoRARequest
from vllm_omni.lora.utils import stable_lora_int_id
@@ -154,8 +153,8 @@ def _make_file_lora_request(adapter_dir: Path) -> LoRARequest:
def test_bagel_lora_scale_and_deactivation(run_level, tmp_path):
"""Validate LoRA effect, bounded perturbation, and clean deactivation."""
config_path = _resolve_stage_config(BAGEL_STAGE_CONFIG, run_level)
- omni = Omni(model=MODEL, stage_configs_path=config_path, stage_init_timeout=300)
- try:
+ with OmniRunner(MODEL, stage_configs_path=config_path) as runner:
+ omni = runner.omni
lora_request = _make_file_lora_request(tmp_path / "bagel_lora")
# 1) Baseline (no LoRA)
@@ -194,5 +193,3 @@ def test_bagel_lora_scale_and_deactivation(run_level, tmp_path):
# (d) Deactivation fully restores base model
assert diff_restored == 0.0, f"Base model not restored after LoRA deactivation: diff={diff_restored}"
- finally:
- omni.close()
diff --git a/tests/e2e/offline_inference/test_bagel_text2img.py b/tests/e2e/offline_inference/test_bagel_text2img.py
index 7cce8da3a73..e45d64f2ac5 100644
--- a/tests/e2e/offline_inference/test_bagel_text2img.py
+++ b/tests/e2e/offline_inference/test_bagel_text2img.py
@@ -16,7 +16,6 @@
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
import signal
import socket
import subprocess
@@ -28,9 +27,9 @@
import pytest
from PIL import Image
-from tests.conftest import modify_stage_config
+from tests.conftest import OmniRunner, modify_stage_config
from tests.utils import hardware_test
-from vllm_omni.entrypoints.omni import Omni
+from vllm_omni import Omni
from vllm_omni.platforms import current_omni_platform
# Reference pixel data extracted from the known-good output image
@@ -199,14 +198,13 @@ def test_bagel_text2img_shared_memory_connector(run_level):
"""Test Bagel text2img with shared memory connector."""
config_path = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml")
config_path = _resolve_stage_config(config_path, run_level)
- omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=config_path, stage_init_timeout=300)
-
- try:
- generated_image = _generate_bagel_image(omni)
+ with OmniRunner(
+ "ByteDance-Seed/BAGEL-7B-MoT",
+ stage_configs_path=config_path,
+ ) as runner:
+ generated_image = _generate_bagel_image(runner.omni)
if run_level == "advanced_model":
_validate_pixels(generated_image)
- finally:
- omni.close()
def _wait_for_port(host: str, port: int, timeout: int = 30) -> bool:
@@ -319,7 +317,6 @@ def test_bagel_text2img_mooncake_connector(run_level):
mooncake_master_proc = None
temp_config_file = None
- omni = None
try:
_cleanup_mooncake_processes()
@@ -349,15 +346,16 @@ def test_bagel_text2img_mooncake_connector(run_level):
)
temp_config_file = _resolve_stage_config(temp_config_file, run_level)
- omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=temp_config_file, stage_init_timeout=300)
-
- generated_image = _generate_bagel_image(omni)
- if run_level == "advanced_model":
- _validate_pixels(generated_image)
+ with OmniRunner(
+ "ByteDance-Seed/BAGEL-7B-MoT",
+ stage_configs_path=temp_config_file,
+ stage_init_timeout=300,
+ ) as runner:
+ generated_image = _generate_bagel_image(runner.omni)
+ if run_level == "advanced_model":
+ _validate_pixels(generated_image)
finally:
- if omni:
- omni.close()
if temp_config_file:
try:
os.unlink(temp_config_file)
diff --git a/tests/e2e/offline_inference/test_bagel_understanding.py b/tests/e2e/offline_inference/test_bagel_understanding.py
index 6f95e7ee00f..bbee3298079 100644
--- a/tests/e2e/offline_inference/test_bagel_understanding.py
+++ b/tests/e2e/offline_inference/test_bagel_understanding.py
@@ -21,15 +21,13 @@
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
from pathlib import Path
import pytest
from vllm.assets.image import ImageAsset
-from tests.conftest import modify_stage_config
+from tests.conftest import OmniRunner, modify_stage_config
from tests.utils import hardware_test
-from vllm_omni.entrypoints.omni import Omni
MODEL_NAME = "ByteDance-Seed/BAGEL-7B-MoT"
STAGE_CONFIG = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml")
@@ -76,13 +74,11 @@ def _extract_text(omni_outputs: list) -> str:
def test_bagel_text2text(run_level):
"""Test Bagel text2text produces correct text output."""
config_path = _resolve_stage_config(STAGE_CONFIG, run_level)
- omni = Omni(
- model=MODEL_NAME,
+ with OmniRunner(
+ MODEL_NAME,
stage_configs_path=config_path,
- stage_init_timeout=300,
- )
-
- try:
+ ) as runner:
+ omni = runner.omni
prompt = "<|im_start|>user\nWhere is the capital of France?<|im_end|>\n<|im_start|>assistant\n"
params_list = omni.default_sampling_params_list
omni_outputs = list(
@@ -100,8 +96,6 @@ def test_bagel_text2text(run_level):
assert text == REFERENCE_TEXT_TEXT2TEXT, (
f"Text mismatch: expected {REFERENCE_TEXT_TEXT2TEXT!r}, got {text!r}"
)
- finally:
- omni.close()
@pytest.mark.core_model
@@ -112,13 +106,12 @@ def test_bagel_img2text(run_level):
"""Test Bagel img2text produces correct text output."""
input_image = ImageAsset("2560px-Gfp-wisconsin-madison-the-nature-boardwalk").pil_image.convert("RGB")
config_path = _resolve_stage_config(STAGE_CONFIG, run_level)
- omni = Omni(
- model=MODEL_NAME,
+ with OmniRunner(
+ MODEL_NAME,
stage_configs_path=config_path,
stage_init_timeout=300,
- )
-
- try:
+ ) as runner:
+ omni = runner.omni
prompt = "<|im_start|>user\n<|image_pad|>\nPlease describe this image<|im_end|>\n<|im_start|>assistant\n"
params_list = omni.default_sampling_params_list
omni_outputs = list(
@@ -140,5 +133,3 @@ def test_bagel_img2text(run_level):
if run_level == "advanced_model":
assert text == REFERENCE_TEXT_IMG2TEXT, f"Text mismatch: expected {REFERENCE_TEXT_IMG2TEXT!r}, got {text!r}"
- finally:
- omni.close()
diff --git a/tests/e2e/offline_inference/test_cache_dit.py b/tests/e2e/offline_inference/test_cache_dit.py
index 0e31413dc07..fc08da7bedf 100644
--- a/tests/e2e/offline_inference/test_cache_dit.py
+++ b/tests/e2e/offline_inference/test_cache_dit.py
@@ -8,27 +8,15 @@
It uses minimal settings to keep test time short for CI.
"""
-import os
-import sys
-from pathlib import Path
-
import pytest
import torch
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
-
# Use random weights model for testing
models = ["riverclouds/qwen_image_random"]
@@ -48,20 +36,17 @@ def test_cache_dit(model_name: str):
"residual_diff_threshold": 0.24,
"max_continuous_cached_steps": 3,
}
- m = None
- try:
- m = Omni(
- model=model_name,
- cache_backend="cache_dit",
- cache_config=cache_config,
- )
-
+ with OmniRunner(
+ model_name,
+ cache_backend="cache_dit",
+ cache_config=cache_config,
+ ) as runner:
# Use minimal settings for fast testing
height = 256
width = 256
num_inference_steps = 4 # Minimal steps for fast test
- outputs = m.generate(
+ outputs = runner.omni.generate(
"a photo of a cat sitting on a laptop keyboard",
OmniDiffusionSamplingParams(
height=height,
@@ -90,9 +75,3 @@ def test_cache_dit(model_name: str):
# Check image size
assert images[0].width == width
assert images[0].height == height
- except Exception as e:
- print(f"Test failed with error: {e}")
- raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
diff --git a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
index f3830f02e97..257755ef8b9 100644
--- a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
+++ b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
@@ -1,22 +1,14 @@
import gc
-import sys
-from pathlib import Path
import pytest
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
+from tests.conftest import OmniRunner
from tests.utils import DeviceMemoryMonitor, hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
-
models = ["riverclouds/qwen_image_random"]
@@ -27,30 +19,29 @@ def inference(model_name: str, offload: bool = True):
current_omni_platform.reset_peak_memory_stats()
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
- m = Omni(
- model=model_name,
+ with OmniRunner(
+ model_name,
# TODO: we might want to add overlapped feature e2e tests
# cache_backend="cache_dit",
enable_cpu_offload=offload,
- )
- current_omni_platform.reset_peak_memory_stats()
- height = 256
- width = 256
+ ) as runner:
+ current_omni_platform.reset_peak_memory_stats()
+ height = 256
+ width = 256
- m.generate(
- "a photo of a cat sitting on a laptop keyboard",
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=9,
- guidance_scale=0.0,
- generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
- ),
- )
+ runner.omni.generate(
+ "a photo of a cat sitting on a laptop keyboard",
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=9,
+ guidance_scale=0.0,
+ generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
+ ),
+ )
peak = monitor.peak_used_mb
monitor.stop()
- del m
gc.collect()
current_omni_platform.empty_cache()
diff --git a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py
index 6132f1bd0eb..bdfd594c774 100644
--- a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py
+++ b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py
@@ -1,21 +1,12 @@
-import sys
-from pathlib import Path
-
import pytest
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
+from tests.conftest import OmniRunner
from tests.utils import DeviceMemoryMonitor
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
-
# Models to test and expected saved memory in MB, correspondingly
MODELS_SAVED_MEMORY_MB = {
"riverclouds/qwen_image_random": 4500,
@@ -33,34 +24,33 @@ def run_inference(
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
- m = Omni(
- model=model_name,
+ with OmniRunner(
+ model_name,
enable_layerwise_offload=layerwise_offload,
# TODO: we might want to add overlapped feature e2e tests
# cache_backend="cache_dit",
boundary_ratio=0.875,
flow_shift=5.0,
- )
-
- current_omni_platform.reset_peak_memory_stats()
-
- # Refer to tests/e2e/offline_inference/test_t2v_model.py
- # Use minimal settings for testing
- height = 480
- width = 640
- num_frames = 5
-
- m.generate(
- "A cat sitting on a table",
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
- guidance_scale=1.0,
- num_inference_steps=num_inference_steps,
- num_frames=num_frames,
- ),
- )
+ ) as runner:
+ current_omni_platform.reset_peak_memory_stats()
+
+ # Refer to tests/e2e/offline_inference/test_t2v_model.py
+ # Use minimal settings for testing
+ height = 480
+ width = 640
+ num_frames = 5
+
+ runner.omni.generate(
+ "A cat sitting on a table",
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
+ guidance_scale=1.0,
+ num_inference_steps=num_inference_steps,
+ num_frames=num_frames,
+ ),
+ )
peak = monitor.peak_used_mb
monitor.stop()
diff --git a/tests/e2e/offline_inference/test_diffusion_lora.py b/tests/e2e/offline_inference/test_diffusion_lora.py
index b414fe30eeb..7edd03f20d1 100644
--- a/tests/e2e/offline_inference/test_diffusion_lora.py
+++ b/tests/e2e/offline_inference/test_diffusion_lora.py
@@ -7,6 +7,7 @@
import torch
from safetensors.torch import save_file
+from tests.conftest import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
@@ -16,15 +17,12 @@
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
-from vllm_omni import Omni
-
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
# This test is specific to Z-Image LoRA behavior. Keep it focused on a single
# model to reduce runtime and avoid extra downloads.
models = ["Tongyi-MAI/Z-Image-Turbo"]
-DIFFUSION_INIT_TIMEOUT_S = 600
@pytest.mark.parametrize("model_name", models)
@@ -77,12 +75,8 @@ def _write_zimage_lora(adapter_dir: Path) -> str:
)
return str(adapter_dir)
- m = Omni(
- model=model_name,
- stage_init_timeout=DIFFUSION_INIT_TIMEOUT_S,
- init_timeout=DIFFUSION_INIT_TIMEOUT_S,
- )
- try:
+ with OmniRunner(model_name) as runner:
+ m = runner.omni
# high resolution may cause OOM on L4
height = 256
width = 256
@@ -140,5 +134,3 @@ def _write_zimage_lora(adapter_dir: Path) -> str:
diff = np.abs(np.array(images[0], dtype=np.int16) - np.array(images_lora[0], dtype=np.int16)).mean()
assert diff > 0.0
- finally:
- m.close()
diff --git a/tests/e2e/offline_inference/test_dynin_omni.py b/tests/e2e/offline_inference/test_dynin_omni.py
index d17e7b81755..5388ac67468 100644
--- a/tests/e2e/offline_inference/test_dynin_omni.py
+++ b/tests/e2e/offline_inference/test_dynin_omni.py
@@ -18,7 +18,6 @@
import torch
from transformers import AutoTokenizer
-from tests.conftest import OmniRunner
from tests.utils import hardware_test
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@@ -37,6 +36,7 @@
pytestmark = [
pytest.mark.core_model,
pytest.mark.omni,
+ pytest.mark.parametrize("omni_runner", test_params, indirect=True),
]
@@ -291,20 +291,11 @@ def _numel(value: Any) -> int:
@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-@pytest.mark.parametrize("test_config", test_params)
-def test_dynin_t2i_decode_to_image(test_config: tuple[str, str]) -> None:
- model, stage_config_path = test_config
+def test_dynin_t2i_decode_to_image(omni_runner) -> None:
_configure_dynin_config_env()
prompt = _build_t2i_decode_prompt(dynin_config_path=DYNIN_CONFIG_PATH)
- with OmniRunner(
- model,
- seed=42,
- stage_configs_path=stage_config_path,
- stage_init_timeout=600,
- init_timeout=600,
- ) as runner:
- outputs = runner.generate([prompt])
+ outputs = omni_runner.generate([prompt])
image_output = _find_stage_output(outputs, "image")
assert image_output is not None
@@ -314,25 +305,16 @@ def test_dynin_t2i_decode_to_image(test_config: tuple[str, str]) -> None:
@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-@pytest.mark.parametrize("test_config", test_params)
-def test_dynin_mmu_to_text(test_config: tuple[str, str]) -> None:
- model, stage_config_path = test_config
+def test_dynin_mmu_to_text(omni_runner) -> None:
_configure_dynin_config_env()
- tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(omni_runner.model_name, trust_remote_code=True)
prompt = _build_mmu_prompt(
tokenizer=tokenizer,
question="What is 2 + 2? Answer in one short sentence.",
dynin_config_path=DYNIN_CONFIG_PATH,
)
- with OmniRunner(
- model,
- seed=42,
- stage_configs_path=stage_config_path,
- stage_init_timeout=600,
- init_timeout=600,
- ) as runner:
- outputs = runner.generate([prompt])
+ outputs = omni_runner.generate([prompt])
text_output = _find_stage_output(outputs, "text")
assert text_output is not None
@@ -341,11 +323,9 @@ def test_dynin_mmu_to_text(test_config: tuple[str, str]) -> None:
@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-@pytest.mark.parametrize("test_config", test_params)
-def test_dynin_image_to_text(test_config: tuple[str, str]) -> None:
- model, stage_config_path = test_config
+def test_dynin_image_to_text(omni_runner) -> None:
_configure_dynin_config_env()
- tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(omni_runner.model_name, trust_remote_code=True)
prompt = _build_mmu_multimodal_prompt(
tokenizer=tokenizer,
question="Describe the image briefly in one sentence.",
@@ -353,14 +333,7 @@ def test_dynin_image_to_text(test_config: tuple[str, str]) -> None:
image=_generate_synthetic_image(),
)
- with OmniRunner(
- model,
- seed=42,
- stage_configs_path=stage_config_path,
- stage_init_timeout=600,
- init_timeout=600,
- ) as runner:
- outputs = runner.generate([prompt])
+ outputs = omni_runner.generate([prompt])
text_output = _find_stage_output(outputs, "text")
assert text_output is not None
@@ -369,11 +342,9 @@ def test_dynin_image_to_text(test_config: tuple[str, str]) -> None:
@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-@pytest.mark.parametrize("test_config", test_params)
-def test_dynin_speech_to_text(test_config: tuple[str, str]) -> None:
- model, stage_config_path = test_config
+def test_dynin_speech_to_text(omni_runner) -> None:
_configure_dynin_config_env()
- tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(omni_runner.model_name, trust_remote_code=True)
prompt = _build_mmu_multimodal_prompt(
tokenizer=tokenizer,
question="Transcribe the audio briefly in one sentence.",
@@ -381,14 +352,7 @@ def test_dynin_speech_to_text(test_config: tuple[str, str]) -> None:
audio=_generate_synthetic_audio(),
)
- with OmniRunner(
- model,
- seed=42,
- stage_configs_path=stage_config_path,
- stage_init_timeout=600,
- init_timeout=600,
- ) as runner:
- outputs = runner.generate([prompt])
+ outputs = omni_runner.generate([prompt])
text_output = _find_stage_output(outputs, "text")
assert text_output is not None
@@ -397,20 +361,11 @@ def test_dynin_speech_to_text(test_config: tuple[str, str]) -> None:
@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-@pytest.mark.parametrize("test_config", test_params)
-def test_dynin_t2s_decode_to_audio(test_config: tuple[str, str]) -> None:
- model, stage_config_path = test_config
+def test_dynin_t2s_decode_to_audio(omni_runner) -> None:
_configure_dynin_config_env()
prompt = _build_t2s_decode_prompt(dynin_config_path=DYNIN_CONFIG_PATH)
- with OmniRunner(
- model,
- seed=42,
- stage_configs_path=stage_config_path,
- stage_init_timeout=600,
- init_timeout=600,
- ) as runner:
- outputs = runner.generate([prompt])
+ outputs = omni_runner.generate([prompt])
audio_output = _find_stage_output(outputs, "audio")
assert audio_output is not None
diff --git a/tests/e2e/offline_inference/test_expert_parallel.py b/tests/e2e/offline_inference/test_expert_parallel.py
index ba126986ec7..29d84d7a3e2 100644
--- a/tests/e2e/offline_inference/test_expert_parallel.py
+++ b/tests/e2e/offline_inference/test_expert_parallel.py
@@ -18,8 +18,8 @@
import torch.distributed as dist
from PIL import Image
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
-from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
@@ -96,12 +96,26 @@ def _run_inference(
tensor_parallel_size=tensor_parallel_size,
enable_expert_parallel=enable_expert_parallel,
)
- omni = Omni(model=model_name, parallel_config=parallel_config)
-
try:
- # Warmup run (not timed)
- if warmup:
- _ = omni.generate(
+ with OmniRunner(model_name, parallel_config=parallel_config) as runner:
+ omni = runner.omni
+ # Warmup run (not timed)
+ if warmup:
+ _ = omni.generate(
+ PROMPT,
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=DEFAULT_STEPS,
+ guidance_scale=guidance_scale,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
+ num_outputs_per_prompt=1,
+ ),
+ )
+
+ # Timed run
+ start = time.time()
+ outputs = omni.generate(
PROMPT,
OmniDiffusionSamplingParams(
height=height,
@@ -112,28 +126,13 @@ def _run_inference(
num_outputs_per_prompt=1,
),
)
+ elapsed_ms = (time.time() - start) * 1000
- # Timed run
- start = time.time()
- outputs = omni.generate(
- PROMPT,
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=DEFAULT_STEPS,
- guidance_scale=guidance_scale,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
- num_outputs_per_prompt=1,
- ),
- )
- elapsed_ms = (time.time() - start) * 1000
-
- return InferenceResult(
- images=outputs[0].images,
- elapsed_ms=elapsed_ms,
- )
+ return InferenceResult(
+ images=outputs[0].images,
+ elapsed_ms=elapsed_ms,
+ )
finally:
- omni.close()
_cleanup_distributed()
diff --git a/tests/e2e/offline_inference/test_flux_autoround_w4a16.py b/tests/e2e/offline_inference/test_flux_autoround_w4a16.py
index 42aab7f26a8..cbcd1009dd5 100644
--- a/tests/e2e/offline_inference/test_flux_autoround_w4a16.py
+++ b/tests/e2e/offline_inference/test_flux_autoround_w4a16.py
@@ -8,31 +8,21 @@
"""
import gc
-import sys
-from pathlib import Path
+import os as _os
import pytest
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
+from tests.conftest import OmniRunner
from tests.utils import DeviceMemoryMonitor, hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
-
QUANTIZED_MODEL = "vllm-project-org/FLUX.1-dev-AutoRound-w4a16"
BASELINE_MODEL = "black-forest-labs/FLUX.1-dev"
-# Allow overriding via environment for local testing
-import os as _os
-
QUANTIZED_MODEL = _os.environ.get("FLUX_AUTOROUND_MODEL", QUANTIZED_MODEL)
BASELINE_MODEL = _os.environ.get("FLUX_BASELINE_MODEL", BASELINE_MODEL)
@@ -51,19 +41,18 @@ def _generate_image(model_name: str, **extra_kwargs) -> tuple[list, float]:
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
- m = Omni(model=model_name, enforce_eager=True, **extra_kwargs)
-
- current_omni_platform.reset_peak_memory_stats()
- outputs = m.generate(
- "a photo of a cat sitting on a laptop keyboard",
- OmniDiffusionSamplingParams(
- height=HEIGHT,
- width=WIDTH,
- num_inference_steps=NUM_STEPS,
- guidance_scale=0.0,
- generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
- ),
- )
+ with OmniRunner(model_name, enforce_eager=True, **extra_kwargs) as runner:
+ current_omni_platform.reset_peak_memory_stats()
+ outputs = runner.omni.generate(
+ "a photo of a cat sitting on a laptop keyboard",
+ OmniDiffusionSamplingParams(
+ height=HEIGHT,
+ width=WIDTH,
+ num_inference_steps=NUM_STEPS,
+ guidance_scale=0.0,
+ generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
+ ),
+ )
peak = monitor.peak_used_mb
monitor.stop()
@@ -74,7 +63,6 @@ def _generate_image(model_name: str, **extra_kwargs) -> tuple[list, float]:
assert isinstance(req_out, OmniRequestOutput) and hasattr(req_out, "images")
images = req_out.images
- del m
gc.collect()
current_omni_platform.empty_cache()
diff --git a/tests/e2e/offline_inference/test_flux_kontext.py b/tests/e2e/offline_inference/test_flux_kontext.py
index 93dca21c9ad..cd711d6b818 100644
--- a/tests/e2e/offline_inference/test_flux_kontext.py
+++ b/tests/e2e/offline_inference/test_flux_kontext.py
@@ -9,23 +9,14 @@
- Image editing with text guidance
"""
-import os
-import sys
-from pathlib import Path
-
import pytest
from PIL import Image
+from vllm.assets.image import ImageAsset
+from tests.conftest import OmniRunner
from vllm_omni.diffusion.data import DiffusionParallelConfig
-from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
-
MODEL = "black-forest-labs/FLUX.1-Kontext-dev"
@@ -33,17 +24,15 @@
@pytest.mark.diffusion
def test_flux_kontext_text_to_image():
"""Test FluxKontext text-to-image generation with real model."""
- omni = Omni(
- model=MODEL,
+ with OmniRunner(
+ MODEL,
parallel_config=DiffusionParallelConfig(
tensor_parallel_size=2,
),
enable_cpu_offload=False,
- )
-
- try:
+ ) as runner:
omni_outputs = list(
- omni.generate(
+ runner.omni.generate(
prompts=["A photo of a cat sitting on a laptop"],
sampling_params_list=OmniDiffusionSamplingParams(
height=512,
@@ -54,43 +43,37 @@ def test_flux_kontext_text_to_image():
)
)
- assert len(omni_outputs) > 0
- output = omni_outputs[0]
- images = None
- if output.images:
- images = output.images
- elif hasattr(output, "request_output") and output.request_output:
- for stage_out in output.request_output:
- if hasattr(stage_out, "images") and stage_out.images:
- images = stage_out.images
- break
+ assert len(omni_outputs) > 0
+ output = omni_outputs[0]
+ images = None
+ if output.images:
+ images = output.images
+ elif hasattr(output, "request_output") and output.request_output:
+ for stage_out in output.request_output:
+ if hasattr(stage_out, "images") and stage_out.images:
+ images = stage_out.images
+ break
- assert images is not None
- assert len(images) > 0
- assert isinstance(images[0], Image.Image)
- assert images[0].size == (512, 512)
- finally:
- omni.close()
+ assert images is not None
+ assert len(images) > 0
+ assert isinstance(images[0], Image.Image)
+ assert images[0].size == (512, 512)
@pytest.mark.core_model
@pytest.mark.diffusion
def test_flux_kontext_image_edit():
"""Test FluxKontext image-to-image editing with real model."""
- from vllm.assets.image import ImageAsset
-
input_image = ImageAsset("2560px-Gfp-wisconsin-madison-the-nature-boardwalk").pil_image.convert("RGB")
- omni = Omni(
- model=MODEL,
+ with OmniRunner(
+ MODEL,
parallel_config=DiffusionParallelConfig(
tensor_parallel_size=2,
),
enable_cpu_offload=False,
- )
-
- try:
+ ) as runner:
omni_outputs = list(
- omni.generate(
+ runner.omni.generate(
prompts=[
{
"prompt": "Transform this image into a Vincent van Gogh style painting",
@@ -107,20 +90,18 @@ def test_flux_kontext_image_edit():
)
)
- assert len(omni_outputs) > 0
- output = omni_outputs[0]
- images = None
- if output.images:
- images = output.images
- elif hasattr(output, "request_output") and output.request_output:
- for stage_out in output.request_output:
- if hasattr(stage_out, "images") and stage_out.images:
- images = stage_out.images
- break
-
- assert images is not None
- assert len(images) > 0
- assert isinstance(images[0], Image.Image)
- assert images[0].size == (512, 512)
- finally:
- omni.close()
+ assert len(omni_outputs) > 0
+ output = omni_outputs[0]
+ images = None
+ if output.images:
+ images = output.images
+ elif hasattr(output, "request_output") and output.request_output:
+ for stage_out in output.request_output:
+ if hasattr(stage_out, "images") and stage_out.images:
+ images = stage_out.images
+ break
+
+ assert images is not None
+ assert len(images) > 0
+ assert isinstance(images[0], Image.Image)
+ assert images[0].size == (512, 512)
diff --git a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py b/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
index 5522f33eaa7..79bb64dca1b 100644
--- a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
+++ b/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
@@ -8,6 +8,7 @@
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
+from tests.conftest import OmniRunner
from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
@@ -271,16 +272,11 @@ def clip_bundle() -> tuple[CLIPModel, CLIPProcessor]:
@pytest.fixture(scope="module")
def omni() -> Generator[Omni, None, None]:
- engine = Omni(
- model=MODEL_NAME,
+ with OmniRunner(
+ MODEL_NAME,
stage_configs_path=str(STAGE_CONFIG_PATH),
- stage_init_timeout=600,
- init_timeout=900,
- )
- try:
- yield engine
- finally:
- engine.close()
+ ) as runner:
+ yield runner.omni
def _extract_generated_image(outputs: list[object]) -> Image.Image:
diff --git a/tests/e2e/offline_inference/test_magi_human.py b/tests/e2e/offline_inference/test_magi_human.py
index 8648216a92f..abb7f9c163c 100644
--- a/tests/e2e/offline_inference/test_magi_human.py
+++ b/tests/e2e/offline_inference/test_magi_human.py
@@ -8,9 +8,9 @@
import numpy as np
import pytest
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes
-from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -49,12 +49,6 @@ def test_magi_human_e2e(run_level):
model_path = "SII-GAIR/daVinci-MagiHuman-Base-1080p"
- omni = Omni(
- model=model_path,
- init_timeout=1200,
- tensor_parallel_size=2,
- )
-
prompt = (
"A young woman with long, wavy golden blonde hair and bright blue eyes, "
"wearing a fitted ivory silk blouse with a delicate lace collar, sits "
@@ -94,7 +88,12 @@ def test_magi_human_e2e(run_level):
},
)
- try:
+ with OmniRunner(
+ model_path,
+ init_timeout=1200,
+ tensor_parallel_size=2,
+ ) as runner:
+ omni = runner.omni
outputs = list(
omni.generate(
prompts=[prompt],
@@ -140,5 +139,3 @@ def test_magi_human_e2e(run_level):
assert len(video_bytes) > 1000, f"MP4 too small ({len(video_bytes)} bytes)"
_validate_mp4(video_bytes)
- finally:
- omni.close()
diff --git a/tests/e2e/offline_inference/test_mammoth_moda2.py b/tests/e2e/offline_inference/test_mammoth_moda2.py
index 5293b5ed1b7..ff744c86e1e 100644
--- a/tests/e2e/offline_inference/test_mammoth_moda2.py
+++ b/tests/e2e/offline_inference/test_mammoth_moda2.py
@@ -23,10 +23,9 @@
import torch
from vllm.sampling_params import SamplingParams
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
-
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
@@ -116,8 +115,6 @@ def test_mammothmoda2_t2i_e2e():
- A fixed set of pixel values matches a golden reference
(regenerate with ``UPDATE_GOLDEN=1``).
"""
- from vllm_omni import Omni
-
if not Path(MODEL_PATH).exists():
pytest.skip(f"Model weights not found at {MODEL_PATH}")
if not Path(T2I_STAGE_CONFIG).exists():
@@ -135,8 +132,8 @@ def test_mammothmoda2_t2i_e2e():
prompt_text = "A cat sitting on a laptop keyboard"
formatted_prompt = _format_t2i_prompt(prompt_text, ar_width, ar_height)
- omni = Omni(model=MODEL_PATH, stage_configs_path=T2I_STAGE_CONFIG, trust_remote_code=True)
- try:
+ with OmniRunner(MODEL_PATH, stage_configs_path=T2I_STAGE_CONFIG, trust_remote_code=True) as runner:
+ omni = runner.omni
# Greedy / deterministic sampling so pixel values are reproducible.
ar_sampling = SamplingParams(
temperature=0.0,
@@ -211,5 +208,3 @@ def test_mammothmoda2_t2i_e2e():
found_image = True
assert found_image, "No image tensor found in pipeline output"
- finally:
- omni.close()
diff --git a/tests/e2e/offline_inference/test_omnivoice.py b/tests/e2e/offline_inference/test_omnivoice.py
index 4b093e357d9..bb4c8a5dd7e 100644
--- a/tests/e2e/offline_inference/test_omnivoice.py
+++ b/tests/e2e/offline_inference/test_omnivoice.py
@@ -16,6 +16,7 @@
import numpy as np
import pytest
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
MODEL = "k2-fsa/OmniVoice"
@@ -37,48 +38,42 @@ def test_omnivoice_text_to_audio() -> None:
Input Modal: text
Output Modal: audio
"""
- from vllm_omni.entrypoints.omni import Omni
+ from vllm_omni.inputs.data import OmniDiffusionSamplingParams
- omni = Omni(
- model=MODEL,
+ with OmniRunner(
+ MODEL,
stage_configs_path=get_stage_config(),
trust_remote_code=True,
log_stats=True,
- )
-
- try:
+ ) as runner:
prompts = {"prompt": "Hello, this is a test for text to audio."}
- from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
sampling_params_list = [OmniDiffusionSamplingParams()]
- outputs = list(omni.generate(prompts, sampling_params_list=sampling_params_list))
+ outputs = list(runner.omni.generate(prompts, sampling_params_list=sampling_params_list))
- assert len(outputs) > 0, "No outputs generated"
+ assert len(outputs) > 0, "No outputs generated"
- # Check final output has audio
- final_output = outputs[-1]
- ro = final_output.request_output
- assert ro is not None, "No request_output"
+ # Check final output has audio
+ final_output = outputs[-1]
+ ro = final_output.request_output
+ assert ro is not None, "No request_output"
- mm = getattr(ro, "multimodal_output", None)
- if not mm and ro.outputs:
- mm = getattr(ro.outputs[0], "multimodal_output", None)
+ mm = getattr(ro, "multimodal_output", None)
+ if not mm and ro.outputs:
+ mm = getattr(ro.outputs[0], "multimodal_output", None)
- assert mm is not None, "No multimodal_output"
- assert "audio" in mm, f"No 'audio' key in multimodal_output: {mm.keys()}"
+ assert mm is not None, "No multimodal_output"
+ assert "audio" in mm, f"No 'audio' key in multimodal_output: {mm.keys()}"
- audio = mm["audio"]
- if isinstance(audio, np.ndarray):
- audio_np = audio
- else:
- audio_np = audio.cpu().numpy().squeeze()
+ audio = mm["audio"]
+ if isinstance(audio, np.ndarray):
+ audio_np = audio
+ else:
+ audio_np = audio.cpu().numpy().squeeze()
- assert audio_np.size > 0, "Audio output is empty"
- rms = np.sqrt(np.mean(audio_np**2))
- assert rms > 0.01, f"Audio RMS too low ({rms:.4f}), likely silence"
+ assert audio_np.size > 0, "Audio output is empty"
+ rms = np.sqrt(np.mean(audio_np**2))
+ assert rms > 0.01, f"Audio RMS too low ({rms:.4f}), likely silence"
- print(f"Generated audio: {len(audio_np) / 24000:.2f}s, rms={rms:.4f}")
- finally:
- omni.close()
+ print(f"Generated audio: {len(audio_np) / 24000:.2f}s, rms={rms:.4f}")
diff --git a/tests/e2e/offline_inference/test_quantization_fp8.py b/tests/e2e/offline_inference/test_quantization_fp8.py
index f71c53de74c..291779fd931 100644
--- a/tests/e2e/offline_inference/test_quantization_fp8.py
+++ b/tests/e2e/offline_inference/test_quantization_fp8.py
@@ -29,7 +29,6 @@
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
from pathlib import Path
from typing import Any
@@ -37,8 +36,8 @@
import pytest
import torch
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
-from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
@@ -61,16 +60,15 @@ def _generate_single_stage_image(
Returns (images, peak_memory_gib).
"""
- omni_kwargs: dict[str, Any] = {"model": model, **extra_omni_kwargs}
+ omni_kwargs: dict[str, Any] = dict(extra_omni_kwargs)
if quantization:
omni_kwargs["quantization"] = quantization
- omni = Omni(**omni_kwargs)
- try:
+ with OmniRunner(model, **omni_kwargs) as runner:
torch.cuda.reset_peak_memory_stats()
generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed)
- outputs = omni.generate(
+ outputs = runner.omni.generate(
"a photo of a cat sitting on a laptop keyboard",
OmniDiffusionSamplingParams(
height=height,
@@ -94,8 +92,6 @@ def _generate_single_stage_image(
assert images[0].height == height
return images, peak_mem
- finally:
- omni.close()
def _generate_bagel_image(
@@ -115,8 +111,9 @@ def _generate_bagel_image(
if quantization_config:
omni_kwargs["quantization_config"] = quantization_config
- omni = Omni(**omni_kwargs)
- try:
+ model_name = omni_kwargs.pop("model")
+ with OmniRunner(model_name, **omni_kwargs) as runner:
+ omni = runner.omni
torch.cuda.reset_peak_memory_stats()
params_list = omni.default_sampling_params_list
@@ -168,8 +165,6 @@ def _generate_bagel_image(
)
return generated_image, peak_mem
- finally:
- omni.close()
# ─── Single-stage diffusion model tests ──────────────────────────────────────
diff --git a/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py b/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py
index d5f82f893e6..f0b0b55c9f6 100644
--- a/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py
+++ b/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py
@@ -28,7 +28,6 @@
import argparse
import asyncio
-import os
import sys
import time
import uuid
@@ -37,6 +36,7 @@
import pytest
import torch
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -48,9 +48,6 @@
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
-from vllm_omni import Omni
-
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
# ------------------------------------------------------------------
models = ["tiny-random/Qwen-Image"]
@@ -391,31 +388,28 @@ async def main(model: str, num_prompts: int, mode: str, batch_size: int = 1) ->
def test_diffusion_batching_sync_sequential(model_name: str):
"""Test that synchronous Omni can generate images for multiple prompts
submitted sequentially (one at a time) and each returns a valid image."""
- m = None
try:
- m = Omni(model=model_name)
- sp = _default_sync_sampling_params()
- prompts = TEST_PROMPTS[:4]
+ with OmniRunner(model_name) as runner:
+ m = runner.omni
+ sp = _default_sync_sampling_params()
+ prompts = TEST_PROMPTS[:4]
- for i, prompt in enumerate(prompts):
- outputs = m.generate(prompt, sp)
- first_output = outputs[0]
- assert first_output.final_output_type == "image", (
- f"Expected 'image', got '{first_output.final_output_type}'"
- )
+ for i, prompt in enumerate(prompts):
+ outputs = m.generate(prompt, sp)
+ first_output = outputs[0]
+ assert first_output.final_output_type == "image", (
+ f"Expected 'image', got '{first_output.final_output_type}'"
+ )
- # Images are surfaced both at top-level and inside request_output
- images = _extract_images(first_output)
- assert len(images) >= 1, f"Expected at least 1 image for prompt {i}, got {len(images)}"
- assert images[0].width == 256
- assert images[0].height == 256
- print(f" prompt {i}: OK ({len(images)} images)")
+ # Images are surfaced both at top-level and inside request_output
+ images = _extract_images(first_output)
+ assert len(images) >= 1, f"Expected at least 1 image for prompt {i}, got {len(images)}"
+ assert images[0].width == 256
+ assert images[0].height == 256
+ print(f" prompt {i}: OK ({len(images)} images)")
except Exception as e:
print(f"Test failed with error: {e}")
raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
@pytest.mark.core_model
@@ -431,34 +425,31 @@ def test_diffusion_batching_sync_multi_prompt(model_name: str):
handling at the diffusion stage, not the explicit list-batch path
(which is only available via AsyncOmni).
"""
- m = None
try:
- m = Omni(model=model_name)
- sp = _default_sync_sampling_params()
- prompts = TEST_PROMPTS[:4]
+ with OmniRunner(model_name) as runner:
+ m = runner.omni
+ sp = _default_sync_sampling_params()
+ prompts = TEST_PROMPTS[:4]
- outputs = m.generate(prompts, sp)
- assert len(outputs) == len(prompts), f"Expected {len(prompts)} outputs, got {len(outputs)}"
+ outputs = m.generate(prompts, sp)
+ assert len(outputs) == len(prompts), f"Expected {len(prompts)} outputs, got {len(outputs)}"
- for i, output in enumerate(outputs):
- assert output.final_output_type == "image", (
- f"Output {i} final_output_type expected 'image', got '{output.final_output_type}'"
- )
- images = _extract_images(output)
- assert images and len(images) >= 1, f"Expected at least 1 image for prompt {i}"
- assert images[0].width == 256
- assert images[0].height == 256
- print(f" prompt {i}: OK ({len(images)} images, request_id={output.request_id})")
-
- # Verify all request_ids are distinct
- request_ids = [o.request_id for o in outputs]
- assert len(set(request_ids)) == len(request_ids), f"Duplicate request_ids found: {request_ids}"
+ for i, output in enumerate(outputs):
+ assert output.final_output_type == "image", (
+ f"Output {i} final_output_type expected 'image', got '{output.final_output_type}'"
+ )
+ images = _extract_images(output)
+ assert images and len(images) >= 1, f"Expected at least 1 image for prompt {i}"
+ assert images[0].width == 256
+ assert images[0].height == 256
+ print(f" prompt {i}: OK ({len(images)} images, request_id={output.request_id})")
+
+ # Verify all request_ids are distinct
+ request_ids = [o.request_id for o in outputs]
+ assert len(set(request_ids)) == len(request_ids), f"Duplicate request_ids found: {request_ids}"
except Exception as e:
print(f"Test failed with error: {e}")
raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
@pytest.mark.core_model
@@ -552,32 +543,29 @@ async def _inner():
def test_diffusion_batching_num_outputs(model_name: str):
"""Test that the diffusion model respects num_outputs_per_prompt and
generates the correct number of images per request."""
- m = None
try:
- m = Omni(model=model_name)
- num_outputs = 2
- sp = _default_sync_sampling_params(num_outputs_per_prompt=num_outputs)
-
- outputs = m.generate(
- "a photo of a cat sitting on a laptop keyboard",
- sp,
- )
+ with OmniRunner(model_name) as runner:
+ m = runner.omni
+ num_outputs = 2
+ sp = _default_sync_sampling_params(num_outputs_per_prompt=num_outputs)
+
+ outputs = m.generate(
+ "a photo of a cat sitting on a laptop keyboard",
+ sp,
+ )
- first_output = outputs[0]
- assert first_output.final_output_type == "image"
- images = _extract_images(first_output)
- assert images is not None and len(images) == num_outputs, (
- f"Expected {num_outputs} images, got {len(images) if images else 0}"
- )
- for img in images:
- assert img.width == 256
- assert img.height == 256
+ first_output = outputs[0]
+ assert first_output.final_output_type == "image"
+ images = _extract_images(first_output)
+ assert images is not None and len(images) == num_outputs, (
+ f"Expected {num_outputs} images, got {len(images) if images else 0}"
+ )
+ for img in images:
+ assert img.width == 256
+ assert img.height == 256
except Exception as e:
print(f"Test failed with error: {e}")
raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
@pytest.mark.core_model
@@ -587,34 +575,31 @@ def test_diffusion_batching_num_outputs(model_name: str):
def test_diffusion_batching_distinct_results(model_name: str):
"""Test that different prompts produce distinct images when batched,
ensuring the batching logic does not mix up results across requests."""
- m = None
try:
- m = Omni(model=model_name)
- sp = _default_sync_sampling_params()
- prompts = [
- {"prompt": "a bright red apple on a white table", "negative_prompt": "blurry"},
- {"prompt": "a blue ocean with white waves crashing", "negative_prompt": "blurry"},
- ]
-
- outputs = m.generate(prompts, sp)
- assert len(outputs) == len(prompts), f"Expected {len(prompts)} outputs, got {len(outputs)}"
-
- # Verify each output has a unique request_id
- request_ids = [o.request_id for o in outputs]
- assert len(set(request_ids)) == len(request_ids), f"Duplicate request_ids: {request_ids}"
-
- # Verify each output has images
- for i, output in enumerate(outputs):
- images = _extract_images(output)
- assert images and len(images) >= 1, f"No images for prompt {i}"
- assert images[0].width == 256
- assert images[0].height == 256
+ with OmniRunner(model_name) as runner:
+ m = runner.omni
+ sp = _default_sync_sampling_params()
+ prompts = [
+ {"prompt": "a bright red apple on a white table", "negative_prompt": "blurry"},
+ {"prompt": "a blue ocean with white waves crashing", "negative_prompt": "blurry"},
+ ]
+
+ outputs = m.generate(prompts, sp)
+ assert len(outputs) == len(prompts), f"Expected {len(prompts)} outputs, got {len(outputs)}"
+
+ # Verify each output has a unique request_id
+ request_ids = [o.request_id for o in outputs]
+ assert len(set(request_ids)) == len(request_ids), f"Duplicate request_ids: {request_ids}"
+
+ # Verify each output has images
+ for i, output in enumerate(outputs):
+ images = _extract_images(output)
+ assert images and len(images) >= 1, f"No images for prompt {i}"
+ assert images[0].width == 256
+ assert images[0].height == 256
except Exception as e:
print(f"Test failed with error: {e}")
raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
# ------------------------------------------------------------------
diff --git a/tests/e2e/offline_inference/test_sequence_parallel.py b/tests/e2e/offline_inference/test_sequence_parallel.py
index 16239a1c52f..d3abccd78cf 100644
--- a/tests/e2e/offline_inference/test_sequence_parallel.py
+++ b/tests/e2e/offline_inference/test_sequence_parallel.py
@@ -20,8 +20,8 @@
import torch.distributed as dist
from PIL import Image
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
-from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
@@ -92,49 +92,48 @@ def _run_inference(
warmup: If True, run one warmup iteration before the timed run.
"""
parallel_config = DiffusionParallelConfig(ulysses_degree=ulysses_degree, ring_degree=ring_degree)
- omni = Omni(
- model=model_name,
- parallel_config=parallel_config,
- dtype=dtype,
- attention_backend=attn_backend,
- )
-
try:
- # Warmup run (not timed)
- if warmup:
- _ = omni.generate(
+ with OmniRunner(
+ model_name,
+ parallel_config=parallel_config,
+ dtype=dtype,
+ attention_backend=attn_backend,
+ ) as runner:
+ omni = runner.omni
+ # Warmup run (not timed)
+ if warmup:
+ _ = omni.generate(
+ PROMPT,
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=DEFAULT_STEPS,
+ guidance_scale=0.0,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed + 1000),
+ num_outputs_per_prompt=1,
+ ),
+ )
+
+ # Timed run
+ start = time.time()
+ outputs = omni.generate(
PROMPT,
OmniDiffusionSamplingParams(
height=height,
width=width,
num_inference_steps=DEFAULT_STEPS,
guidance_scale=0.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed + 1000),
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
num_outputs_per_prompt=1,
),
)
+ elapsed_ms = (time.time() - start) * 1000
- # Timed run
- start = time.time()
- outputs = omni.generate(
- PROMPT,
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=DEFAULT_STEPS,
- guidance_scale=0.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
- num_outputs_per_prompt=1,
- ),
- )
- elapsed_ms = (time.time() - start) * 1000
-
- return InferenceResult(
- images=outputs[0].request_output.images,
- elapsed_ms=elapsed_ms,
- )
+ return InferenceResult(
+ images=outputs[0].request_output.images,
+ elapsed_ms=elapsed_ms,
+ )
finally:
- omni.close()
_cleanup_distributed()
diff --git a/tests/e2e/offline_inference/test_stable_audio_model.py b/tests/e2e/offline_inference/test_stable_audio_model.py
index ff4d9b40172..21d75aad52a 100644
--- a/tests/e2e/offline_inference/test_stable_audio_model.py
+++ b/tests/e2e/offline_inference/test_stable_audio_model.py
@@ -1,6 +1,3 @@
-import sys
-from pathlib import Path
-
import numpy as np
import pytest
import torch
@@ -10,31 +7,25 @@
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
-
# Use random weights model for CI testing (small, no authentication required)
models = ["linyueqian/stable_audio_random"]
+# omni_runner expects (model, stage_configs_path); single-stage diffusion has no YAML.
+test_params = [(m, None) for m in models]
+
@pytest.mark.core_model
@pytest.mark.diffusion
@hardware_test(res={"cuda": "L4", "xpu": "B60"})
-@pytest.mark.parametrize("model_name", models)
-def test_stable_audio_model(model_name: str):
- m = Omni(model=model_name)
-
+@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
+def test_stable_audio_model(omni_runner):
# Use minimal settings for testing
# Generate a short 2-second audio clip with minimal inference steps
audio_start_in_s = 0.0
audio_end_in_s = 2.0 # Short duration for fast testing
sample_rate = 44100 # Stable Audio uses 44100 Hz
- outputs = m.generate(
+ outputs = omni_runner.omni.generate(
prompts={
"prompt": "The sound of a dog barking",
"negative_prompt": "Low quality.",
diff --git a/tests/e2e/offline_inference/test_t2i_model.py b/tests/e2e/offline_inference/test_t2i_model.py
index 55a154f61b9..fc54f9a7ff1 100644
--- a/tests/e2e/offline_inference/test_t2i_model.py
+++ b/tests/e2e/offline_inference/test_t2i_model.py
@@ -1,7 +1,3 @@
-import os
-import sys
-from pathlib import Path
-
import pytest
import torch
@@ -10,14 +6,12 @@
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
+# Match unprefixed HF id even when MODEL_PREFIX is set (omni_runner resolves full path).
+_QWEN_IMAGE_RANDOM_ID = "riverclouds/qwen_image_random"
-from vllm_omni import Omni
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
+def _is_qwen_image_random(model_path: str) -> bool:
+ return model_path.rstrip("/").endswith(_QWEN_IMAGE_RANDOM_ID)
models = ["Tongyi-MAI/Z-Image-Turbo", "riverclouds/qwen_image_random"]
@@ -27,56 +21,55 @@
if current_omni_platform.is_npu():
models = ["Tongyi-MAI/Z-Image-Turbo", "Qwen/Qwen-Image"]
+# omni_runner expects (model, stage_configs_path); single-stage diffusion has no YAML.
+test_params = [(m, None) for m in models]
+
@pytest.mark.core_model
@pytest.mark.advanced_model
@pytest.mark.diffusion
@hardware_test(res={"cuda": "L4", "rocm": "MI325", "xpu": "B60"}, num_cards={"cuda": 1, "rocm": 1, "xpu": 2})
-@pytest.mark.parametrize("model_name", models)
-def test_diffusion_model(model_name: str, run_level):
- if run_level == "core_model" and model_name != "riverclouds/qwen_image_random":
+@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
+def test_diffusion_model(omni_runner, run_level):
+ resolved = omni_runner.model_name
+ if run_level == "core_model" and not _is_qwen_image_random(resolved):
pytest.skip()
- if run_level == "advanced_model" and model_name == "riverclouds/qwen_image_random":
+ if run_level == "advanced_model" and _is_qwen_image_random(resolved):
pytest.skip()
- m = None
- try:
- m = Omni(model=model_name)
- # high resolution may cause OOM on L4
- height = 256
- width = 256
- outputs = m.generate(
- "a photo of a cat sitting on a laptop keyboard",
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=2,
- guidance_scale=0.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
- num_outputs_per_prompt=2,
- ),
- )
- # Extract images from request_output['images']
- first_output = outputs[0]
- assert first_output.final_output_type == "image"
- if not hasattr(first_output, "request_output") or not first_output.request_output:
- raise ValueError("No request_output found in OmniRequestOutput")
-
- req_out = first_output.request_output
- if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
- raise ValueError("Invalid request_output structure or missing 'images' key")
-
- images = req_out.images
-
- assert len(images) == 2
- # check image size
- assert images[0].width == width
- assert images[0].height == height
- images[0].save("image_output.png")
- except Exception as e:
- print(f"Test failed with error: {e}")
- raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
+ # high resolution may cause OOM on L4
+ height = 256
+ width = 256
+ sampling = OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=2,
+ guidance_scale=0.0,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
+ num_outputs_per_prompt=2,
+ )
+
+ # OmniRunner.generate() is typed for list[TextPrompt]; diffusion uses Omni.generate(str, ...).
+ outputs = omni_runner.omni.generate(
+ "a photo of a cat sitting on a laptop keyboard",
+ sampling,
+ )
+
+ # Extract images from request_output['images']
+ first_output = outputs[0]
+ assert first_output.final_output_type == "image"
+ if not hasattr(first_output, "request_output") or not first_output.request_output:
+ raise ValueError("No request_output found in OmniRequestOutput")
+
+ req_out = first_output.request_output
+ if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
+ raise ValueError("Invalid request_output structure or missing 'images' key")
+
+ images = req_out.images
+
+ assert len(images) == 2
+ # check image size
+ assert images[0].width == width
+ assert images[0].height == height
+ images[0].save("image_output.png")
diff --git a/tests/e2e/offline_inference/test_t2v_model.py b/tests/e2e/offline_inference/test_t2v_model.py
index 94c9dedf741..6fe623cfc82 100644
--- a/tests/e2e/offline_inference/test_t2v_model.py
+++ b/tests/e2e/offline_inference/test_t2v_model.py
@@ -1,22 +1,13 @@
import os
-import sys
-from pathlib import Path
import pytest
import torch
+from tests.conftest import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
-# os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
models = ["Wan-AI/Wan2.2-T2V-A14B-Diffusers"]
@@ -24,28 +15,28 @@
@pytest.mark.parametrize("model_name", models)
def test_video_diffusion_model(model_name: str):
- m = Omni(
- model=model_name,
+ with OmniRunner(
+ model_name,
boundary_ratio=0.875,
flow_shift=5.0,
- )
- # Use minimal settings for testing
- # num_frames must satisfy: num_frames % vae_scale_factor_temporal == 1
- # For Wan2.2, vae_scale_factor_temporal=4, so valid values are 5, 9, 13, 17, ...
- height = 480
- width = 640
- num_frames = 5
- outputs = m.generate(
- prompts="A cat sitting on a table",
- sampling_params_list=OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_frames=num_frames,
- num_inference_steps=2,
- guidance_scale=1.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
- ),
- )
+ ) as runner:
+ # Use minimal settings for testing
+ # num_frames must satisfy: num_frames % vae_scale_factor_temporal == 1
+ # For Wan2.2, vae_scale_factor_temporal=4, so valid values are 5, 9, 13, 17, ...
+ height = 480
+ width = 640
+ num_frames = 5
+ outputs = runner.omni.generate(
+ prompts="A cat sitting on a table",
+ sampling_params_list=OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ num_inference_steps=2,
+ guidance_scale=1.0,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
+ ),
+ )
first_output = outputs[0]
assert first_output.final_output_type == "image"
if not hasattr(first_output, "request_output") or not first_output.request_output:
diff --git a/tests/e2e/offline_inference/test_teacache.py b/tests/e2e/offline_inference/test_teacache.py
index efc0e43e86f..7cd1c5a4797 100644
--- a/tests/e2e/offline_inference/test_teacache.py
+++ b/tests/e2e/offline_inference/test_teacache.py
@@ -8,26 +8,14 @@
It uses minimal settings to keep test time short for CI.
"""
-import os
-import sys
-from pathlib import Path
-
import pytest
import torch
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.platforms import current_omni_platform
-
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
from vllm_omni.outputs import OmniRequestOutput
-
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
+from vllm_omni.platforms import current_omni_platform
# Use random weights model for testing
models = ["riverclouds/qwen_image_random"]
@@ -44,20 +32,17 @@ def test_teacache(model_name: str):
cache_config = {
"rel_l1_thresh": 0.2, # Default threshold
}
- m = None
- try:
- m = Omni(
- model=model_name,
- cache_backend="tea_cache",
- cache_config=cache_config,
- )
-
+ with OmniRunner(
+ model_name,
+ cache_backend="tea_cache",
+ cache_config=cache_config,
+ ) as runner:
# Use minimal settings for fast testing
height = 256
width = 256
num_inference_steps = 4 # Minimal steps for fast test
- outputs = m.generate(
+ outputs = runner.omni.generate(
"a photo of a cat sitting on a laptop keyboard",
OmniDiffusionSamplingParams(
height=height,
@@ -86,9 +71,3 @@ def test_teacache(model_name: str):
# Check image size
assert images[0].width == width
assert images[0].height == height
- except Exception as e:
- print(f"Test failed with error: {e}")
- raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
diff --git a/tests/e2e/offline_inference/test_vae_decode_parallelism.py b/tests/e2e/offline_inference/test_vae_decode_parallelism.py
index cee76fac2e9..0fce28d6692 100644
--- a/tests/e2e/offline_inference/test_vae_decode_parallelism.py
+++ b/tests/e2e/offline_inference/test_vae_decode_parallelism.py
@@ -18,7 +18,7 @@
import time
-from vllm_omni import Omni
+from tests.conftest import OmniRunner
from vllm_omni.platforms import current_omni_platform
# os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
@@ -72,23 +72,22 @@ def is_nextstep_model(model_name: str) -> bool:
def model_run(model_configs, tp, out_height, out_width, out_frames, using_tile, vae_patch_parallel_size=1):
- m = None
- try:
- parallel_config = DiffusionParallelConfig(
- tensor_parallel_size=tp,
- vae_patch_parallel_size=vae_patch_parallel_size,
- )
+ parallel_config = DiffusionParallelConfig(
+ tensor_parallel_size=tp,
+ vae_patch_parallel_size=vae_patch_parallel_size,
+ )
- omni_kwargs = {
- "model": model_configs["model_name"],
- "vae_use_tiling": using_tile,
- "parallel_config": parallel_config,
- }
- use_nextstep = is_nextstep_model(model_configs["model_name"])
- if use_nextstep:
- # NextStep-1.1 requires explicit pipeline class
- omni_kwargs["model_class_name"] = "NextStep11Pipeline"
- m = Omni(**omni_kwargs)
+ omni_kwargs = {
+ "vae_use_tiling": using_tile,
+ "parallel_config": parallel_config,
+ }
+ use_nextstep = is_nextstep_model(model_configs["model_name"])
+ if use_nextstep:
+ # NextStep-1.1 requires explicit pipeline class
+ omni_kwargs["model_class_name"] = "NextStep11Pipeline"
+
+ with OmniRunner(model_configs["model_name"], **omni_kwargs) as runner:
+ m = runner.omni
image = Image.new("RGB", (out_width, out_height), (0, 0, 0))
start = time.perf_counter()
outputs = m.generate(
@@ -115,9 +114,6 @@ def model_run(model_configs, tp, out_height, out_width, out_frames, using_tile,
# frames shape: (batch, num_frames, height, width, channels)
cost = (end - start) * 1000
return frames, cost
- finally:
- if m is not None:
- m.close()
cleanup_dist_env_and_memory()
diff --git a/tests/e2e/offline_inference/test_voxcpm2.py b/tests/e2e/offline_inference/test_voxcpm2.py
index 7e17c6a3691..4e4f635d5c4 100644
--- a/tests/e2e/offline_inference/test_voxcpm2.py
+++ b/tests/e2e/offline_inference/test_voxcpm2.py
@@ -5,6 +5,7 @@
import pytest
import torch
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
VOXCPM2_MODEL = "openbmb/VoxCPM2"
@@ -24,10 +25,8 @@
@pytest.fixture(scope="module")
def voxcpm2_engine():
"""Create VoxCPM2 engine for testing."""
- from vllm_omni import Omni
-
- engine = Omni(model=VOXCPM2_MODEL, stage_configs_path=STAGE_CONFIG)
- yield engine
+ with OmniRunner(VOXCPM2_MODEL, stage_configs_path=STAGE_CONFIG) as runner:
+ yield runner.omni
def _extract_audio(multimodal_output: dict) -> torch.Tensor:
diff --git a/tests/e2e/offline_inference/test_voxtral_tts.py b/tests/e2e/offline_inference/test_voxtral_tts.py
index b559cc252dc..4f440f243bf 100644
--- a/tests/e2e/offline_inference/test_voxtral_tts.py
+++ b/tests/e2e/offline_inference/test_voxtral_tts.py
@@ -19,7 +19,6 @@
import uuid
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
from pathlib import Path
@@ -30,10 +29,9 @@
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from vllm import SamplingParams
-from tests.conftest import modify_stage_config
+from tests.conftest import OmniRunner, modify_stage_config
from tests.utils import hardware_test
from vllm_omni.entrypoints.async_omni import AsyncOmni
-from vllm_omni.entrypoints.omni import Omni
MODEL = "mistralai/Voxtral-4B-TTS-2603"
STAGE_CONFIG = str(
@@ -83,14 +81,12 @@ def test_voxtral_tts_offline_basic(run_level):
"""Test basic Voxtral TTS offline inference with a voice preset."""
stage_config = _resolve_stage_config(run_level)
- omni = Omni(
- model=MODEL,
+ with OmniRunner(
+ MODEL,
stage_configs_path=stage_config,
- stage_init_timeout=300,
enforce_eager=True,
- )
-
- try:
+ ) as runner:
+ omni = runner.omni
inputs = _compose_request(MODEL, TEST_TEXT, VOICE)
sampling_params = SamplingParams(max_tokens=2500)
@@ -127,9 +123,6 @@ def test_voxtral_tts_offline_basic(run_level):
# Verify audio isn't all zeros / silence
assert np.max(np.abs(audio_array)) > 0.01, "Audio appears to be silence"
- finally:
- omni.close()
-
@pytest.mark.advanced_model
@pytest.mark.omni
diff --git a/tests/e2e/offline_inference/test_zimage_parallelism.py b/tests/e2e/offline_inference/test_zimage_parallelism.py
index b685704ae4b..27edc48f205 100644
--- a/tests/e2e/offline_inference/test_zimage_parallelism.py
+++ b/tests/e2e/offline_inference/test_zimage_parallelism.py
@@ -12,7 +12,6 @@
"""
import os
-import sys
import time
from pathlib import Path
@@ -20,21 +19,14 @@
import pytest
import torch
from PIL import Image
-from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
+from tests.conftest import OmniRunner
from tests.utils import DeviceMemoryMonitor, hardware_test
-from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
PROMPT = "a photo of a cat sitting on a laptop keyboard"
@@ -97,61 +89,61 @@ def _run_zimage_generate(
device_index = current_omni_platform.current_device()
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
- m = Omni(
- model=_get_zimage_model(),
- parallel_config=DiffusionParallelConfig(
- tensor_parallel_size=tp_size,
- vae_patch_parallel_size=vae_patch_parallel_size,
- ),
- enforce_eager=enforce_eager,
- vae_use_tiling=vae_use_tiling,
- )
try:
- # NOTE: Omni closes itself when a generate() call is exhausted.
- # To avoid measuring teardown time (process shutdown, memory cleanup),
- # we measure the latency to produce *subsequent* outputs within a single
- # generator run.
- #
- # This also serves as a warmup: the first output may include extra
- # compilation/caching overhead, while later outputs are closer to
- # steady-state inference.
- gen = m.generate(
- [PROMPT] * num_requests,
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=num_inference_steps,
- guidance_scale=0.0,
- seed=seed,
- num_outputs_per_prompt=1,
+ # Each run needs a distinct DiffusionParallelConfig; use OmniRunner per call (not the
+ # parametrized omni_runner fixture, which is fixed per module).
+ with OmniRunner(
+ _get_zimage_model(),
+ parallel_config=DiffusionParallelConfig(
+ tensor_parallel_size=tp_size,
+ vae_patch_parallel_size=vae_patch_parallel_size,
),
- py_generator=True,
- )
-
- warmup_output = next(gen)
-
- t_prev = time.perf_counter()
- per_request_times_s: list[float] = []
- last_output = warmup_output
- for _ in range(num_requests - 1):
- last_output = next(gen)
- t_now = time.perf_counter()
- per_request_times_s.append(t_now - t_prev)
- t_prev = t_now
-
- # Ensure the generator is fully consumed so it can clean up.
- for _ in gen:
- pass
-
- median_time_s = float(np.median(per_request_times_s))
-
- peak_memory_mb = monitor.peak_used_mb
-
- return _extract_single_image([last_output]), median_time_s, peak_memory_mb
+ enforce_eager=enforce_eager,
+ vae_use_tiling=vae_use_tiling,
+ ) as runner:
+ # NOTE: Omni closes itself when a generate() call is exhausted.
+ # To avoid measuring teardown time (process shutdown, memory cleanup),
+ # we measure the latency to produce *subsequent* outputs within a single
+ # generator run.
+ #
+ # This also serves as a warmup: the first output may include extra
+ # compilation/caching overhead, while later outputs are closer to
+ # steady-state inference.
+ gen = runner.omni.generate(
+ [PROMPT] * num_requests,
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=0.0,
+ seed=seed,
+ num_outputs_per_prompt=1,
+ ),
+ py_generator=True,
+ )
+
+ warmup_output = next(gen)
+
+ t_prev = time.perf_counter()
+ per_request_times_s: list[float] = []
+ last_output = warmup_output
+ for _ in range(num_requests - 1):
+ last_output = next(gen)
+ t_now = time.perf_counter()
+ per_request_times_s.append(t_now - t_prev)
+ t_prev = t_now
+
+ # Ensure the generator is fully consumed so it can clean up.
+ for _ in gen:
+ pass
+
+ median_time_s = float(np.median(per_request_times_s))
+
+ peak_memory_mb = monitor.peak_used_mb
+
+ return _extract_single_image([last_output]), median_time_s, peak_memory_mb
finally:
monitor.stop()
- m.close()
- cleanup_dist_env_and_memory()
@pytest.mark.advanced_model
diff --git a/tests/e2e/online_serving/test_images_generations_lora.py b/tests/e2e/online_serving/test_images_generations_lora.py
index 8c826591a56..fb1e3ea1e0f 100644
--- a/tests/e2e/online_serving/test_images_generations_lora.py
+++ b/tests/e2e/online_serving/test_images_generations_lora.py
@@ -28,7 +28,7 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODEL = "Tongyi-MAI/Z-Image-Turbo"
-DIFFUSION_INIT_TIMEOUT_S = 700
+DIFFUSION_INIT_TIMEOUT_S = 900
PROMPT = "a photo of a cat sitting on a laptop keyboard"
From d369648e668b66ce6191003157fc5ad17dd67597 Mon Sep 17 00:00:00 2001
From: ZhengWG
Date: Mon, 13 Apr 2026 14:51:34 +0800
Subject: [PATCH 11/76] refactor: add stage_pool
Signed-off-by: ZhengWG
---
vllm_omni/engine/orchestrator.py | 81 ++++++++++++---------
vllm_omni/engine/stage_pool.py | 116 +++++++++++++++++++++++++++++++
2 files changed, 165 insertions(+), 32 deletions(-)
create mode 100644 vllm_omni/engine/stage_pool.py
diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py
index b79f88933ff..da16d3ca66c 100644
--- a/vllm_omni/engine/orchestrator.py
+++ b/vllm_omni/engine/orchestrator.py
@@ -28,6 +28,7 @@
OmniEngineCoreRequest,
)
from vllm_omni.engine.serialization import serialize_additional_information
+from vllm_omni.engine.stage_pool import StagePool, StageReplica, build_stage_pools
from vllm_omni.metrics.stats import StageRequestStats as StageRequestMetrics
from vllm_omni.metrics.stats import StageStats
from vllm_omni.metrics.utils import count_tokens_from_outputs
@@ -136,6 +137,12 @@ def __init__(
self.num_clients = len(stage_clients)
self.async_chunk = bool(async_chunk)
+ # Flat-list view: retained as a compatibility layer so existing call
+ # sites that index by flat client_index (metrics, shutdown, collective
+ # RPC fan-out, etc.) keep working. StagePool below is the canonical
+ # path for replica selection and should be preferred in new code.
+ # TODO(stage-pool): migrate remaining flat-list readers onto
+ # self.stage_pools and drop these attributes.
self.stage_clients: list[Any] = stage_clients
self.output_processors: list[Any] = output_processors
self.stage_vllm_configs: list[Any] = stage_vllm_configs
@@ -148,7 +155,16 @@ def __init__(
self.logical_stage_to_clients = [[i] for i in range(self.num_clients)]
self.num_logical_stages = len(self.logical_stage_to_clients)
- # Reverse mappings: client_index -> (logical_stage_id, replica_index)
+ # Canonical per-logical-stage replica container.
+ self.stage_pools: list[StagePool] = build_stage_pools(
+ stage_clients,
+ output_processors,
+ stage_vllm_configs,
+ self.logical_stage_to_clients,
+ )
+
+ # Reverse mappings: client_index -> (logical_stage_id, replica_index).
+ # Kept for metrics/shutdown log lines that index by flat client_index.
self._client_to_logical: list[int] = [0] * self.num_clients
self._client_to_replica: list[int] = [0] * self.num_clients
for logical_id, client_indices in enumerate(self.logical_stage_to_clients):
@@ -156,9 +172,6 @@ def __init__(
self._client_to_logical[ci] = logical_id
self._client_to_replica[ci] = ri
- # Round-robin counters for replica selection per logical stage
- self._replica_rr: list[int] = [0] * self.num_logical_stages
-
# Backward compat: num_stages now means num_logical_stages
self.num_stages = self.num_logical_stages
@@ -186,26 +199,14 @@ def _choose_client_index(
logical_stage_id: int,
req_state: OrchestratorRequestState,
) -> int:
- """Pick a client for *logical_stage_id* and record the choice.
+ """Pick a flat client_index for *logical_stage_id* via the stage pool.
- If this request already has a chosen client for the logical stage,
- return the existing one (affinity). Otherwise round-robin among the
- available replicas.
+ Thin wrapper that delegates to ``StagePool.select_replica`` so the
+ flat-index-based call sites keep working. New code should call the
+ pool directly when the StageReplica object itself is useful.
"""
- existing = req_state.chosen_client_index.get(logical_stage_id)
- if existing is not None:
- return existing
-
- candidates = self.logical_stage_to_clients[logical_stage_id]
- if len(candidates) == 1:
- chosen = candidates[0]
- else:
- rr = self._replica_rr[logical_stage_id]
- chosen = candidates[rr % len(candidates)]
- self._replica_rr[logical_stage_id] = rr + 1
-
- req_state.chosen_client_index[logical_stage_id] = chosen
- return chosen
+ replica = self.stage_pools[logical_stage_id].select_replica(req_state)
+ return replica.flat_index
def _resolve_client_index(self, stage_id: int, replica_index: int = 0) -> int:
"""Resolve (stage_id, replica_index) to a flat client index."""
@@ -251,7 +252,14 @@ async def run(self) -> None:
await asyncio.gather(*pending, return_exceptions=True)
async def _request_handler(self) -> None:
- """Read messages from the main thread via request_async_queue."""
+ """Read messages from the main thread via request_async_queue.
+
+ TODO(stage-pool): the while loop below has no top-level try/except, so
+ any unhandled exception inside a _handle_* coroutine kills this task
+ and leaves the orchestrator unable to consume further messages. Wrap
+ each dispatch in a per-message try/except so one bad request can't
+ wedge the whole engine.
+ """
while True:
msg = await self.request_async_queue.get()
msg_type = msg.get("type")
@@ -908,26 +916,35 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None:
)
self.request_states[companion_id] = companion_state
- # Use same replica as the parent for affinity, or choose one
+ # CFG companions must land on the same stage-0 replica as their
+ # parent so the diffusion stage can fetch both KV caches from a
+ # single device. Pass affinity_from explicitly; if the parent is
+ # already gone (aborted between add_request and add_companion) fall
+ # back to round-robin rather than failing the companion.
+ stage0_pool = self.stage_pools[0]
parent_state = self.request_states.get(parent_id)
- if parent_state is not None and 0 in parent_state.chosen_client_index:
- client_index = parent_state.chosen_client_index[0]
- companion_state.chosen_client_index[0] = client_index
- else:
- client_index = self._choose_client_index(0, companion_state)
+ parent_replica: StageReplica | None = None
+ if parent_state is not None:
+ parent_flat = parent_state.chosen_client_index.get(0)
+ if parent_flat is not None:
+ parent_replica = stage0_pool.get_replica_by_flat_index(parent_flat)
+
+ companion_replica = stage0_pool.select_replica(
+ companion_state,
+ affinity_from=parent_replica,
+ )
companion_state.stage_submit_ts[0] = _time.time()
request = companion_prompt # Already a processed OmniEngineCoreRequest
- stage_client = self.stage_clients[client_index]
- await stage_client.add_request_async(request)
+ await companion_replica.client.add_request_async(request)
logger.info(
"[Orchestrator] CFG companion submitted: %s (role=%s, parent=%s, stage-0 replica-%s)",
companion_id,
role,
parent_id,
- self._client_to_replica[client_index],
+ companion_replica.replica_index,
)
async def _handle_abort(self, msg: dict[str, Any]) -> None:
diff --git a/vllm_omni/engine/stage_pool.py b/vllm_omni/engine/stage_pool.py
new file mode 100644
index 00000000000..86a2fdcd77f
--- /dev/null
+++ b/vllm_omni/engine/stage_pool.py
@@ -0,0 +1,116 @@
+"""StagePool: per-logical-stage replica container.
+
+Groups the {client, output_processor, vllm_config} triple of each replica
+under a single logical stage and centralizes replica selection (round-robin
++ per-request affinity). The Orchestrator still owns flat lists as a
+compatibility view; StagePool is the canonical lookup going forward.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ from vllm_omni.engine.orchestrator import OrchestratorRequestState
+
+
+@dataclass
+class StageReplica:
+ """One replica of a logical stage.
+
+ flat_index is the index into Orchestrator's flat stage_clients list; it
+ is the value cached in OrchestratorRequestState.chosen_client_index so
+ existing call sites that resolve a flat client keep working unchanged.
+ """
+
+ logical_stage_id: int
+ replica_index: int
+ flat_index: int
+ client: Any
+ output_processor: Any
+ vllm_config: Any
+
+
+class StagePool:
+ """Replicas of one logical stage with RR + affinity selection."""
+
+ def __init__(
+ self,
+ logical_stage_id: int,
+ stage_type: str | None,
+ replicas: list[StageReplica],
+ ) -> None:
+ if not replicas:
+ raise ValueError(f"StagePool for logical stage {logical_stage_id} has no replicas")
+ self.logical_stage_id = logical_stage_id
+ self.stage_type = stage_type
+ self.replicas: list[StageReplica] = replicas
+ self._rr_cursor = 0
+ self._by_flat_index: dict[int, StageReplica] = {r.flat_index: r for r in replicas}
+
+ @property
+ def num_replicas(self) -> int:
+ return len(self.replicas)
+
+ def get_replica_by_flat_index(self, flat_index: int) -> StageReplica:
+ return self._by_flat_index[flat_index]
+
+ def select_replica(
+ self,
+ req_state: OrchestratorRequestState,
+ *,
+ affinity_from: StageReplica | None = None,
+ ) -> StageReplica:
+ """Pick a replica for *req_state* and cache the choice.
+
+ Resolution order:
+ 1. Existing choice recorded on req_state (per-request affinity).
+ 2. affinity_from (explicit cross-request binding, e.g. CFG companion
+ inheriting its parent's replica at stage 0).
+ 3. Round-robin across replicas.
+ """
+ cached = req_state.chosen_client_index.get(self.logical_stage_id)
+ if cached is not None:
+ return self._by_flat_index[cached]
+
+ if affinity_from is not None:
+ if affinity_from.logical_stage_id != self.logical_stage_id:
+ raise ValueError(
+ f"affinity_from is for logical stage {affinity_from.logical_stage_id}, "
+ f"cannot be used to select in stage {self.logical_stage_id}"
+ )
+ chosen = affinity_from
+ elif self.num_replicas == 1:
+ chosen = self.replicas[0]
+ else:
+ chosen = self.replicas[self._rr_cursor % self.num_replicas]
+ self._rr_cursor += 1
+
+ req_state.chosen_client_index[self.logical_stage_id] = chosen.flat_index
+ return chosen
+
+
+def build_stage_pools(
+ stage_clients: list[Any],
+ output_processors: list[Any],
+ stage_vllm_configs: list[Any],
+ logical_stage_to_clients: list[list[int]],
+) -> list[StagePool]:
+ """Assemble StagePool list from the flat-list view owned by the engine."""
+ pools: list[StagePool] = []
+ for logical_id, client_indices in enumerate(logical_stage_to_clients):
+ replicas = [
+ StageReplica(
+ logical_stage_id=logical_id,
+ replica_index=ri,
+ flat_index=ci,
+ client=stage_clients[ci],
+ output_processor=output_processors[ci],
+ vllm_config=stage_vllm_configs[ci],
+ )
+ for ri, ci in enumerate(client_indices)
+ ]
+ stage_type = getattr(stage_clients[client_indices[0]], "stage_type", None)
+ pools.append(StagePool(logical_id, stage_type, replicas))
+ return pools
From 2b70e89535aca2f29eff74687a6b07b5fd2bd077 Mon Sep 17 00:00:00 2001
From: amy-why-3459
Date: Mon, 13 Apr 2026 14:55:16 +0800
Subject: [PATCH 12/76] =?UTF-8?q?[Revert]=20Revert=20"[Log]=20Wire=20stat?=
=?UTF-8?q?=20loggers=20into=20AsyncOmniEngine=20to=20match=20AsyncLL?=
=?UTF-8?q?=E2=80=A6=20(#2716)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: amy-why-3459
---
.../test_async_omni_engine_do_log_stats.py | 56 ------------------
.../test_async_omni_engine_stage_init.py | 2 -
tests/engine/test_single_stage_mode.py | 3 -
vllm_omni/engine/async_omni_engine.py | 58 +------------------
vllm_omni/engine/orchestrator.py | 26 +--------
vllm_omni/entrypoints/async_omni.py | 7 ++-
6 files changed, 8 insertions(+), 144 deletions(-)
delete mode 100644 tests/engine/test_async_omni_engine_do_log_stats.py
diff --git a/tests/engine/test_async_omni_engine_do_log_stats.py b/tests/engine/test_async_omni_engine_do_log_stats.py
deleted file mode 100644
index e2b8c03b935..00000000000
--- a/tests/engine/test_async_omni_engine_do_log_stats.py
+++ /dev/null
@@ -1,56 +0,0 @@
-"""Guard tests for AsyncOmniEngine.do_log_stats edge cases.
-
-These are pure-Python tests that bypass __init__ and only exercise the
-no-op branches of do_log_stats, so no stage cores / threads are needed.
-"""
-
-import asyncio
-
-import pytest
-
-from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def _make_bare_engine() -> AsyncOmniEngine:
- # Bypass __init__ so we don't spin up stage cores; we only need the
- # attributes do_log_stats touches.
- return AsyncOmniEngine.__new__(AsyncOmniEngine)
-
-
-@pytest.mark.asyncio
-async def test_do_log_stats_noop_when_manager_missing():
- engine = _make_bare_engine()
- engine.logger_manager = None
- engine.orchestrator_loop = None
- await engine.do_log_stats() # should silently return
-
-
-@pytest.mark.asyncio
-async def test_do_log_stats_noop_when_loop_missing():
- engine = _make_bare_engine()
-
- class _Manager:
- def log(self) -> None: # pragma: no cover - must not be called
- raise AssertionError("log() should not be called without a loop")
-
- engine.logger_manager = _Manager()
- engine.orchestrator_loop = None
- await engine.do_log_stats()
-
-
-@pytest.mark.asyncio
-async def test_do_log_stats_noop_when_loop_not_running():
- engine = _make_bare_engine()
-
- class _Manager:
- def log(self) -> None: # pragma: no cover - must not be called
- raise AssertionError("log() should not be called on a stopped loop")
-
- dead_loop = asyncio.new_event_loop()
- dead_loop.close()
-
- engine.logger_manager = _Manager()
- engine.orchestrator_loop = dead_loop
- await engine.do_log_stats()
diff --git a/tests/engine/test_async_omni_engine_stage_init.py b/tests/engine/test_async_omni_engine_stage_init.py
index f3973079365..6993f391ebc 100644
--- a/tests/engine/test_async_omni_engine_stage_init.py
+++ b/tests/engine/test_async_omni_engine_stage_init.py
@@ -31,7 +31,6 @@ def test_initialize_stages_restores_device_visibility_after_diffusion_init(monke
from vllm_omni.platforms import current_omni_platform
engine = object.__new__(AsyncOmniEngine)
- engine.log_stats = False
engine.model = "dummy-model"
engine.config_path = "dummy-config"
engine.num_stages = 1
@@ -283,7 +282,6 @@ def __init__(self, vllm_config, renderer=None):
)
engine = object.__new__(AsyncOmniEngine)
- engine.log_stats = False
_stage_client, _out_proc, _vllm_cfg, input_processor = engine._attach_llm_stage(started)
diff --git a/tests/engine/test_single_stage_mode.py b/tests/engine/test_single_stage_mode.py
index 1afe2fd6d9c..2c5bf6cc79c 100644
--- a/tests/engine/test_single_stage_mode.py
+++ b/tests/engine/test_single_stage_mode.py
@@ -461,7 +461,6 @@ def _build_engine_skeleton(
engine.stage_configs = stage_cfgs
engine.num_stages = len(stage_cfgs)
engine.async_chunk = False
- engine.log_stats = False
engine.single_stage_mode = single_stage_mode
engine._single_stage_id_filter = stage_id_filter
engine._omni_master_address = omni_master_address
@@ -1367,7 +1366,6 @@ class TestLaunchLlmStageSingleStageMode:
def _build_engine_with_oms(self) -> AsyncOmniEngine:
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
- engine.log_stats = False
engine.single_stage_mode = True
engine._single_stage_id_filter = 0
engine._llm_stage_launch_lock = threading.Lock()
@@ -1448,7 +1446,6 @@ def test_spawn_stage_core_used_in_normal_mode(self):
"""~single_stage_mode → spawn_stage_core + complete_stage_handshake."""
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
- engine.log_stats = False
engine.single_stage_mode = False
engine._omni_master_server = None
engine._llm_stage_launch_lock = threading.Lock()
diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py
index 32e8336f6da..0a2e02d66ef 100644
--- a/vllm_omni/engine/async_omni_engine.py
+++ b/vllm_omni/engine/async_omni_engine.py
@@ -31,7 +31,6 @@
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor
-from vllm.v1.metrics.loggers import StatLoggerManager
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
@@ -285,7 +284,6 @@ def __init__(
self.num_stages = len(self.stage_configs)
stage0_args = getattr(self.stage_configs[0], "engine_args", None) if self.num_stages > 0 else None
self.async_chunk = bool(getattr(stage0_args, "async_chunk", False))
- self.log_stats = not bool(getattr(stage0_args, "disable_log_stats", False))
self.stage_clients: list[Any] = []
self.stage_vllm_configs: list[Any] = []
self.output_processors: list[MultimodalOutputProcessor | None] = []
@@ -415,7 +413,7 @@ def _launch_llm_stage(
addresses, proc, handshake_address = spawn_stage_core(
vllm_config=vllm_config,
executor_class=executor_class,
- log_stats=self.log_stats,
+ log_stats=False,
)
started_stage = StartedLlmStage(
stage_id=metadata.stage_id,
@@ -617,7 +615,7 @@ def _attach_llm_stage(
)
output_processor = MultimodalOutputProcessor(
tokenizer=tokenizer,
- log_stats=self.log_stats,
+ log_stats=False,
engine_core_output_type=started.metadata.engine_output_type,
)
input_processor = None
@@ -872,30 +870,6 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
self.default_sampling_params_list = default_sampling_params_list
self.stage_metadata = stage_metadata
- # Single StatLoggerManager for the whole pipeline, mirroring how
- # vLLM AsyncLLM uses one manager with multiple engine indices for DP.
- # We treat each stage as a separate "engine_idx" so logs are
- # distinguishable as "Engine 000/001/002/...". Using a single manager
- # also avoids PrometheusStatLogger registry collisions.
- self.logger_manager: StatLoggerManager | None = None
- if self.log_stats:
- base_vllm_config = next(
- (cfg for cfg in self.stage_vllm_configs if cfg is not None),
- None,
- )
- if base_vllm_config is not None:
- try:
- self.logger_manager = StatLoggerManager(
- vllm_config=base_vllm_config,
- engine_idxs=list(range(self.num_stages)),
- custom_stat_loggers=None,
- enable_default_loggers=True,
- )
- self.logger_manager.log_engine_initialized()
- except Exception:
- logger.exception("[AsyncOmniEngine] Failed to build StatLoggerManager")
- self.logger_manager = None
-
def _initialize_janus_queues(self) -> None:
"""Initialize janus queues inside orchestrator thread loop context."""
self.request_queue = janus.Queue()
@@ -912,10 +886,6 @@ def _bootstrap_orchestrator(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
- # Expose the orchestrator loop so other threads (API server) can
- # schedule coroutines onto it via run_coroutine_threadsafe, keeping
- # single-threaded access to StatLoggerManager (mirrors AsyncLLM).
- self.orchestrator_loop = loop
async def _run_orchestrator() -> None:
self._initialize_janus_queues()
@@ -929,7 +899,6 @@ async def _run_orchestrator() -> None:
stage_clients=self.stage_clients,
output_processors=self.output_processors,
stage_vllm_configs=self.stage_vllm_configs,
- logger_manager=self.logger_manager,
)
if not startup_future.done():
startup_future.set_result(asyncio.get_running_loop())
@@ -1554,29 +1523,6 @@ async def abort_async(self, request_ids: list[str]) -> None:
"""Async abort API."""
self.abort(request_ids)
- async def do_log_stats(self) -> None:
- """Flush the StatLoggerManager on the orchestrator thread.
-
- ``StatLoggerManager`` is only safe to access from the orchestrator
- loop (where ``record()`` runs). Schedule ``log()`` onto that loop
- via ``run_coroutine_threadsafe`` so all access stays single-threaded,
- matching upstream vLLM ``AsyncLLM``.
- """
- manager = self.logger_manager
- if manager is None:
- return
- loop = getattr(self, "orchestrator_loop", None)
- if loop is None or not loop.is_running():
- return
-
- async def _log() -> None:
- manager.log()
-
- try:
- await asyncio.wrap_future(asyncio.run_coroutine_threadsafe(_log(), loop))
- except Exception:
- logger.exception("[AsyncOmniEngine] do_log_stats failed")
-
def collective_rpc(
self,
method: str,
diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py
index e64fd3685cf..386b545eb75 100644
--- a/vllm_omni/engine/orchestrator.py
+++ b/vllm_omni/engine/orchestrator.py
@@ -22,8 +22,6 @@
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import EngineCoreOutputs
-from vllm.v1.metrics.loggers import StatLoggerManager
-from vllm.v1.metrics.stats import IterationStats
from vllm_omni.distributed.omni_connectors.adapter import compute_talker_prompt_ids_length
from vllm_omni.engine import (
@@ -124,7 +122,6 @@ def __init__(
stage_vllm_configs: list[Any],
*,
async_chunk: bool = False,
- logger_manager: StatLoggerManager | None = None,
) -> None:
self.request_async_queue = request_async_queue
self.output_async_queue = output_async_queue
@@ -136,8 +133,6 @@ def __init__(
self.stage_clients: list[Any] = stage_clients
self.output_processors: list[Any] = output_processors
self.stage_vllm_configs: list[Any] = stage_vllm_configs
- self.logger_manager: StatLoggerManager | None = logger_manager
- self.log_stats = self.logger_manager is not None
# Per-request state
self.request_states: dict[str, OrchestratorRequestState] = {}
@@ -629,13 +624,10 @@ async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOut
"""
processor = self.output_processors[stage_id]
- num_outputs = len(raw_outputs.outputs)
- iteration_stats = IterationStats() if (self.log_stats and num_outputs) else None
-
processed = processor.process_outputs(
raw_outputs.outputs,
raw_outputs.timestamp,
- iteration_stats,
+ None,
)
if processed.reqs_to_abort:
@@ -644,22 +636,6 @@ async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOut
if raw_outputs.scheduler_stats is not None:
processor.update_scheduler_stats(raw_outputs.scheduler_stats)
- # Mirror vLLM AsyncLLM output_handler: feed stats to the logger
- # manager so LoggingStatLogger can periodically print KV cache /
- # prefix cache hit rate, and PrometheusStatLogger can publish.
- if self.logger_manager is not None:
- try:
- self.logger_manager.record(
- engine_idx=stage_id,
- scheduler_stats=raw_outputs.scheduler_stats,
- iteration_stats=iteration_stats,
- )
- except Exception:
- logger.exception(
- "[Orchestrator] stat logger record failed for stage-%s",
- stage_id,
- )
-
return processed.request_outputs
async def _handle_add_request(self, msg: dict[str, Any]) -> None:
diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py
index 0b25ce71418..129ef3c99d8 100644
--- a/vllm_omni/entrypoints/async_omni.py
+++ b/vllm_omni/entrypoints/async_omni.py
@@ -743,8 +743,11 @@ async def is_tracing_enabled(self) -> bool:
return False
async def do_log_stats(self) -> None:
- """Log statistics via the engine, mirroring vLLM ``AsyncLLM``."""
- await self.engine.do_log_stats()
+ """Log statistics.
+
+ TODO: Forward to Orchestrator process via message.
+ """
+ pass
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
"""Return the task set exposed by the orchestrator-backed engine."""
From 0d4e975e1bf6c574babc7e8279db2b4ff612dd22 Mon Sep 17 00:00:00 2001
From: NATURE
Date: Mon, 13 Apr 2026 16:01:14 +0800
Subject: [PATCH 13/76] [core]refactor communication layer: PR1(Added Refactor
Infra Only) (#1555)
Signed-off-by: natureofnature
Co-authored-by: Hongsheng Liu
---
.../test_chunk_scheduling_coordinator.py | 690 ++++++
tests/worker/test_omni_connector_mixin.py | 1419 +++++++++++
.../core/sched/omni_scheduling_coordinator.py | 380 +++
.../worker/diffusion_model_runner.py | 3 +-
vllm_omni/outputs.py | 28 +
vllm_omni/worker/gpu_ar_model_runner.py | 3 +-
.../worker/gpu_generation_model_runner.py | 3 +-
.../omni_connector_model_runner_mixin.py | 2125 +++++++++++++++++
vllm_omni/worker/payload_span.py | 64 +
9 files changed, 4712 insertions(+), 3 deletions(-)
create mode 100644 tests/core/sched/test_chunk_scheduling_coordinator.py
create mode 100644 tests/worker/test_omni_connector_mixin.py
create mode 100644 vllm_omni/core/sched/omni_scheduling_coordinator.py
create mode 100644 vllm_omni/worker/omni_connector_model_runner_mixin.py
create mode 100644 vllm_omni/worker/payload_span.py
diff --git a/tests/core/sched/test_chunk_scheduling_coordinator.py b/tests/core/sched/test_chunk_scheduling_coordinator.py
new file mode 100644
index 00000000000..5e19465e224
--- /dev/null
+++ b/tests/core/sched/test_chunk_scheduling_coordinator.py
@@ -0,0 +1,690 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for OmniSchedulingCoordinator (formerly ChunkSchedulingCoordinator).
+
+These tests use mock request objects and mock queues. They do not require
+GPU, vLLM runtime, or any connector.
+"""
+
+from __future__ import annotations
+
+import unittest
+from types import SimpleNamespace
+
+import vllm_omni.core.sched.omni_scheduling_coordinator as coord_mod
+from vllm_omni.core.sched.omni_scheduling_coordinator import (
+ ChunkSchedulingCoordinator,
+ OmniSchedulingCoordinator,
+)
+
+# ------------------------------------------------------------------ #
+# Mock helpers
+# ------------------------------------------------------------------ #
+
+
+class _RequestStatus:
+ WAITING = "waiting"
+ RUNNING = "running"
+ WAITING_FOR_CHUNK = "waiting_for_chunk"
+ WAITING_FOR_INPUT = "waiting_for_input"
+ FINISHED_STOPPED = "finished_stopped"
+
+
+# Patch RequestStatus for tests that don't import vllm
+try:
+ from vllm.v1.request import RequestStatus
+except ImportError:
+ RequestStatus = _RequestStatus # type: ignore[misc,assignment]
+
+if not hasattr(RequestStatus, "WAITING_FOR_INPUT"):
+ coord_mod.RequestStatus = _RequestStatus # type: ignore[assignment]
+ RequestStatus = _RequestStatus # type: ignore[misc,assignment]
+
+
+def _make_request(req_id: str, status: str = "waiting") -> SimpleNamespace:
+ return SimpleNamespace(
+ request_id=req_id,
+ external_req_id=req_id,
+ status=status,
+ additional_information=None,
+ prompt_token_ids=[],
+ num_prompt_tokens=0,
+ num_computed_tokens=0,
+ _all_token_ids=[],
+ _output_token_ids=[],
+ )
+
+
+class MockQueue:
+ """Simplified queue that mimics the Scheduler waiting queue interface."""
+
+ def __init__(self, items: list | None = None):
+ self._items: list = list(items or [])
+
+ def __iter__(self):
+ return iter(self._items)
+
+ def __len__(self):
+ return len(self._items)
+
+ def __contains__(self, item):
+ return item in self._items
+
+ def add_request(self, request):
+ self._items.append(request)
+
+ def prepend_requests(self, requests):
+ self._items = list(requests) + self._items
+
+ def remove(self, request):
+ self._items.remove(request)
+
+ def remove_requests(self, requests):
+ remove_set = set(id(r) for r in requests)
+ self._items = [r for r in self._items if id(r) not in remove_set]
+
+
+# ------------------------------------------------------------------ #
+# Tests
+# ------------------------------------------------------------------ #
+
+
+class TestChunkCoordinatorStateTransition(unittest.TestCase):
+ """Test 5: process_pending_chunks transitions WAITING_FOR_CHUNK → target."""
+
+ def test_ready_request_transitions_to_waiting(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids={"r1"},
+ chunk_finished_req_ids=set(),
+ )
+
+ self.assertEqual(req.status, RequestStatus.WAITING)
+ self.assertIn("r1", coord.requests_with_ready_chunks)
+
+ def test_non_ready_stays_waiting_for_chunk(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids=set(),
+ chunk_finished_req_ids=set(),
+ )
+
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
+
+ def test_stage_0_is_noop(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=0)
+ req = _make_request("r1")
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids={"r1"},
+ chunk_finished_req_ids=set(),
+ )
+ self.assertNotEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
+
+
+class TestChunkCoordinatorRestoreQueues(unittest.TestCase):
+ """Test 6: restore_queues returns waiting-for-chunk requests."""
+
+ def test_restore(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ r1 = _make_request("r1")
+ r2 = _make_request("r2")
+ coord._waiting_for_chunk_waiting.append(r1)
+ coord._waiting_for_chunk_running.append(r2)
+
+ waiting = MockQueue()
+ running: list = []
+
+ coord.restore_queues(waiting, running)
+
+ self.assertIn(r1, waiting)
+ self.assertIn(r2, running)
+ self.assertEqual(len(coord._waiting_for_chunk_waiting), 0)
+ self.assertEqual(len(coord._waiting_for_chunk_running), 0)
+
+
+class TestChunkCoordinatorFinishedSignal(unittest.TestCase):
+ """Test 8: chunk_finished_req_ids → finished_requests."""
+
+ def test_finished_signal(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids={"r1"},
+ chunk_finished_req_ids={"r1"},
+ )
+
+ self.assertIn("r1", coord.finished_requests)
+
+
+class TestChunkCoordinatorUpdateRequestMetadata(unittest.TestCase):
+ """Test update_request_metadata applies scheduling metadata to requests."""
+
+ def test_ar_mode_no_longer_sets_additional_information(self):
+ """AR mode only processes scheduling metadata, not full payloads."""
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ req = _make_request("r1")
+ requests = {"r1": req}
+
+ # Only scheduling metadata is passed now (full payload stays in model runner)
+ request_metadata = {"r1": {"next_stage_prompt_len": 50}}
+
+ coord.update_request_metadata(requests, request_metadata, model_mode="ar")
+
+ # next_stage_prompt_len should update prompt_token_ids
+ self.assertEqual(len(req.prompt_token_ids), 50)
+ self.assertEqual(req.num_prompt_tokens, 50)
+ # additional_information should NOT be set
+ self.assertIsNone(getattr(req, "additional_information", None))
+
+ def test_generation_mode(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ req = _make_request("r1")
+ req.prompt_token_ids = [0, 0, 0]
+ requests = {"r1": req}
+
+ request_metadata = {
+ "r1": {
+ "code_predictor_codes": [10, 20, 30],
+ "left_context_size": 25,
+ }
+ }
+
+ coord.update_request_metadata(requests, request_metadata, model_mode="generation")
+
+ self.assertEqual(req.prompt_token_ids, [10, 20, 30])
+ self.assertEqual(req.num_computed_tokens, 0)
+ self.assertIsNone(req.additional_information)
+ self.assertEqual(req._omni_initial_model_buffer, {"left_context_size": 25})
+
+
+class TestChunkCoordinatorPostprocess(unittest.TestCase):
+ """Test postprocess_scheduler_output clears ready chunks."""
+
+ def test_clear_ready(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+ coord.requests_with_ready_chunks = {"r1", "r2"}
+
+ new_req = SimpleNamespace(req_id="r1")
+ cached_reqs = SimpleNamespace(req_ids=["r2"])
+ scheduler_output = SimpleNamespace(
+ scheduled_new_reqs=[new_req],
+ scheduled_cached_reqs=cached_reqs,
+ )
+
+ coord.postprocess_scheduler_output(scheduler_output)
+
+ self.assertEqual(coord.requests_with_ready_chunks, set())
+
+
+class TestWaitingForInputTransition(unittest.TestCase):
+ """Test B8: process_pending_full_payload_inputs transitions WAITING_FOR_INPUT."""
+
+ def test_transition_on_recv(self):
+ coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ running,
+ stage_recv_req_ids={"r1"},
+ )
+
+ self.assertEqual(req.status, RequestStatus.WAITING)
+
+ def test_stays_waiting_for_input_if_not_received(self):
+ coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ running,
+ stage_recv_req_ids=set(),
+ )
+
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT)
+ self.assertEqual(len(coord._waiting_for_input), 1)
+
+ def test_stage_0_is_noop(self):
+ coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=0)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ running,
+ stage_recv_req_ids={"r1"},
+ )
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT)
+
+ def test_restore_queues_includes_waiting_for_input(self):
+ coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ r1 = _make_request("r1")
+ coord._waiting_for_input.append(r1)
+
+ waiting = MockQueue()
+ running: list = []
+
+ coord.restore_queues(waiting, running)
+
+ self.assertIn(r1, waiting)
+ self.assertEqual(len(coord._waiting_for_input), 0)
+
+ def test_full_payload_mode_auto_transitions_waiting_to_waiting_for_input(self):
+ """In full_payload_mode (async_chunk=False), fresh WAITING requests on
+ non-Stage-0 should be transitioned to WAITING_FOR_INPUT."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=False,
+ )
+
+ req = _make_request("r1", status=RequestStatus.WAITING)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ running,
+ stage_recv_req_ids=set(),
+ )
+
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT)
+ self.assertEqual(len(coord._waiting_for_input), 1)
+ self.assertEqual(len(coord.pending_input_registrations), 1)
+
+ def test_async_chunk_mode_does_not_auto_transition(self):
+ """In async_chunk mode, fresh WAITING requests should NOT be
+ transitioned to WAITING_FOR_INPUT."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=True,
+ )
+
+ req = _make_request("r1", status=RequestStatus.WAITING)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ running,
+ stage_recv_req_ids=set(),
+ )
+
+ self.assertEqual(req.status, RequestStatus.WAITING)
+
+ def test_pending_input_registrations(self):
+ coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ running,
+ stage_recv_req_ids=set(),
+ )
+
+ self.assertEqual(len(coord.pending_input_registrations), 1)
+ self.assertEqual(coord.pending_input_registrations[0].request_id, "r1")
+
+
+class TestTimeoutDetection(unittest.TestCase):
+ """Regression tests for orphaned pending-recv timeout detection.
+
+ Covers the full lifecycle:
+ 1. Request enters WAITING_FOR_CHUNK from either waiting or running queue
+ 2. restore_queues() moves it back to the scheduler queue
+ 3. Timeout fires via collect_timed_out_request_ids()
+ 4. Scheduler removes from both queues and calls _free_request()
+ """
+
+ def test_waiting_since_recorded_on_chunk_wait(self):
+ """_waiting_since is set when a request enters WAITING_FOR_CHUNK."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=True,
+ )
+ req = _make_request("r1", status=RequestStatus.WAITING)
+ waiting = MockQueue([req])
+
+ coord.process_pending_chunks(
+ waiting,
+ [],
+ chunk_ready_req_ids=set(),
+ chunk_finished_req_ids=set(),
+ )
+
+ self.assertIn("r1", coord._waiting_since)
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
+
+ def test_waiting_since_cleared_on_chunk_arrival(self):
+ """_waiting_since is cleared when a chunk arrives."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=True,
+ )
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
+ waiting = MockQueue([req])
+
+ coord.process_pending_chunks(
+ waiting,
+ [],
+ chunk_ready_req_ids={"r1"},
+ chunk_finished_req_ids=set(),
+ )
+
+ self.assertNotIn("r1", coord._waiting_since)
+
+ def test_waiting_since_recorded_on_input_wait(self):
+ """_waiting_since is set when a request enters WAITING_FOR_INPUT."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=False,
+ )
+ req = _make_request("r1", status=RequestStatus.WAITING)
+ waiting = MockQueue([req])
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ [],
+ stage_recv_req_ids=set(),
+ )
+
+ self.assertIn("r1", coord._waiting_since)
+
+ def test_waiting_since_cleared_on_input_arrival(self):
+ """_waiting_since is cleared when input data arrives."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=False,
+ )
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
+ coord._waiting_for_input.append(req)
+ coord._waiting_since["r1"] = 0.0
+
+ waiting = MockQueue()
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ [],
+ stage_recv_req_ids={"r1"},
+ )
+
+ self.assertNotIn("r1", coord._waiting_since)
+ self.assertEqual(req.status, RequestStatus.WAITING)
+
+ def test_collect_timed_out_request_ids_no_timeout(self):
+ """No IDs returned when nothing has timed out."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ )
+ import time
+
+ coord._waiting_since["r1"] = time.monotonic()
+
+ result = coord.collect_timed_out_request_ids(timeout_s=300.0)
+ self.assertEqual(result, set())
+
+ def test_collect_timed_out_request_ids_expired(self):
+ """Timed-out IDs are returned and _waiting_since is cleared."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ )
+ coord._waiting_since["r1"] = 0.0 # epoch → definitely expired
+ coord._waiting_since["r2"] = 0.0
+
+ import time
+
+ coord._waiting_since["r3"] = time.monotonic() + 9999 # far future
+
+ result = coord.collect_timed_out_request_ids(timeout_s=1.0)
+
+ self.assertEqual(result, {"r1", "r2"})
+ self.assertNotIn("r1", coord._waiting_since)
+ self.assertNotIn("r2", coord._waiting_since)
+ self.assertIn("r3", coord._waiting_since)
+
+ def test_collect_removes_from_coordinator_queues(self):
+ """Timed-out requests are defensively removed from internal queues."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ )
+ r1 = _make_request("r1")
+ r2 = _make_request("r2")
+ coord._waiting_for_chunk_waiting.append(r1)
+ coord._waiting_for_input.append(r2)
+ coord._waiting_since["r1"] = 0.0
+ coord._waiting_since["r2"] = 0.0
+
+ result = coord.collect_timed_out_request_ids(timeout_s=1.0)
+
+ self.assertEqual(result, {"r1", "r2"})
+ self.assertEqual(len(coord._waiting_for_chunk_waiting), 0)
+ self.assertEqual(len(coord._waiting_for_input), 0)
+
+ def test_free_finished_request_clears_waiting_since(self):
+ """free_finished_request clears _waiting_since."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ )
+ coord._waiting_since["r1"] = 0.0
+ coord.free_finished_request("r1")
+ self.assertNotIn("r1", coord._waiting_since)
+
+ def test_timeout_from_running_queue_full_lifecycle(self):
+ """End-to-end: request from running → WAITING_FOR_CHUNK → restore →
+ timeout → removed from running list.
+
+ This is the critical regression case: WAITING_FOR_CHUNK requests
+ that originated from self.running are placed back into self.running
+ by restore_queues(), but their status remains WAITING_FOR_CHUNK.
+ The scheduler must remove from BOTH queues unconditionally.
+ """
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=True,
+ )
+
+ # 1) Request starts in running queue with WAITING status
+ req = _make_request("r1", status=RequestStatus.WAITING)
+ running = [req]
+ waiting = MockQueue()
+
+ # 2) process_pending_chunks: moves to WAITING_FOR_CHUNK
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids=set(),
+ chunk_finished_req_ids=set(),
+ )
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
+ self.assertIn("r1", coord._waiting_since)
+ self.assertEqual(len(coord._waiting_for_chunk_running), 1)
+
+ # 3) restore_queues: back to running (status stays WAITING_FOR_CHUNK)
+ coord.restore_queues(waiting, running)
+ self.assertIn(req, running)
+ self.assertEqual(len(coord._waiting_for_chunk_running), 0)
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
+
+ # 4) Force timeout by setting _waiting_since to epoch
+ coord._waiting_since["r1"] = 0.0
+
+ timed_out_ids = coord.collect_timed_out_request_ids(timeout_s=1.0)
+ self.assertEqual(timed_out_ids, {"r1"})
+
+ # 5) Scheduler removes from both queues (simulating the scheduler path)
+ timed_out_id_set = {id(req)}
+ running = [r for r in running if id(r) not in timed_out_id_set]
+ waiting.remove_requests([req])
+
+ self.assertNotIn(req, running)
+ self.assertEqual(len(waiting), 0)
+
+ def test_timeout_from_waiting_queue_full_lifecycle(self):
+ """End-to-end: request from waiting → WAITING_FOR_CHUNK → restore →
+ timeout → removed from waiting queue."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=True,
+ )
+
+ req = _make_request("r1", status=RequestStatus.WAITING)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids=set(),
+ chunk_finished_req_ids=set(),
+ )
+ self.assertEqual(len(coord._waiting_for_chunk_waiting), 1)
+
+ coord.restore_queues(waiting, running)
+ self.assertIn(req, waiting)
+
+ coord._waiting_since["r1"] = 0.0
+ timed_out_ids = coord.collect_timed_out_request_ids(timeout_s=1.0)
+ self.assertEqual(timed_out_ids, {"r1"})
+
+ waiting.remove_requests([req])
+ self.assertEqual(len(waiting), 0)
+
+
+class TestOverflowPreemption(unittest.TestCase):
+ """Tests for P1-1: overflow requests must get WAITING status.
+
+ Overflow happens when multiple WAITING_FOR_CHUNK requests in
+ ``_waiting_for_chunk_running`` receive their chunk in the same cycle.
+ ``_process_chunk_queue`` restores them to RUNNING (``continue``
+ path) while RUNNING requests without chunks are moved out. If the
+ net result exceeds ``scheduler_max_num_seqs``, the tail is pushed
+ to ``waiting_queue`` and must have status == WAITING.
+ """
+
+ def test_overflow_sets_waiting_status(self):
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=1,
+ stage_id=1,
+ async_chunk=True,
+ )
+
+ # r1 is currently RUNNING in the queue.
+ # r2, r3 were previously moved to _waiting_for_chunk_running.
+ r1 = _make_request("r1", status=RequestStatus.RUNNING)
+ r2 = _make_request("r2", status=RequestStatus.WAITING_FOR_CHUNK)
+ r3 = _make_request("r3", status=RequestStatus.WAITING_FOR_CHUNK)
+
+ running = [r1]
+ waiting = MockQueue([])
+ coord._waiting_for_chunk_running.extend([r2, r3])
+
+ # restore_queues puts r2, r3 back into running
+ coord.restore_queues(waiting, running)
+ self.assertEqual(len(running), 3)
+
+ # Now process_pending_chunks with r2, r3 chunks ready:
+ # _process_chunk_queue will:
+ # r1 (RUNNING) → no chunk → move to _waiting_for_chunk_running
+ # r2 (WAITING_FOR_CHUNK, chunk ready) → set RUNNING, stay in running
+ # r3 (WAITING_FOR_CHUNK, chunk ready) → set RUNNING, stay in running
+ # running = [r2, r3], len=2 > max=1 → overflow
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids={"r2", "r3"},
+ chunk_finished_req_ids=set(),
+ )
+
+ self.assertEqual(len(running), 1)
+ self.assertEqual(len(waiting), 1)
+ overflow_req = list(waiting)[0]
+ self.assertEqual(
+ overflow_req.status,
+ RequestStatus.WAITING,
+ f"Overflowed request should have WAITING status, got {overflow_req.status}",
+ )
+
+ def test_overflow_does_not_strand_request(self):
+ """Without the fix, the overflowed request would keep its
+ RUNNING status in the waiting queue and never be re-scheduled."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=1,
+ stage_id=1,
+ async_chunk=True,
+ )
+
+ r1 = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
+ r2 = _make_request("r2", status=RequestStatus.WAITING_FOR_CHUNK)
+ coord._waiting_for_chunk_running.extend([r1, r2])
+
+ running: list = []
+ waiting = MockQueue([])
+
+ coord.restore_queues(waiting, running)
+ self.assertEqual(len(running), 2)
+
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids={"r1", "r2"},
+ chunk_finished_req_ids=set(),
+ )
+
+ self.assertEqual(len(running), 1)
+ self.assertEqual(len(waiting), 1)
+ for req in waiting:
+ self.assertNotEqual(req.status, RequestStatus.RUNNING, "Overflowed request must not keep RUNNING status")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py
new file mode 100644
index 00000000000..0e162a37e5b
--- /dev/null
+++ b/tests/worker/test_omni_connector_mixin.py
@@ -0,0 +1,1419 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for OmniConnectorModelRunnerMixin.
+
+These tests use a mock connector (in-memory dict store) and do not require
+GPU or vLLM runtime.
+"""
+
+from __future__ import annotations
+
+import time
+import unittest
+from types import SimpleNamespace
+from typing import Any
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+from vllm_omni.outputs import OmniConnectorOutput
+from vllm_omni.worker.omni_connector_model_runner_mixin import (
+ OmniConnectorModelRunnerMixin,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+# ------------------------------------------------------------------ #
+# Mock helpers
+# ------------------------------------------------------------------ #
+
+
+class MockConnector:
+ """In-memory connector for testing (mimics OmniConnectorBase)."""
+
+ def __init__(self, stage_id: int = 0):
+ self.stage_id = stage_id
+ self._store: dict[str, Any] = {}
+
+ def put(self, from_stage, to_stage, put_key, data):
+ key = f"{from_stage}_{to_stage}_{put_key}"
+ self._store[key] = data
+ return True, len(str(data)), None
+
+ def get(self, from_stage, to_stage, get_key, metadata=None):
+ key = f"{from_stage}_{to_stage}_{get_key}"
+ data = self._store.pop(key, None)
+ if data is None:
+ return None
+ return data, len(str(data))
+
+ def close(self):
+ pass
+
+
+def _make_model_config(
+ stage_id: int = 0,
+ async_chunk: bool = False,
+ worker_type: str = "ar",
+ custom_func: str | None = None,
+) -> SimpleNamespace:
+ return SimpleNamespace(
+ stage_connector_config=None,
+ async_chunk=async_chunk,
+ worker_type=worker_type,
+ custom_process_next_stage_input_func=custom_func,
+ )
+
+
+def _make_request(req_id: str, external_req_id: str | None = None):
+ r = SimpleNamespace(
+ request_id=req_id,
+ external_req_id=external_req_id or req_id,
+ additional_information=None,
+ prompt_token_ids=[],
+ num_computed_tokens=0,
+ )
+ return r
+
+
+class MixinHost(OmniConnectorModelRunnerMixin):
+ """Minimal class that mixes in the mixin for testing."""
+
+ pass
+
+
+class _FakeTPGroup:
+ def __init__(self, *, world_size: int, rank_in_group: int, follower_result: Any = None):
+ self.world_size = world_size
+ self.rank_in_group = rank_in_group
+ self.follower_result = follower_result
+ self.broadcast_inputs: list[Any] = []
+
+ def broadcast_object(self, obj: Any | None = None, src: int = 0):
+ self.broadcast_inputs.append(obj)
+ if self.rank_in_group == src:
+ return obj
+ return self.follower_result
+
+
+# ------------------------------------------------------------------ #
+# Test cases
+# ------------------------------------------------------------------ #
+
+
+class TestMixinAsyncChunkSendRecv(unittest.TestCase):
+ """Test 2: Async chunk send/recv + bg threads."""
+
+ def test_send_chunk_passes_is_finished_and_connector(self):
+ connector = MockConnector(stage_id=0)
+
+ sender = MixinHost()
+ sender.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=0, async_chunk=True),
+ )
+ sender._omni_connector = connector
+ sender._stage_id = 0
+ sender._async_chunk = True
+
+ seen = {}
+
+ def mock_process(transfer_manager, pooling_output, request, is_finished=False):
+ seen["connector"] = transfer_manager.connector
+ seen["is_finished"] = is_finished
+ return {"data": pooling_output, "finished": is_finished}
+
+ sender._custom_process_func = mock_process
+
+ request = _make_request("req-1", "ext-req-1")
+ request.is_finished = lambda: True
+ sender._send_single_request(
+ {
+ "stage_id": 0,
+ "next_stage_id": 1,
+ "request_id": "ext-req-1",
+ "request": request,
+ "pooling_output": {"value": 42},
+ }
+ )
+ self.assertIs(seen["connector"], connector)
+ self.assertTrue(seen["is_finished"])
+
+ sender.shutdown_omni_connectors()
+
+ def test_send_chunk_does_not_retry_real_type_error(self):
+ connector = MockConnector(stage_id=0)
+
+ sender = MixinHost()
+ sender.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=0, async_chunk=True),
+ )
+ sender._omni_connector = connector
+ sender._stage_id = 0
+ sender._async_chunk = True
+
+ seen = {"calls": 0}
+
+ def broken_process(transfer_manager, pooling_output, request, is_finished=""):
+ seen["calls"] += 1
+ return {"data": is_finished + "tail"}
+
+ sender._custom_process_func = broken_process
+
+ request = _make_request("req-1", "ext-req-1")
+ request.is_finished = lambda: True
+ ok = sender.send_chunk(request, pooling_output={"value": 42})
+ self.assertFalse(ok)
+ self.assertEqual(seen["calls"], 1)
+
+ sender.shutdown_omni_connectors()
+
+
+class TestMixinKVCacheTransfer(unittest.TestCase):
+ """Test 3: KV cache delegation to OmniKVTransferManager."""
+
+ def test_send_kv_delegates(self):
+ mock_kvm = MagicMock()
+ mock_kvm.handle_finished_requests_kv_transfer.return_value = ["req-1"]
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ kv_transfer_manager=mock_kvm,
+ )
+
+ result = host.send_kv_cache(
+ finished_reqs={"req-1": {"seq_len": 10, "block_ids": [0]}},
+ kv_caches=[],
+ block_size=16,
+ cache_dtype="float16",
+ )
+ self.assertEqual(result, ["req-1"])
+ mock_kvm.handle_finished_requests_kv_transfer.assert_called_once()
+
+ host.shutdown_omni_connectors()
+
+ def test_recv_kv_delegates(self):
+ mock_kvm = MagicMock()
+ mock_kvm.receive_kv_cache_for_request.return_value = ({"layer_blocks": {}}, 100)
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ kv_transfer_manager=mock_kvm,
+ )
+
+ data, size = host.recv_kv_cache("req-1")
+ self.assertIsNotNone(data)
+ self.assertEqual(size, 100)
+ mock_kvm.receive_kv_cache_for_request.assert_called_once()
+
+ host.shutdown_omni_connectors()
+
+ def test_receive_multi_kv_fetches_companions_via_mixin(self):
+ mock_kvm = MagicMock()
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ kv_transfer_manager=mock_kvm,
+ )
+
+ host.recv_kv_cache = MagicMock(
+ side_effect=[({"layer_blocks": {"k": [1]}}, 64), ({"layer_blocks": {"k": [2]}}, 32)]
+ )
+ seen = {}
+
+ def collect_cfg(request_id, cfg_role_payloads):
+ seen["request_id"] = request_id
+ seen["cfg_role_payloads"] = cfg_role_payloads
+ return {"cfg_text_kv_metadata": {"seq_len": 3}}
+
+ req = SimpleNamespace(
+ request_id="req-1",
+ sampling_params=SimpleNamespace(cfg_kv_request_ids={"cfg_text": "req-1__cfg_text"}),
+ )
+ ok = host.receive_multi_kv_cache(req, cfg_kv_collect_func=collect_cfg)
+ self.assertTrue(ok)
+ host.recv_kv_cache.assert_any_call("req-1", target_device=None)
+ host.recv_kv_cache.assert_any_call("req-1__cfg_text", target_device=None)
+ mock_kvm.apply_kv_cache_to_request.assert_called_once_with(req, {"layer_blocks": {"k": [1]}})
+ self.assertEqual(seen["request_id"], "req-1")
+ self.assertEqual(
+ seen["cfg_role_payloads"],
+ {"cfg_text": ({"layer_blocks": {"k": [2]}}, 32)},
+ )
+ self.assertEqual(req.sampling_params.cfg_text_kv_metadata, {"seq_len": 3})
+
+ host.shutdown_omni_connectors()
+
+ def test_receive_multi_kv_skips_inactive_request(self):
+ mock_kvm = MagicMock()
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ kv_transfer_manager=mock_kvm,
+ )
+
+ host.requests = {}
+ host.recv_kv_cache = MagicMock(return_value=({"layer_blocks": {"k": [1]}}, 64))
+ req = SimpleNamespace(request_id="req-1", sampling_params=None)
+
+ ok = host.receive_multi_kv_cache(req)
+
+ self.assertFalse(ok)
+ host.recv_kv_cache.assert_not_called()
+ mock_kvm.apply_kv_cache_to_request.assert_not_called()
+
+ host.shutdown_omni_connectors()
+
+
+class TestOmniConnectorOutput(unittest.TestCase):
+ """Test 4: Output aggregation across transfer modes."""
+
+ def test_output_aggregation(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ )
+
+ host._chunk_ready_req_ids.add("req-1")
+ host._chunk_finished_req_ids.add("req-2")
+ host._local_request_metadata["req-1"] = {"next_stage_prompt_len": 10}
+ host._stage_recv_req_ids.add("req-3")
+
+ output = host.get_omni_connector_output()
+ self.assertIsInstance(output, OmniConnectorOutput)
+ self.assertEqual(output.chunk_ready_req_ids, {"req-1"})
+ self.assertEqual(output.chunk_finished_req_ids, {"req-2"})
+ self.assertEqual(output.request_metadata, {"req-1": {"next_stage_prompt_len": 10}})
+ self.assertEqual(output.stage_recv_req_ids, {"req-3"})
+
+ output2 = host.get_omni_connector_output()
+ self.assertEqual(output2.chunk_ready_req_ids, set())
+ self.assertEqual(output2.request_metadata, {})
+
+ host.shutdown_omni_connectors()
+
+
+class TestMixinNoConnector(unittest.TestCase):
+ """Edge case: mixin works gracefully without a connector."""
+
+ def test_no_connector(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ )
+ self.assertIsNone(host._omni_connector)
+
+ results = host.recv_full_payload_inputs(scheduler_output=None)
+ self.assertIsNone(results)
+
+ sent = host.send_full_payload_outputs(None, {"req-1": {}})
+ self.assertEqual(sent, [])
+
+ ok = host.send_chunk(_make_request("req-1"), pooling_output={})
+ self.assertFalse(ok)
+
+ output = host.get_omni_connector_output()
+ self.assertIsInstance(output, OmniConnectorOutput)
+
+ host.shutdown_omni_connectors()
+
+
+class TestFinishedLoadReqsDrain(unittest.TestCase):
+ """Test A1 fix: get_omni_connector_output drains _finished_load_reqs."""
+
+ def test_finished_load_reqs_flow_to_chunk_ready(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ )
+
+ host._finished_load_reqs.add("req-1")
+ host._finished_load_reqs.add("req-2")
+
+ output = host.get_omni_connector_output()
+ self.assertIn("req-1", output.chunk_ready_req_ids)
+ self.assertIn("req-2", output.chunk_ready_req_ids)
+
+ self.assertEqual(len(host._finished_load_reqs), 0)
+ self.assertEqual(len(host._chunk_ready_req_ids), 0)
+
+ host.shutdown_omni_connectors()
+
+
+class TestLoadCustomFuncSelection(unittest.TestCase):
+ def test_skips_legacy_stage_list_processors_for_full_payload_mode(self):
+ legacy_paths = [
+ "vllm_omni.model_executor.stage_input_processors.mimo_audio.llm2code2wav",
+ "vllm_omni.model_executor.stage_input_processors.mammoth_moda2.ar2dit",
+ "vllm_omni.model_executor.stage_input_processors.cosyvoice3.text2flow",
+ "vllm_omni.model_executor.stage_input_processors.glm_image.ar2diffusion",
+ ]
+
+ for func_path in legacy_paths:
+ selected_path, func = MixinHost._load_custom_func(
+ SimpleNamespace(
+ async_chunk=False,
+ custom_process_input_func=func_path,
+ custom_process_next_stage_input_func=None,
+ )
+ )
+ assert selected_path != func_path
+ assert func is None or MixinHost._is_connector_payload_builder(func)
+
+
+class TestFullPayloadSendWithCustomFunc(unittest.TestCase):
+ """Test B4: send_full_payload_outputs with full_payload_mode custom process func."""
+
+ def test_full_payload_send_passes_is_finished_and_connector(self):
+ seen = {}
+
+ def full_payload_func(transfer_manager, pooling_output, request, is_finished=False):
+ seen["connector"] = transfer_manager.connector
+ seen["is_finished"] = is_finished
+ seen["data"] = pooling_output
+ seen["rid"] = request.request_id if request else None
+ return {"processed": True, "finished": is_finished}
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ )
+ host._omni_connector = MockConnector(stage_id=0)
+ host._stage_id = 0
+ host._custom_process_func = full_payload_func
+
+ req = _make_request("req-1")
+ req.is_finished = lambda: True
+ sent = host.send_full_payload_outputs(
+ scheduler_output=None,
+ outputs={"req-1": ({"raw": 100}, req)},
+ )
+ self.assertEqual(sent, ["req-1"])
+ self.assertEqual(
+ seen,
+ {
+ "connector": host._omni_connector,
+ "is_finished": True,
+ "data": {"raw": 100},
+ "rid": "req-1",
+ },
+ )
+
+ host.shutdown_omni_connectors()
+
+ def test_accumulate_and_flush(self):
+ call_log = []
+
+ def full_payload_func(transfer_manager, pooling_output, request):
+ call_log.append(request.request_id if request else None)
+ return {"processed": True}
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ )
+ host._omni_connector = MockConnector(stage_id=0)
+ host._stage_id = 0
+ host._custom_process_func = full_payload_func
+
+ req = _make_request("req-1")
+ host.accumulate_full_payload_output("req-1", {"raw": 42}, req)
+ self.assertEqual(len(host._pending_full_payload_send), 1)
+
+ host.flush_full_payload_outputs({"req-1"})
+ self.assertEqual(len(host._pending_full_payload_send), 0)
+ self.assertEqual(len(call_log), 1)
+ self.assertEqual(call_log[0], "req-1")
+
+ time.sleep(0.1)
+ host.shutdown_omni_connectors()
+
+
+class TestKVSentReqIdsAccumulation(unittest.TestCase):
+ """Test that kv_sent_req_ids accumulates results from send_kv_cache."""
+
+ def test_kv_sent_accumulation(self):
+ mock_kvm = MagicMock()
+ mock_kvm.handle_finished_requests_kv_transfer.return_value = ["req-1", "req-2"]
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ kv_transfer_manager=mock_kvm,
+ )
+
+ host.send_kv_cache(
+ finished_reqs={"req-1": {}, "req-2": {}},
+ kv_caches=[],
+ block_size=16,
+ cache_dtype="float16",
+ )
+
+ output = host.get_omni_connector_output()
+ self.assertIn("req-1", output.kv_sent_req_ids)
+ self.assertIn("req-2", output.kv_sent_req_ids)
+
+ output2 = host.get_omni_connector_output()
+ self.assertEqual(output2.kv_sent_req_ids, [])
+
+ host.shutdown_omni_connectors()
+
+
+class TestChunkStreamCompletedGuard(unittest.TestCase):
+ """Test that register_chunk_recv is skipped after finish sentinel.
+
+ This validates the fix for the race condition where the scheduling
+ coordinator re-registers a request for chunk polling after its
+ upstream chunk stream has already finished (is_finished sentinel
+ received), causing the bg recv thread to poll for a non-existent
+ shared-memory segment (e.g. ``_0_7`` when only 7 chunks 0–6 exist).
+ """
+
+ def _make_host(self, stage_id: int = 1) -> MixinHost:
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=stage_id, async_chunk=True),
+ )
+ host._omni_connector = MockConnector(stage_id=stage_id)
+ host._stage_id = stage_id
+ host._async_chunk = True
+ return host
+
+ def test_register_blocked_after_finish_sentinel(self):
+ """register_chunk_recv must be a no-op after the finish sentinel."""
+ host = self._make_host(stage_id=1)
+
+ req = _make_request("req-1", "ext-req-1")
+
+ # Simulate the bg thread having received the finish sentinel:
+ with host._lock:
+ host._chunk_stream_completed.add("req-1")
+
+ # Now try to re-register — this mimics the coordinator asking
+ # the model runner to poll for the next (non-existent) chunk.
+ host.register_chunk_recv(req)
+
+ # The request must NOT appear in _pending_load_reqs
+ self.assertNotIn(
+ "req-1",
+ host._pending_load_reqs,
+ "register_chunk_recv should skip requests whose chunk stream is already complete",
+ )
+
+ host.shutdown_omni_connectors()
+
+ def test_register_allowed_before_finish(self):
+ """register_chunk_recv works normally before finish sentinel."""
+ host = self._make_host(stage_id=1)
+ req = _make_request("req-1", "ext-req-1")
+
+ host.register_chunk_recv(req)
+ self.assertIn(
+ "req-1",
+ host._pending_load_reqs,
+ "register_chunk_recv should add request to pending when stream is not yet complete",
+ )
+
+ host.shutdown_omni_connectors()
+
+ def test_finish_sentinel_populates_completed_set(self):
+ """Receiving is_finished=True adds to _chunk_stream_completed."""
+ host = self._make_host(stage_id=1)
+
+ # Simulate _poll_single_request receiving is_finished=True
+ req_id = "req-1"
+ with host._lock:
+ host._chunk_finished_req_ids.add(req_id)
+ host._chunk_stream_completed.add(req_id)
+ host._local_stage_payload_cache[req_id] = {"finished": True}
+ host._local_request_metadata[req_id] = {}
+ host._finished_load_reqs.add(req_id)
+ host._pending_load_reqs.pop(req_id, None)
+
+ self.assertIn(req_id, host._chunk_stream_completed)
+
+ # Subsequent register_chunk_recv should be blocked
+ req = _make_request(req_id, f"ext-{req_id}")
+ host.register_chunk_recv(req)
+ self.assertNotIn(req_id, host._pending_load_reqs)
+
+ host.shutdown_omni_connectors()
+
+ def test_stage_0_always_skipped(self):
+ """Stage-0 has no upstream, register_chunk_recv is always no-op."""
+ host = self._make_host(stage_id=0)
+ host._stage_id = 0
+
+ req = _make_request("req-1")
+ host.register_chunk_recv(req)
+ self.assertNotIn("req-1", host._pending_load_reqs)
+
+ host.shutdown_omni_connectors()
+
+ def test_full_payload_recv_guard_still_works(self):
+ """Pre-existing guard: staged full-payload results prevent registration."""
+ host = self._make_host(stage_id=1)
+
+ with host._lock:
+ host._stage_recv_req_ids.add("req-1")
+
+ req = _make_request("req-1", "ext-req-1")
+ host.register_chunk_recv(req)
+ self.assertNotIn("req-1", host._pending_load_reqs)
+
+ host.shutdown_omni_connectors()
+
+
+class TestCleanupFinishedRequest(unittest.TestCase):
+ """Test cleanup_finished_request frees per-request mixin state."""
+
+ def _make_host(self, stage_id: int = 1) -> MixinHost:
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=stage_id, async_chunk=True),
+ )
+ host._omni_connector = MockConnector(stage_id=stage_id)
+ host._stage_id = stage_id
+ host._async_chunk = True
+ return host
+
+ def test_cleanup_removes_all_state(self):
+ """cleanup_finished_request removes all tracking dicts/sets."""
+ host = self._make_host(stage_id=1)
+ req_id = "req-1"
+ ext_id = "ext-req-1"
+
+ # Simulate state accumulated during a request's lifetime
+ host._request_ids_mapping[req_id] = ext_id
+ host._put_req_chunk[ext_id] = 5
+ host._get_req_chunk[req_id] = 3
+ host._send_side_request_payload[ext_id] = {"some": "data"}
+ host._code_prompt_token_ids[ext_id] = [[1, 2, 3]]
+ host._chunk_stream_completed.add(req_id)
+ host._stage_recv_req_ids.add(req_id)
+ host._local_stage_payload_cache[req_id] = {"engine_inputs": {}}
+ host._local_request_metadata[req_id] = {"prompt_len": 10}
+
+ # Cleanup
+ host.cleanup_finished_request(req_id)
+
+ # All state should be gone
+ self.assertNotIn(req_id, host._request_ids_mapping)
+ self.assertNotIn(ext_id, host._put_req_chunk)
+ self.assertNotIn(req_id, host._get_req_chunk)
+ self.assertNotIn(ext_id, host._send_side_request_payload)
+ self.assertNotIn(ext_id, host._code_prompt_token_ids)
+ self.assertNotIn(req_id, host._chunk_stream_completed)
+ self.assertNotIn(req_id, host._stage_recv_req_ids)
+ self.assertNotIn(req_id, host._local_stage_payload_cache)
+ self.assertNotIn(req_id, host._local_request_metadata)
+
+ host.shutdown_omni_connectors()
+
+ def test_cleanup_removes_per_cycle_ready_state(self):
+ """cleanup_finished_request clears ready/finished carry-over for req-id reuse."""
+ host = self._make_host(stage_id=1)
+ req_id = "req-1"
+
+ host._pending_load_reqs[req_id] = _make_request(req_id, "ext-req-1")
+ host._finished_load_reqs.add(req_id)
+ host._chunk_ready_req_ids.add(req_id)
+ host._chunk_finished_req_ids.add(req_id)
+
+ host.cleanup_finished_request(req_id)
+
+ self.assertNotIn(req_id, host._pending_load_reqs)
+ self.assertNotIn(req_id, host._finished_load_reqs)
+ self.assertNotIn(req_id, host._chunk_ready_req_ids)
+ self.assertNotIn(req_id, host._chunk_finished_req_ids)
+
+ host.shutdown_omni_connectors()
+
+ def test_cleanup_without_mapping(self):
+ """cleanup works for Stage-0 where _request_ids_mapping isn't set."""
+ host = self._make_host(stage_id=0)
+ host._stage_id = 0
+ req_id = "req-1"
+
+ # Stage-0 uses req_id directly (no ext_id mapping)
+ host._put_req_chunk[req_id] = 3
+ host._get_req_chunk[req_id] = 0
+
+ host.cleanup_finished_request(req_id)
+
+ self.assertNotIn(req_id, host._put_req_chunk)
+ self.assertNotIn(req_id, host._get_req_chunk)
+
+ host.shutdown_omni_connectors()
+
+ def test_prune_inactive_requests_cleans_stale_state_but_keeps_active(self):
+ """Inactive request IDs should be pruned without touching active ones."""
+ host = self._make_host(stage_id=1)
+ active_req_id = "req-active"
+ stale_req_id = "req-stale"
+ stale_ext_id = "ext-stale"
+
+ host._request_ids_mapping[active_req_id] = "ext-active"
+ host._request_ids_mapping[stale_req_id] = stale_ext_id
+ host._put_req_chunk[stale_ext_id] = 2
+ host._get_req_chunk[stale_req_id] = 1
+ host._finished_load_reqs.add(stale_req_id)
+ host._chunk_ready_req_ids.update({active_req_id, stale_req_id})
+ host._chunk_finished_req_ids.add(stale_req_id)
+ host._chunk_stream_completed.add(stale_req_id)
+ host._stage_recv_req_ids.add(active_req_id)
+ host._send_side_request_payload[stale_ext_id] = {"stale": True}
+ host._code_prompt_token_ids[stale_ext_id] = [[1, 2, 3]]
+
+ pruned = host.prune_inactive_requests({active_req_id})
+
+ self.assertEqual(pruned, {stale_req_id})
+ self.assertIn(active_req_id, host._request_ids_mapping)
+ self.assertIn(active_req_id, host._chunk_ready_req_ids)
+ self.assertIn(active_req_id, host._stage_recv_req_ids)
+ self.assertNotIn(stale_req_id, host._request_ids_mapping)
+ self.assertNotIn(stale_ext_id, host._put_req_chunk)
+ self.assertNotIn(stale_req_id, host._get_req_chunk)
+ self.assertNotIn(stale_req_id, host._pending_load_reqs)
+ self.assertNotIn(stale_req_id, host._finished_load_reqs)
+ self.assertNotIn(stale_req_id, host._chunk_ready_req_ids)
+ self.assertNotIn(stale_req_id, host._chunk_finished_req_ids)
+ self.assertNotIn(stale_req_id, host._chunk_stream_completed)
+ self.assertNotIn(stale_req_id, host._stage_recv_req_ids)
+ self.assertNotIn(stale_ext_id, host._send_side_request_payload)
+ self.assertNotIn(stale_ext_id, host._code_prompt_token_ids)
+
+ host.shutdown_omni_connectors()
+
+ def test_prune_inactive_requests_keeps_recently_received_full_payload_state(self):
+ """Late bg-thread receives must survive until the scheduler catches up."""
+ host = self._make_host(stage_id=1)
+ req_id = "req-recv-race"
+ ext_id = "ext-recv-race"
+
+ host._request_ids_mapping[req_id] = ext_id
+ host._put_req_chunk[ext_id] = 1
+ host._local_stage_payload_cache[req_id] = {"engine_inputs": {"ids": [1, 2, 3]}}
+ host._local_request_metadata[req_id] = {"next_stage_prompt_len": 3}
+ host._stage_recv_req_ids.add(req_id)
+
+ pruned = host.prune_inactive_requests(set())
+
+ self.assertEqual(pruned, set())
+ self.assertIn(req_id, host._request_ids_mapping)
+ self.assertIn(req_id, host._local_stage_payload_cache)
+ self.assertIn(req_id, host._local_request_metadata)
+ self.assertIn(req_id, host._stage_recv_req_ids)
+ self.assertIn(ext_id, host._put_req_chunk)
+
+ # Once the scheduler has consumed the wake-up and the request really
+ # disappears from all protected sets, prune should clean it up.
+ host._stage_recv_req_ids.clear()
+ host._local_stage_payload_cache.clear()
+ host._local_request_metadata.clear()
+
+ pruned = host.prune_inactive_requests(set())
+
+ self.assertEqual(pruned, {req_id})
+ self.assertNotIn(req_id, host._request_ids_mapping)
+ self.assertNotIn(ext_id, host._put_req_chunk)
+
+ host.shutdown_omni_connectors()
+
+
+class TestSendChunkCachesMapping(unittest.TestCase):
+ """Test that send_chunk caches internal→external req ID mapping."""
+
+ def test_send_chunk_populates_request_ids_mapping(self):
+ """send_chunk should cache the internal→external mapping."""
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=0, async_chunk=True),
+ )
+ host._omni_connector = MockConnector(stage_id=0)
+ host._stage_id = 0
+ host._async_chunk = True
+
+ def mock_process(transfer_manager, pooling_output, request):
+ return {"data": "test", "finished": False}
+
+ host._custom_process_func = mock_process
+
+ request = _make_request("internal-1", "external-1")
+ host.send_chunk(request, pooling_output={"v": 1})
+
+ # The mapping should be cached
+ self.assertEqual(
+ host._request_ids_mapping.get("internal-1"),
+ "external-1",
+ )
+
+ time.sleep(0.1)
+ host.shutdown_omni_connectors()
+
+
+class TestLocalPayloadCacheLifecycle(unittest.TestCase):
+ """Unit tests for the local payload cache API (RFC §2.4)."""
+
+ def _make_host(self) -> MixinHost:
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=0),
+ )
+ host._omni_connector = MockConnector(stage_id=0)
+ host._stage_id = 0
+ return host
+
+ def test_put_get_pop(self):
+ host = self._make_host()
+ payload = {"engine_inputs": {"ids": [1, 2, 3]}}
+ host.put_local_stage_payload("r1", payload)
+
+ self.assertEqual(host.get_local_stage_payload("r1"), payload)
+ popped = host.pop_local_stage_payload("r1")
+ self.assertEqual(popped, payload)
+ self.assertIsNone(host.get_local_stage_payload("r1"))
+ host.shutdown_omni_connectors()
+
+ def test_recv_full_payload_inputs_populates_local_cache(self):
+ host = self._make_host()
+ host._omni_connector = MockConnector(stage_id=0)
+ host._stage_id = 0
+
+ # Simulate a full payload already staged by the bg recv path
+ with host._lock:
+ host._local_stage_payload_cache["r1"] = {"tok": [10]}
+ host._stage_recv_req_ids.add("r1")
+
+ host.recv_full_payload_inputs(scheduler_output=None)
+ self.assertEqual(host.get_local_stage_payload("r1"), {"tok": [10]})
+ host.shutdown_omni_connectors()
+
+ def test_rank0_only_polls_connector_for_tp_full_payload(self):
+ host = self._make_host()
+ host._omni_connector = MagicMock()
+ host._stage_id = 2
+ host._local_rank = 0
+ host._request_ids_mapping["r1"] = "ext-r1"
+ host._get_req_chunk["r1"] = 0
+ payload = {"tok": [10], "finished": torch.tensor(True)}
+ connector_result = (payload, 123)
+ host._omni_connector.get.return_value = connector_result
+ tp_group = _FakeTPGroup(world_size=2, rank_in_group=0)
+
+ with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
+ made_progress = host._poll_single_request("r1")
+
+ self.assertTrue(made_progress)
+ host._omni_connector.get.assert_called_once_with("1", "2", "ext-r1_1_0")
+ self.assertEqual(tp_group.broadcast_inputs, [])
+ self.assertEqual(host.get_local_stage_payload("r1"), payload)
+ self.assertIn("r1", host._full_payload_pending_broadcast_req_ids)
+ self.assertNotIn("r1", host._stage_recv_req_ids)
+ self.assertIsNone(host.get_local_request_metadata("r1"))
+ host.shutdown_omni_connectors()
+
+ def test_tp_follower_skips_connector_poll_for_full_payload(self):
+ host = self._make_host()
+ host._omni_connector = MagicMock()
+ host._stage_id = 2
+ host._local_rank = 1
+ host._request_ids_mapping["r1"] = "ext-r1"
+ host._get_req_chunk["r1"] = 0
+ tp_group = _FakeTPGroup(world_size=2, rank_in_group=1)
+
+ with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
+ made_progress = host._poll_single_request("r1")
+
+ self.assertFalse(made_progress)
+ host._omni_connector.get.assert_not_called()
+ self.assertEqual(tp_group.broadcast_inputs, [])
+ self.assertNotIn("r1", host._local_stage_payload_cache)
+ host.shutdown_omni_connectors()
+
+ def test_recv_full_payload_inputs_broadcasts_tp_leader_results_to_followers(self):
+ host = self._make_host()
+ host._omni_connector = MagicMock()
+ host._stage_id = 2
+ host._local_rank = 1
+ host._pending_load_reqs["r1"] = object()
+ payload = {"tok": [10], "finished": torch.tensor(True)}
+ tp_group = _FakeTPGroup(world_size=2, rank_in_group=1, follower_result={"r1": payload})
+
+ with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
+ results = host.recv_full_payload_inputs(scheduler_output=None)
+
+ self.assertEqual(results, {"r1": payload})
+ self.assertEqual(host.get_local_stage_payload("r1"), payload)
+ self.assertEqual(host.get_local_request_metadata("r1"), {})
+ self.assertEqual(host._stage_recv_req_ids, {"r1"})
+ self.assertNotIn("r1", host._pending_load_reqs)
+ self.assertEqual(tp_group.broadcast_inputs, [None])
+ host.shutdown_omni_connectors()
+
+
+class TestTPAsyncChunkFanout(unittest.TestCase):
+ def _make_host(self, rank: int) -> MixinHost:
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"),
+ )
+ host._omni_connector = MagicMock()
+ host._stage_id = 2
+ host._async_chunk = True
+ host._model_mode = "gen"
+ host._local_rank = rank
+ host._request_ids_mapping["r1"] = "ext-r1"
+ host._get_req_chunk["r1"] = 0
+ return host
+
+ def test_rank0_only_polls_connector_for_tp_async_chunk(self):
+ host = self._make_host(rank=0)
+ payload = {
+ "code_predictor_codes": [10, 11],
+ "left_context_size": 0,
+ "finished": torch.tensor(False),
+ }
+ host._omni_connector.get.return_value = (payload, 123)
+ tp_group = _FakeTPGroup(world_size=2, rank_in_group=0)
+
+ with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
+ made_progress = host._poll_single_request("r1")
+
+ self.assertTrue(made_progress)
+ host._omni_connector.get.assert_called_once_with("1", "2", "ext-r1_1_0")
+ self.assertEqual(host.get_local_stage_payload("r1"), payload)
+ self.assertIn("r1", host._finished_load_reqs)
+ self.assertIn("r1", host._async_chunk_updated_req_ids)
+ self.assertEqual(tp_group.broadcast_inputs, [])
+ host.shutdown_omni_connectors()
+
+ def test_tp_follower_skips_connector_poll_for_async_chunk(self):
+ host = self._make_host(rank=1)
+ tp_group = _FakeTPGroup(world_size=2, rank_in_group=1)
+
+ with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
+ made_progress = host._poll_single_request("r1")
+
+ self.assertFalse(made_progress)
+ host._omni_connector.get.assert_not_called()
+ self.assertIsNone(host.get_local_stage_payload("r1"))
+ self.assertEqual(tp_group.broadcast_inputs, [])
+ host.shutdown_omni_connectors()
+
+ def test_get_output_broadcasts_tp_async_chunk_payloads_to_followers(self):
+ host = self._make_host(rank=1)
+ host._pending_load_reqs["r1"] = object()
+ payload = {
+ "code_predictor_codes": [10, 11],
+ "left_context_size": 0,
+ "finished": torch.tensor(True),
+ }
+ packet = {
+ "staged_payloads": {"r1": payload},
+ "request_metadata": {"r1": {"code_predictor_codes": [10, 11], "left_context_size": 0}},
+ "newly_finished": {"r1"},
+ "chunk_finished": {"r1"},
+ }
+ tp_group = _FakeTPGroup(world_size=2, rank_in_group=1, follower_result=packet)
+
+ with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
+ output = host.get_omni_connector_output()
+
+ self.assertEqual(output.chunk_ready_req_ids, {"r1"})
+ self.assertEqual(output.chunk_finished_req_ids, {"r1"})
+ self.assertEqual(
+ output.request_metadata,
+ {"r1": {"code_predictor_codes": [10, 11], "left_context_size": 0}},
+ )
+ self.assertEqual(host.get_local_stage_payload("r1"), payload)
+ self.assertNotIn("r1", host._pending_load_reqs)
+ self.assertIn("r1", host._chunk_stream_completed)
+ self.assertEqual(tp_group.broadcast_inputs, [None])
+ host.shutdown_omni_connectors()
+
+
+class TestKVTransferLifecycle(unittest.TestCase):
+ """Unit tests for KV transfer lifecycle methods."""
+
+ def _make_host(self) -> MixinHost:
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=0),
+ )
+ return host
+
+ def test_mark_drain_ack_complete(self):
+ host = self._make_host()
+ self.assertFalse(host.has_pending_kv_work())
+
+ host.mark_kv_transfer("r1", seq_len=100, block_ids=[0, 1, 2])
+ self.assertTrue(host.has_pending_kv_work())
+ self.assertTrue(host.is_kv_transfer_triggered("r1"))
+
+ # Drain moves pending → active
+ pending = host.drain_pending_kv_transfers()
+ self.assertEqual(pending, {"r1": {"seq_len": 100, "block_ids": [0, 1, 2]}})
+ self.assertIn("r1", host._kv_active_transfers)
+ self.assertTrue(host.has_pending_kv_work())
+
+ # Ack moves active → completed
+ host.ack_kv_transfers(["r1"])
+ self.assertNotIn("r1", host._kv_active_transfers)
+ self.assertIn("r1", host._kv_completed_transfers)
+
+ # Drain completed
+ completed = host.drain_completed_kv_transfers()
+ self.assertEqual(completed, {"r1"})
+ self.assertFalse(host.has_pending_kv_work())
+ host.shutdown_omni_connectors()
+
+ def test_mark_dedup(self):
+ host = self._make_host()
+ host.mark_kv_transfer("r1", seq_len=100, block_ids=[0])
+ host.mark_kv_transfer("r1", seq_len=200, block_ids=[0, 1])
+ # Second mark is a no-op
+ self.assertEqual(host._kv_pending_transfers["r1"]["seq_len"], 100)
+ host.shutdown_omni_connectors()
+
+ def test_cleanup_removes_kv_state(self):
+ host = self._make_host()
+ host.mark_kv_transfer("r1", seq_len=50, block_ids=[0])
+ host.drain_pending_kv_transfers()
+ host.cleanup_finished_request("r1")
+ self.assertFalse(host.is_kv_transfer_triggered("r1"))
+ self.assertNotIn("r1", host._kv_active_transfers)
+ self.assertFalse(host.has_pending_kv_work())
+ host.shutdown_omni_connectors()
+
+
+class TestAsyncPayloadLifecycle(unittest.TestCase):
+ """Regression tests for async payload delivery lifecycle."""
+
+ def test_send_side_request_payload_not_cleared_before_payload_is_consumable(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"),
+ )
+ host._request_ids_mapping["r1"] = "r1"
+ payload = {
+ "thinker_decode_embeddings": torch.ones(1, 2),
+ "thinker_output_token_ids": [1],
+ "override_keys": ["thinker_decode_embeddings", "thinker_output_token_ids"],
+ "finished": torch.tensor(False),
+ }
+
+ host._accumulate_payload("r1", dict(payload))
+ with host._lock:
+ host._finished_load_reqs.add("r1")
+
+ host.get_omni_connector_output()
+ self.assertIn("r1", host._send_side_request_payload)
+ host.shutdown_omni_connectors()
+
+ def test_payload_consumable_ignores_token_horizon_only_updates(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"),
+ )
+ payload = {
+ "thinker_output_token_ids": [1, 2, 3],
+ "finished": torch.tensor(False),
+ "override_keys": [
+ "thinker_output_token_ids",
+ "thinker_decode_embeddings_token_start",
+ "thinker_decode_embeddings_token_end",
+ ],
+ "thinker_decode_embeddings_token_start": 2,
+ "thinker_decode_embeddings_token_end": 3,
+ }
+ self.assertFalse(host._payload_is_consumable(payload))
+ host.shutdown_omni_connectors()
+
+ def test_payload_consumable_accepts_decode_embeddings(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"),
+ )
+ payload = {
+ "thinker_output_token_ids": [1, 2, 3],
+ "thinker_decode_embeddings": torch.ones(1, 2),
+ "finished": torch.tensor(False),
+ }
+ self.assertTrue(host._payload_is_consumable(payload))
+ host.shutdown_omni_connectors()
+
+ def test_ar_metadata_only_followup_chunk_does_not_rewake_request(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"),
+ )
+ host._omni_connector = MagicMock()
+ host._stage_id = 1
+ host._async_chunk = True
+ host._model_mode = "ar"
+ host._request_ids_mapping["r1"] = "ext-r1"
+ host._get_req_chunk["r1"] = 0
+
+ host._omni_connector.get.side_effect = [
+ (
+ {
+ "thinker_decode_embeddings": torch.ones(1, 2),
+ "finished": torch.tensor(False),
+ },
+ 1,
+ ),
+ (
+ {
+ "next_stage_prompt_len": 7,
+ "finished": torch.tensor(False),
+ },
+ 1,
+ ),
+ ]
+
+ host._poll_single_request("r1")
+ output1 = host.get_omni_connector_output()
+ self.assertEqual(output1.chunk_ready_req_ids, {"r1"})
+
+ host._poll_single_request("r1")
+ output2 = host.get_omni_connector_output()
+ self.assertEqual(output2.chunk_ready_req_ids, set())
+ self.assertEqual(output2.request_metadata, {"r1": {"next_stage_prompt_len": 7}})
+
+ host.shutdown_omni_connectors()
+
+ def test_non_ar_recv_does_not_overwrite_unconsumed_staged_chunk(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"),
+ )
+ host._omni_connector = MagicMock()
+ host._stage_id = 2
+ host._async_chunk = True
+ host._model_mode = "gen"
+ host._request_ids_mapping["r1"] = "ext-r1"
+ host._get_req_chunk["r1"] = 1
+ host._local_stage_payload_cache["r1"] = {
+ "code_predictor_codes": [1, 2, 3],
+ "left_context_size": 0,
+ "finished": torch.tensor(False),
+ }
+
+ made_progress = host._poll_single_request("r1")
+
+ self.assertFalse(made_progress)
+ host._omni_connector.get.assert_not_called()
+ self.assertEqual(host._get_req_chunk["r1"], 1)
+
+ host.shutdown_omni_connectors()
+
+ def test_non_ar_recv_waits_for_scheduler_handoff_before_fetching_next_chunk(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"),
+ )
+ host._omni_connector = MagicMock()
+ host._stage_id = 2
+ host._async_chunk = True
+ host._model_mode = "gen"
+ host._request_ids_mapping["r1"] = "ext-r1"
+ host._get_req_chunk["r1"] = 1
+ host._local_request_metadata["r1"] = {
+ "code_predictor_codes": [10, 11, 12],
+ "left_context_size": 0,
+ }
+ host._finished_load_reqs.add("r1")
+
+ made_progress = host._poll_single_request("r1")
+
+ self.assertFalse(made_progress)
+ host._omni_connector.get.assert_not_called()
+ self.assertEqual(host._get_req_chunk["r1"], 1)
+
+ output = host.get_omni_connector_output()
+ self.assertEqual(output.request_metadata["r1"]["code_predictor_codes"], [10, 11, 12])
+ self.assertEqual(output.chunk_ready_req_ids, {"r1"})
+
+ host._omni_connector.get.return_value = (
+ {
+ "code_predictor_codes": [20, 21, 22],
+ "left_context_size": 0,
+ "finished": torch.tensor(False),
+ },
+ 1,
+ )
+ made_progress = host._poll_single_request("r1")
+
+ self.assertTrue(made_progress)
+ host._omni_connector.get.assert_called_once()
+ self.assertEqual(host._get_req_chunk["r1"], 2)
+
+ host.shutdown_omni_connectors()
+
+
+class TestRankAwareKVRouting(unittest.TestCase):
+ def _make_host(self, *, from_tp: int, to_tp: int, local_rank: int) -> MixinHost:
+ host = MixinHost()
+ host.init_omni_connectors(vllm_config=None, model_config=_make_model_config(stage_id=1))
+ host._from_tp = from_tp
+ host._to_tp = to_tp
+ host._local_rank = local_rank
+ return host
+
+ def test_recv_keys_use_remote_rank_as_from_rank(self):
+ host = self._make_host(from_tp=4, to_tp=2, local_rank=1)
+ self.assertEqual(
+ host.get_rank_aware_kv_keys("req", from_stage=0),
+ ["req_0_0_2_1", "req_0_0_3_1"],
+ )
+ host.shutdown_omni_connectors()
+
+ def test_send_keys_route_from_rank_gt_to_rank(self):
+ host = self._make_host(from_tp=4, to_tp=2, local_rank=3)
+ self.assertEqual(host.get_rank_aware_kv_send_keys("req", from_stage=0), ["req_0_0_3_1"])
+ host.shutdown_omni_connectors()
+
+ def test_invalid_recv_rank_mapping_raises(self):
+ host = self._make_host(from_tp=3, to_tp=2, local_rank=1)
+ with self.assertRaises(ValueError):
+ host.get_rank_aware_kv_keys("req", from_stage=0)
+ host.shutdown_omni_connectors()
+
+ def test_invalid_send_rank_mapping_raises(self):
+ host = self._make_host(from_tp=3, to_tp=2, local_rank=1)
+ with self.assertRaises(ValueError):
+ host.get_rank_aware_kv_send_keys("req", from_stage=0)
+ host.shutdown_omni_connectors()
+
+ def test_merge_rank_sharded_payloads_concatenates_head_dimension(self):
+ host = self._make_host(from_tp=4, to_tp=2, local_rank=0)
+ payloads = [
+ {"layer_blocks": {"key_cache": [torch.ones(2, 1, 3)], "value_cache": [torch.ones(2, 1, 3)]}},
+ {"layer_blocks": {"key_cache": [torch.full((2, 1, 3), 2.0)], "value_cache": [torch.full((2, 1, 3), 2.0)]}},
+ ]
+ merged = host._merge_rank_sharded_kv_payloads(payloads)
+ self.assertEqual(tuple(merged["layer_blocks"]["key_cache"][0].shape), (2, 2, 3))
+ self.assertTrue(torch.equal(merged["layer_blocks"]["key_cache"][0][:, 0], torch.ones(2, 3)))
+ self.assertTrue(torch.equal(merged["layer_blocks"]["key_cache"][0][:, 1], torch.full((2, 3), 2.0)))
+ host.shutdown_omni_connectors()
+
+ def test_slice_rank_sharded_payload_splits_head_dimension(self):
+ host = self._make_host(from_tp=2, to_tp=4, local_rank=1)
+ payload = {
+ "layer_blocks": {
+ "key_cache": [torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)],
+ "value_cache": [torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)],
+ },
+ "metadata": {},
+ }
+ sliced = host._slice_rank_sharded_kv_payload(payload)
+ self.assertEqual(tuple(sliced["layer_blocks"]["key_cache"][0].shape), (2, 2, 3))
+ expected = torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)[:, 2:4, :]
+ self.assertTrue(torch.equal(sliced["layer_blocks"]["key_cache"][0], expected))
+ host.shutdown_omni_connectors()
+
+
+class TestAttachOmniConnectorOutput(unittest.TestCase):
+ def test_wraps_empty_model_runner_output_when_signals_exist(self):
+ from vllm.v1.worker.gpu_model_runner import EMPTY_MODEL_RUNNER_OUTPUT
+
+ host = MixinHost()
+ host.get_omni_connector_output = lambda: OmniConnectorOutput(chunk_ready_req_ids={"req-1"})
+
+ wrapped = host.attach_omni_connector_output(EMPTY_MODEL_RUNNER_OUTPUT)
+
+ self.assertIsNot(wrapped, EMPTY_MODEL_RUNNER_OUTPUT)
+ self.assertEqual(wrapped.omni_connector_output.chunk_ready_req_ids, {"req-1"})
+
+
+class TestConnectorConfigValidation(unittest.TestCase):
+ def test_invalid_connector_name_raises(self):
+ host = MixinHost()
+ model_config = _make_model_config(stage_id=1)
+ model_config.stage_connector_config = {"name": " "}
+
+ with self.assertRaisesRegex(RuntimeError, "missing connector name"):
+ host.init_omni_connectors(vllm_config=None, model_config=model_config)
+
+
+class _FailingConnector:
+ """Connector whose put() fails a configurable number of times."""
+
+ def __init__(self, fail_count: int = 1, raise_on_fail: bool = False):
+ self._fail_count = fail_count
+ self._raise_on_fail = raise_on_fail
+ self.attempt = 0
+
+ def put(self, from_stage, to_stage, put_key, data):
+ self.attempt += 1
+ if self.attempt <= self._fail_count:
+ if self._raise_on_fail:
+ raise ConnectionError("transient connector error")
+ return False, 0, None
+ return True, len(str(data)), None
+
+ def get(self, *a, **kw):
+ return None
+
+ def close(self):
+ pass
+
+
+class TestSendRetry(unittest.TestCase):
+ """Tests for P1-2: failed connector sends must be retried."""
+
+ def _make_sender(self, connector):
+ sender = MixinHost()
+ sender.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=0, async_chunk=True),
+ )
+ sender._omni_connector = connector
+ sender._stage_id = 0
+ sender._async_chunk = True
+ return sender
+
+ def _make_task(self, req_id="r1"):
+ return {
+ "stage_id": 0,
+ "next_stage_id": 1,
+ "request_id": req_id,
+ "data": {"payload": "test"},
+ }
+
+ def test_send_single_request_returns_false_on_put_failure(self):
+ connector = _FailingConnector(fail_count=999)
+ sender = self._make_sender(connector)
+
+ result = sender._send_single_request(self._make_task())
+ self.assertFalse(result)
+ sender.shutdown_omni_connectors()
+
+ def test_send_single_request_does_not_decrement_on_failure(self):
+ connector = _FailingConnector(fail_count=999)
+ sender = self._make_sender(connector)
+ sender._pending_save_counts["r1"] = 1
+
+ sender._send_single_request(self._make_task())
+ self.assertEqual(sender._pending_save_counts.get("r1"), 1, "pending count must NOT be decremented on failure")
+ sender.shutdown_omni_connectors()
+
+ def test_send_single_request_decrements_on_success(self):
+ connector = MockConnector(stage_id=0)
+ sender = self._make_sender(connector)
+ sender._pending_save_counts["r1"] = 1
+
+ result = sender._send_single_request(self._make_task())
+ self.assertTrue(result)
+ self.assertNotIn("r1", sender._pending_save_counts, "pending count should be zero/removed on success")
+ sender.shutdown_omni_connectors()
+
+ def test_requeue_or_drop_requeues_on_first_failure(self):
+ sender = self._make_sender(MockConnector(stage_id=0))
+ task = self._make_task()
+
+ sender._requeue_or_drop_failed_send(task)
+
+ self.assertEqual(task.get("_retry_count"), 1)
+ with sender._lock:
+ dq = sender._pending_save_reqs.get("r1")
+ self.assertIsNotNone(dq)
+ self.assertEqual(len(dq), 1)
+ sender.shutdown_omni_connectors()
+
+ def test_requeue_or_drop_drops_after_max_retries(self):
+ sender = self._make_sender(MockConnector(stage_id=0))
+ sender._pending_save_counts["r1"] = 1
+ task = self._make_task()
+ task["_retry_count"] = sender._MAX_SEND_RETRIES # already at max
+
+ sender._requeue_or_drop_failed_send(task)
+
+ with sender._lock:
+ dq = sender._pending_save_reqs.get("r1")
+ self.assertTrue(dq is None or len(dq) == 0, "task should NOT be re-enqueued after max retries")
+ self.assertNotIn("r1", sender._pending_save_counts, "pending count should be cleaned up on final drop")
+ sender.shutdown_omni_connectors()
+
+ def test_save_loop_retries_on_exception(self):
+ """Integration: _save_loop retries a task when put() raises."""
+ from collections import deque
+
+ connector = _FailingConnector(fail_count=1, raise_on_fail=True)
+ sender = self._make_sender(connector)
+ task = self._make_task()
+
+ with sender._lock:
+ sender._pending_save_reqs["r1"] = deque([task])
+ sender._pending_save_counts["r1"] = 1
+
+ sender._stop_event.clear()
+
+ def run_one_loop():
+ sender._save_loop()
+
+ sender._stop_event.set() # will exit after one iteration
+ # Run manually instead of threading
+ # Simulate: pop task, send fails, requeue
+ popped_task = None
+ with sender._lock:
+ dq = sender._pending_save_reqs.get("r1")
+ if dq:
+ popped_task = dq.popleft()
+ if not dq:
+ del sender._pending_save_reqs["r1"]
+
+ if popped_task is not None:
+ success = False
+ try:
+ success = sender._send_single_request(popped_task)
+ except Exception:
+ pass
+ if not success:
+ sender._requeue_or_drop_failed_send(popped_task)
+
+ # After first failure, task should be re-enqueued
+ with sender._lock:
+ dq = sender._pending_save_reqs.get("r1")
+ self.assertIsNotNone(dq)
+ self.assertEqual(len(dq), 1)
+ requeued = dq[0]
+ self.assertEqual(requeued.get("_retry_count"), 1)
+
+ # Second attempt should succeed (connector now returns True)
+ success = sender._send_single_request(requeued)
+ self.assertTrue(success)
+ sender.shutdown_omni_connectors()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py
new file mode 100644
index 00000000000..c9d891afb41
--- /dev/null
+++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py
@@ -0,0 +1,380 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Scheduling-side coordination for chunk and full_payload input waiting.
+
+Manages WAITING_FOR_CHUNK and WAITING_FOR_INPUT state transitions
+based on readiness signals from OmniConnectorOutput, without ever
+calling connector.put()/get().
+
+This replaces the scheduling half of OmniChunkTransferAdapter; the
+transport half lives in OmniConnectorModelRunnerMixin.
+"""
+
+from __future__ import annotations
+
+import time
+from collections import deque
+from typing import Any
+
+from vllm.logger import init_logger
+from vllm.v1.request import Request, RequestStatus
+
+logger = init_logger(__name__)
+
+
+class OmniSchedulingCoordinator:
+ """Pure-scheduling coordinator for chunk and full_payload input waiting.
+
+ The Scheduler owns an instance of this class. It consumes readiness
+ signals produced by the Model Runner's ``OmniConnectorModelRunnerMixin``
+ (via ``OmniConnectorOutput``) and manages ``WAITING_FOR_CHUNK`` and
+ ``WAITING_FOR_INPUT`` state transitions accordingly.
+ """
+
+ def __init__(self, scheduler_max_num_seqs: int, stage_id: int = 0, async_chunk: bool = False):
+ self._stage_id = stage_id
+ self._scheduler_max_num_seqs = scheduler_max_num_seqs
+ self._async_chunk = async_chunk
+
+ self.finished_requests: set[str] = set()
+ self.requests_with_ready_chunks: set[str] = set()
+ self._full_payload_input_received: set[str] = set()
+
+ self._waiting_for_chunk_waiting: deque[Any] = deque()
+ self._waiting_for_chunk_running: deque[Any] = deque()
+
+ # Request IDs that were newly registered for chunk recv this cycle.
+ # The engine/Model Runner should call register_chunk_recv() for these
+ # so the bg thread starts polling.
+ self.pending_chunk_registrations: list[Any] = []
+
+ # Requests waiting for full_payload stage input (WAITING_FOR_INPUT).
+ self._waiting_for_input: deque[Any] = deque()
+ self.pending_input_registrations: list[Any] = []
+
+ # Monotonic timestamp recording when each request first entered
+ # WAITING_FOR_CHUNK or WAITING_FOR_INPUT. Used by
+ # collect_timed_out_request_ids() to detect orphaned waits.
+ self._waiting_since: dict[str, float] = {}
+
+ # ------------------------------------------------------------------ #
+ # Core scheduling methods
+ # ------------------------------------------------------------------ #
+
+ def process_pending_chunks(
+ self,
+ waiting_queue: Any,
+ running_queue: list[Request],
+ chunk_ready_req_ids: set[str],
+ chunk_finished_req_ids: set[str],
+ ) -> None:
+ """Transition requests whose chunks have arrived.
+
+ Args:
+ waiting_queue: Scheduler's waiting request queue.
+ running_queue: Scheduler's running request list.
+ chunk_ready_req_ids: IDs with a newly arrived chunk this cycle.
+ chunk_finished_req_ids: IDs whose final chunk has arrived.
+ """
+ if self._stage_id == 0 or not self._async_chunk:
+ return
+
+ terminal_ready_req_ids = chunk_ready_req_ids.intersection(chunk_finished_req_ids)
+ self.finished_requests.update(chunk_finished_req_ids - terminal_ready_req_ids)
+ self.pending_chunk_registrations = []
+
+ self._process_chunk_queue(
+ waiting_queue,
+ self._waiting_for_chunk_waiting,
+ RequestStatus.WAITING,
+ chunk_ready_req_ids,
+ )
+ self._process_chunk_queue(
+ running_queue,
+ self._waiting_for_chunk_running,
+ RequestStatus.RUNNING,
+ chunk_ready_req_ids,
+ )
+ self.finished_requests.update(terminal_ready_req_ids)
+
+ while len(running_queue) > self._scheduler_max_num_seqs:
+ request = running_queue.pop()
+ # Must reset status to WAITING so the scheduler treats it as
+ # schedulable work. KV blocks are NOT freed here (unlike a
+ # real preemption), so PREEMPTED would be incorrect.
+ request.status = RequestStatus.WAITING
+ waiting_queue.prepend_requests([request])
+
+ def process_pending_full_payload_inputs(
+ self,
+ waiting_queue: Any,
+ running_queue: list[Request],
+ stage_recv_req_ids: set[str],
+ ) -> None:
+ """Manage WAITING_FOR_INPUT lifecycle for full_payload_mode.
+
+ For non-Stage-0 stages in full_payload_mode (``async_chunk=False``):
+ 1. Fresh WAITING requests are transitioned to WAITING_FOR_INPUT
+ and registered for bg-thread polling.
+ 2. WAITING_FOR_INPUT requests whose data has arrived (in
+ ``stage_recv_req_ids``) are transitioned back to WAITING.
+ """
+ if self._stage_id == 0:
+ return
+
+ self._full_payload_input_received.update(stage_recv_req_ids)
+ if not self._async_chunk and stage_recv_req_ids:
+ self.finished_requests.update(stage_recv_req_ids)
+ logger.debug(
+ "[Coordinator stage-%s] full_payload recv -> finished_requests: %s",
+ self._stage_id,
+ stage_recv_req_ids,
+ )
+ self.pending_input_registrations = []
+
+ remaining: deque[Any] = deque()
+ for request in self._waiting_for_input:
+ if request.request_id in stage_recv_req_ids:
+ request.status = RequestStatus.WAITING
+ self._waiting_since.pop(request.request_id, None)
+ waiting_queue.add_request(request)
+ else:
+ remaining.append(request)
+ self._waiting_for_input = remaining
+
+ if not self._async_chunk:
+ to_remove: list[Any] = []
+ queue_snapshot = list(waiting_queue)
+ for request in queue_snapshot:
+ if request.status == RequestStatus.WAITING:
+ if request.request_id in self._full_payload_input_received:
+ continue
+ if request.request_id in self.requests_with_ready_chunks:
+ continue
+ if request.request_id in self.finished_requests:
+ continue
+ request.status = RequestStatus.WAITING_FOR_INPUT
+ self._waiting_since.setdefault(request.request_id, time.monotonic())
+ to_remove.append(request)
+ self._waiting_for_input.append(request)
+ self.pending_input_registrations.append(request)
+ elif request.status == RequestStatus.WAITING_FOR_INPUT:
+ if request.request_id in stage_recv_req_ids:
+ request.status = RequestStatus.WAITING
+ self._waiting_since.pop(request.request_id, None)
+ else:
+ to_remove.append(request)
+ self._waiting_for_input.append(request)
+ self.pending_input_registrations.append(request)
+ for request in to_remove:
+ waiting_queue.remove(request)
+
+ def process_pending_full_payload_inputs_legacy(
+ self,
+ waiting_queue: Any,
+ running_queue: list[Request],
+ stage_recv_req_ids: set[str],
+ ) -> None:
+ """Compatibility wrapper for ``process_pending_full_payload_inputs``."""
+ self.process_pending_full_payload_inputs(waiting_queue, running_queue, stage_recv_req_ids)
+
+ def free_finished_request(self, request_id: str) -> None:
+ """Prune internal tracking sets for a freed request to prevent unbounded growth."""
+ self._full_payload_input_received.discard(request_id)
+ self.finished_requests.discard(request_id)
+ self.requests_with_ready_chunks.discard(request_id)
+ self._waiting_since.pop(request_id, None)
+
+ def collect_timed_out_request_ids(
+ self,
+ timeout_s: float,
+ ) -> set[str]:
+ """Return IDs of requests that have been waiting longer than *timeout_s*.
+
+ Uses ``_waiting_since`` timestamps (always up-to-date) to detect
+ timed-out requests. This method is safe to call at any point in
+ the scheduling cycle — it does **not** rely on coordinator internal
+ queues (which are empty after ``restore_queues()``).
+
+ Clears ``_waiting_since`` for timed-out IDs and defensively removes
+ them from coordinator internal queues if present. The caller
+ (scheduler) should then remove the requests from its queues,
+ set ``FINISHED_ERROR``, and call ``_free_request()`` so that
+ ``cleanup_finished_request()`` fires in the model runner mixin.
+ """
+ if timeout_s <= 0:
+ return set()
+ now = time.monotonic()
+ timed_out_ids: set[str] = set()
+ for req_id, start_time in self._waiting_since.items():
+ if now - start_time > timeout_s:
+ timed_out_ids.add(req_id)
+ if not timed_out_ids:
+ return set()
+
+ # Defensively remove from coordinator internal queues (may already
+ # be empty if restore_queues() has run).
+ for queue_attr in (
+ "_waiting_for_chunk_waiting",
+ "_waiting_for_chunk_running",
+ "_waiting_for_input",
+ ):
+ queue = getattr(self, queue_attr)
+ remaining: deque[Any] = deque()
+ for request in queue:
+ if request.request_id not in timed_out_ids:
+ remaining.append(request)
+ setattr(self, queue_attr, remaining)
+
+ for req_id in timed_out_ids:
+ self._waiting_since.pop(req_id, None)
+ logger.warning(
+ "[Coordinator stage-%s] Request %s timed out waiting for chunk/input (waited > %.0fs)",
+ self._stage_id,
+ req_id,
+ timeout_s,
+ )
+
+ return timed_out_ids
+
+ def restore_queues(
+ self,
+ waiting_queue: Any,
+ running_queue: list[Request],
+ ) -> None:
+ """Return waiting-for-chunk/input requests to scheduling queues."""
+ for request in self._waiting_for_chunk_waiting:
+ waiting_queue.add_request(request)
+ self._waiting_for_chunk_waiting = deque()
+
+ if self._waiting_for_chunk_running:
+ running_queue.extend(self._waiting_for_chunk_running)
+ self._waiting_for_chunk_running = deque()
+
+ for request in self._waiting_for_input:
+ waiting_queue.add_request(request)
+ self._waiting_for_input = deque()
+
+ def update_request_metadata(
+ self,
+ requests: dict[str, Request],
+ request_metadata: dict[str, dict[str, Any]],
+ model_mode: str = "ar",
+ ) -> None:
+ """Apply received scheduling metadata to request objects.
+
+ For AR mode: only scheduler-visible metadata is applied locally.
+ For Generation mode: updates ``request.prompt_token_ids``.
+
+ Additionally, if the payload contains ``next_stage_prompt_len``,
+ updates the request's ``prompt_token_ids`` to the correct length.
+ """
+ for req_id, metadata in request_metadata.items():
+ request = requests.get(req_id)
+ if request is None:
+ continue
+
+ # Handle next_stage_prompt_len if present (for models like Qwen3-Omni).
+ # Only apply when the request has not started decoding yet
+ # (no output tokens). Resetting a mid-decode request would
+ # destroy generated tokens and desync KV cache state.
+ if "next_stage_prompt_len" in metadata:
+ next_len = metadata["next_stage_prompt_len"]
+ if isinstance(next_len, int) and next_len > 0:
+ output_token_ids = getattr(request, "_output_token_ids", None)
+ has_decode_output = output_token_ids is not None and len(output_token_ids) > 0
+ if has_decode_output:
+ logger.debug(
+ "[Coordinator stage-%s] Skipping prompt resize for req %s: "
+ "request already has %s output tokens",
+ self._stage_id,
+ req_id,
+ len(output_token_ids),
+ )
+ else:
+ current_prompt_ids = getattr(request, "prompt_token_ids", []) or []
+ current_prompt_len = len(current_prompt_ids)
+ if current_prompt_len != next_len or getattr(request, "num_prompt_tokens", None) != next_len:
+ new_prompt = [0] * next_len
+ request.prompt_token_ids = new_prompt
+ request.num_prompt_tokens = next_len
+ request._all_token_ids.clear()
+ request._all_token_ids.extend(new_prompt)
+ request._output_token_ids.clear()
+ request.num_computed_tokens = 0
+ logger.debug(
+ "[Coordinator stage-%s] Updated prompt_token_ids length to %s for req %s",
+ self._stage_id,
+ next_len,
+ req_id,
+ )
+
+ if model_mode != "ar":
+ new_ids = metadata.get("code_predictor_codes", [])
+ runtime_seed = None
+ if "left_context_size" in metadata:
+ runtime_seed = {
+ "left_context_size": metadata["left_context_size"],
+ }
+ request._omni_initial_model_buffer = runtime_seed
+ if new_ids:
+ request.prompt_token_ids = new_ids
+ request.num_computed_tokens = 0
+
+ def postprocess_scheduler_output(
+ self,
+ scheduler_output: Any,
+ requests: dict[str, Request] | None = None,
+ ) -> None:
+ """Clear per-cycle ready state after scheduler output is materialized."""
+ self._clear_chunk_ready(scheduler_output)
+
+ # ------------------------------------------------------------------ #
+ # Internal helpers
+ # ------------------------------------------------------------------ #
+
+ def _process_chunk_queue(
+ self,
+ queue: Any,
+ waiting_for_chunk_list: deque[Any],
+ target_status: RequestStatus,
+ chunk_ready_req_ids: set[str],
+ ) -> None:
+ queue_snapshot = list(queue)
+ for request in queue_snapshot:
+ if request.status != RequestStatus.WAITING_FOR_CHUNK:
+ if request.request_id in self.requests_with_ready_chunks:
+ continue
+ if request.request_id in self.finished_requests:
+ continue
+ if request.status == RequestStatus.WAITING_FOR_INPUT:
+ continue
+ if request.request_id in chunk_ready_req_ids:
+ self.requests_with_ready_chunks.add(request.request_id)
+ continue
+ self.pending_chunk_registrations.append(request)
+ request.status = RequestStatus.WAITING_FOR_CHUNK
+ self._waiting_since.setdefault(request.request_id, time.monotonic())
+ else:
+ if request.request_id in chunk_ready_req_ids:
+ request.status = target_status
+ self.requests_with_ready_chunks.add(request.request_id)
+ self._waiting_since.pop(request.request_id, None)
+ continue
+ queue.remove(request)
+ waiting_for_chunk_list.append(request)
+
+ def _clear_chunk_ready(self, scheduler_output: Any) -> None:
+ if scheduler_output.scheduled_new_reqs:
+ for req_data in scheduler_output.scheduled_new_reqs:
+ self.requests_with_ready_chunks.discard(
+ getattr(req_data, "req_id", None),
+ )
+
+ if scheduler_output.scheduled_cached_reqs:
+ for req_id in scheduler_output.scheduled_cached_reqs.req_ids:
+ self.requests_with_ready_chunks.discard(req_id)
+
+
+# Backward-compatible alias
+ChunkSchedulingCoordinator = OmniSchedulingCoordinator
diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py
index 32ea5bf64dc..535f053c388 100644
--- a/vllm_omni/diffusion/worker/diffusion_model_runner.py
+++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py
@@ -35,11 +35,12 @@
from vllm_omni.diffusion.worker.utils import DiffusionRequestState, RunnerOutput
from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager
from vllm_omni.platforms import current_omni_platform
+from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin
logger = init_logger(__name__)
-class DiffusionModelRunner:
+class DiffusionModelRunner(OmniConnectorModelRunnerMixin):
"""
Model runner that handles model loading and execution for diffusion models.
diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py
index 9a7bb670658..2c2c1d21c11 100644
--- a/vllm_omni/outputs.py
+++ b/vllm_omni/outputs.py
@@ -9,6 +9,33 @@
from vllm_omni.inputs.data import OmniPromptType
+@dataclass
+class OmniConnectorOutput:
+ """Communication results from Model Runner to Scheduler.
+
+ Carries transfer readiness signals so the Scheduler can make scheduling
+ decisions without ever calling connector.put()/get() directly.
+
+ Attributes:
+ chunk_ready_req_ids: Request IDs with newly arrived chunks this cycle.
+ chunk_finished_req_ids: Request IDs whose final chunk has arrived.
+ request_metadata: Lightweight scheduling metadata keyed by request ID
+ (e.g. next_stage_prompt_len, code_predictor_codes, left_context_size).
+ Full payloads are owned by the Model Runner's local cache.
+ kv_sent_req_ids: Request IDs whose KV cache was successfully sent.
+ stage_recv_req_ids: Request IDs that received batch stage inputs.
+ has_pending_kv_work: True if the mixin has pending, active, or
+ completed KV transfers that the scheduler should account for.
+ """
+
+ chunk_ready_req_ids: set[str] = field(default_factory=set)
+ chunk_finished_req_ids: set[str] = field(default_factory=set)
+ request_metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
+ kv_sent_req_ids: list[str] = field(default_factory=list)
+ stage_recv_req_ids: set[str] = field(default_factory=set)
+ has_pending_kv_work: bool = False
+
+
class OmniModelRunnerOutput(ModelRunnerOutput):
"""Model runner output for omni models.
@@ -24,6 +51,7 @@ class OmniModelRunnerOutput(ModelRunnerOutput):
# IDs of requests whose KV cache has been extracted from GPU/NPU to CPU.
# The Scheduler can safely free the block tables for these requests.
kv_extracted_req_ids: list[str] | None = None
+ omni_connector_output: OmniConnectorOutput | None = None
@dataclass
diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py
index 72e745fb172..868140d265b 100644
--- a/vllm_omni/worker/gpu_ar_model_runner.py
+++ b/vllm_omni/worker/gpu_ar_model_runner.py
@@ -40,6 +40,7 @@
from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager
from vllm_omni.outputs import OmniModelRunnerOutput
from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
+from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin
logger = init_logger(__name__)
@@ -60,7 +61,7 @@ class ExecuteModelState(NamedTuple):
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None
-class GPUARModelRunner(OmniGPUModelRunner):
+class GPUARModelRunner(OmniGPUModelRunner, OmniConnectorModelRunnerMixin):
"""Autoregressive GPU model runner that returns hidden states per request.
Follows the v0.12 two-phase execute/sample flow from GPUModelRunner, and
diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py
index d95b676f6d6..f10115c8e90 100644
--- a/vllm_omni/worker/gpu_generation_model_runner.py
+++ b/vllm_omni/worker/gpu_generation_model_runner.py
@@ -39,11 +39,12 @@
from vllm_omni.outputs import OmniModelRunnerOutput
from vllm_omni.worker.gpu_ar_model_runner import ExecuteModelState
from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
+from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin
logger = logging.getLogger(__name__)
-class GPUGenerationModelRunner(OmniGPUModelRunner):
+class GPUGenerationModelRunner(OmniGPUModelRunner, OmniConnectorModelRunnerMixin):
"""Generation model runner for vLLM-Omni (non-autoregressive).
- Reuses GPUModelRunner preparation, multimodal handling, and TP/PP/DP glue.
diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py
new file mode 100644
index 00000000000..e0df3ba3d7a
--- /dev/null
+++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py
@@ -0,0 +1,2125 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unified data-plane communication mixin for Model Runners.
+
+All connector.put()/get() calls are consolidated here. Background I/O
+threads handle async_chunk and full_payload_mode transfers; KV cache is delegated to
+the existing OmniKVTransferManager (to be absorbed later).
+
+The mixin reports transfer results via OmniConnectorOutput so that the
+Scheduler can make scheduling decisions without ever touching a connector.
+"""
+
+from __future__ import annotations
+
+import importlib
+import inspect
+import os
+import threading
+from collections import defaultdict, deque
+from types import SimpleNamespace
+from typing import TYPE_CHECKING, Any
+
+import torch
+from vllm.distributed.parallel_state import get_tp_group
+from vllm.logger import init_logger
+
+from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory
+from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec
+from vllm_omni.outputs import OmniConnectorOutput
+from vllm_omni.worker.payload_span import (
+ THINKER_DECODE_EMBEDDINGS_KEY,
+ THINKER_DECODE_TOKEN_END_KEY,
+ THINKER_DECODE_TOKEN_START_KEY,
+ THINKER_OUTPUT_TOKEN_IDS_KEY,
+ get_tensor_span,
+ merge_tensor_spans,
+)
+
+if TYPE_CHECKING:
+ from vllm_omni.distributed.omni_connectors.connectors.base import (
+ OmniConnectorBase,
+ )
+ from vllm_omni.distributed.omni_connectors.kv_transfer_manager import (
+ OmniKVTransferManager,
+ )
+
+logger = init_logger(__name__)
+
+
+class OmniConnectorModelRunnerMixin:
+ """Unified data-plane communication mixin for Model Runners.
+
+ Provides three transfer modes through a single pair of bg I/O threads:
+ - **full_payload_mode**: ``recv_full_payload_inputs`` / ``send_full_payload_outputs``
+ - **Streaming (async_chunk)**: ``recv_chunk`` / ``send_chunk``
+ - **KV cache**: ``send_kv_cache`` / ``recv_kv_cache`` (delegates to
+ the existing ``OmniKVTransferManager``)
+
+ The mixin owns connector instances and background threads. It never
+ touches scheduling queues -- readiness is communicated to the Scheduler
+ via ``OmniConnectorOutput``.
+ """
+
+ # ------------------------------------------------------------------ #
+ # Init / Shutdown
+ # ------------------------------------------------------------------ #
+
+ def init_omni_connectors(
+ self,
+ vllm_config: Any,
+ model_config: Any,
+ kv_transfer_manager: OmniKVTransferManager | None = None,
+ ) -> None:
+ """Initialize connectors and background threads.
+
+ Args:
+ vllm_config: Full vLLM config object.
+ model_config: Stage-level model config with connector settings.
+ kv_transfer_manager: Existing KV transfer manager to delegate to.
+ """
+ self._omni_connector: OmniConnectorBase | None = self._create_connector(model_config)
+ self._kv_transfer_manager = kv_transfer_manager
+
+ self._async_chunk: bool = getattr(model_config, "async_chunk", False)
+ self._model_mode: str = getattr(model_config, "worker_type", "ar")
+ stage_id = getattr(model_config, "stage_id", 0)
+ if isinstance(stage_id, str):
+ stage_id = int(stage_id)
+ self._stage_id: int = stage_id if isinstance(stage_id, int) else 0
+
+ self._custom_process_func_path, self._custom_process_func = self._load_custom_func(model_config)
+ self._custom_process_supports_is_finished = self._custom_process_supports_is_finished_kwarg()
+ logger.info(
+ "[Stage-%s] init_omni_connectors: async_chunk=%s, custom_process_func=%s, connector=%s, func_path=%s",
+ self._stage_id,
+ self._async_chunk,
+ self._custom_process_func,
+ type(self._omni_connector).__name__ if self._omni_connector else None,
+ self._custom_process_func_path,
+ )
+
+ # -- next stage ID (from connector config or default stage_id + 1) --
+ self._next_stage_id: int = self._resolve_next_stage_id(model_config)
+
+ # -- heterogeneous TP rank support --
+ rank_cfg = self._parse_rank_mapping(model_config)
+ self._from_tp: int = rank_cfg["from_tp"]
+ self._to_tp: int = rank_cfg["to_tp"]
+ self._local_rank: int = rank_cfg["local_rank"]
+ if self._kv_transfer_manager is not None:
+ self._kv_transfer_manager.kv_send_key_builder = self.get_rank_aware_kv_send_keys
+ self._kv_transfer_manager.kv_recv_key_builder = self.get_rank_aware_kv_keys
+ self._kv_transfer_manager.kv_payload_merger = self._merge_rank_sharded_kv_payloads
+ self._kv_transfer_manager.kv_payload_slicer = self._slice_rank_sharded_kv_payload
+
+ # -- chunk index tracking (ported from OmniChunkTransferAdapter) --
+ self._put_req_chunk: dict[str, int] = defaultdict(int)
+ self._get_req_chunk: dict[str, int] = defaultdict(int)
+ # Send-side async accumulation / staging buffer. Receive-side payload
+ # ownership lives in ``_local_stage_payload_cache``.
+ self._send_side_request_payload: dict[str, dict[str, Any]] = {}
+ self._code_prompt_token_ids: dict[str, list[list[int]]] = defaultdict(list)
+ self._request_ids_mapping: dict[str, str] = {}
+
+ # -- async I/O state (shared by chunk + full_payload_mode) --
+ self._pending_load_reqs: dict[str, Any] = {}
+ self._finished_load_reqs: set[str] = set()
+ self._pending_save_reqs: dict[str, deque] = {}
+ self._pending_save_counts: dict[str, int] = defaultdict(int)
+ self._deferred_send_cleanup: set[str] = set()
+ # -- per-cycle output accumulator --
+ self._chunk_ready_req_ids: set[str] = set()
+ self._chunk_finished_req_ids: set[str] = set()
+ self._stage_recv_req_ids: set[str] = set()
+ self._full_payload_pending_broadcast_req_ids: set[str] = set()
+ self._async_chunk_updated_req_ids: set[str] = set()
+
+ # -- Model Runner local payload cache (RFC §2.4) --
+ # Full stage payloads land here first on the recv side. We
+ # intentionally do not write connector recv results straight into
+ # `model_intermediate_buffer`: runner-owned runtime state is
+ # materialized later by `_sync_local_stage_payloads()` on the
+ # model thread. This keeps recv timing separate from execute-step
+ # visibility and avoids mixing connector I/O with model runtime
+ # ownership.
+ self._local_stage_payload_cache: dict[str, dict[str, Any]] = {}
+ # Lightweight scheduling metadata pending delivery to the Scheduler.
+ self._local_request_metadata: dict[str, dict[str, Any]] = {}
+
+ # -- persistent set of request IDs whose chunk stream is complete --
+ # Prevents re-registration after the finish sentinel has been received.
+ self._chunk_stream_completed: set[str] = set()
+
+ # -- full_payload_mode: accumulate latest pooler_output per request,
+ # send only when the request finishes (next-cycle flush) --
+ self._pending_full_payload_send: dict[str, tuple[Any, Any]] = {}
+
+ # -- KV sent accumulator --
+ self._kv_sent_req_ids: list[str] = []
+
+ # -- KV transfer lifecycle (absorbed from scheduler) --
+ # Requests marked for KV transfer: {req_id: {seq_len, block_ids}}
+ self._kv_pending_transfers: dict[str, dict[str, Any]] = {}
+ # Requests whose KV transfer has been submitted but not yet acked
+ self._kv_active_transfers: set[str] = set()
+ # Requests whose KV transfer is complete (acked by kv_extracted_req_ids)
+ self._kv_completed_transfers: set[str] = set()
+ # Dedup guard: requests that have already triggered KV transfer
+ self._kv_triggered_requests: set[str] = set()
+
+ self._lock = threading.Lock()
+ self._stop_event = threading.Event()
+ self._work_available = threading.Event()
+
+ # Start background threads only when there's a connector
+ self._recv_thread: threading.Thread | None = None
+ self._save_thread: threading.Thread | None = None
+ if self._omni_connector is not None:
+ self._recv_thread = threading.Thread(
+ target=self._recv_loop,
+ daemon=True,
+ name="omni-mixin-recv",
+ )
+ self._recv_thread.start()
+ self._save_thread = threading.Thread(
+ target=self._save_loop,
+ daemon=True,
+ name="omni-mixin-save",
+ )
+ self._save_thread.start()
+
+ def shutdown_omni_connectors(self) -> None:
+ """Stop background threads and release connector resources."""
+ self._stop_event.set()
+ if self._recv_thread is not None:
+ self._recv_thread.join(timeout=5)
+ if self._save_thread is not None:
+ self._save_thread.join(timeout=5)
+ if self._omni_connector is not None:
+ try:
+ self._omni_connector.close()
+ except Exception:
+ pass
+
+ def cleanup_finished_request(self, req_id: str) -> None:
+ """Clean up per-request state after a request is fully finished.
+
+ Call this when a request is freed from the model runner to prevent
+ memory leaks in the mixin's tracking dicts/sets. The external
+ request ID is resolved before cleaning up ``_put_req_chunk`` which
+ is keyed by external ID.
+ """
+ ext_id = self._request_ids_mapping.pop(req_id, None)
+ send_req_id = ext_id if ext_id is not None else req_id
+
+ with self._lock:
+ if self._pending_save_counts.get(send_req_id, 0):
+ self._deferred_send_cleanup.add(send_req_id)
+ else:
+ self._put_req_chunk.pop(send_req_id, None)
+ self._send_side_request_payload.pop(send_req_id, None)
+ self._code_prompt_token_ids.pop(send_req_id, None)
+ self._kv_pending_transfers.pop(req_id, None)
+ self._kv_active_transfers.discard(req_id)
+ self._kv_completed_transfers.discard(req_id)
+ self._kv_triggered_requests.discard(req_id)
+ self._cleanup_recv_delivery_state(req_id)
+
+ def drop_inactive_request_delivery_state(self, req_id: str) -> None:
+ """Clear recv-side state for inactive requests."""
+ ext_id = self._request_ids_mapping.pop(req_id, None)
+ if hasattr(self, "_lock"):
+ with self._lock:
+ self._drop_send_side_payload_state(req_id, ext_id)
+ else:
+ self._drop_send_side_payload_state(req_id, ext_id)
+ self._cleanup_recv_delivery_state(req_id)
+
+ def _drop_send_side_payload_state(self, req_id: str, ext_id: str | None) -> None:
+ if ext_id is not None:
+ self._send_side_request_payload.pop(ext_id, None)
+ self._send_side_request_payload.pop(req_id, None)
+
+ def _cleanup_recv_delivery_state(self, req_id: str) -> None:
+ """Clear recv-side delivery-cycle state."""
+ if hasattr(self, "_lock"):
+ with self._lock:
+ self._clear_recv_delivery_state(req_id)
+ else:
+ self._clear_recv_delivery_state(req_id)
+
+ def _clear_recv_delivery_state(self, req_id: str) -> None:
+ self._get_req_chunk.pop(req_id, None)
+ self._pending_load_reqs.pop(req_id, None)
+ self._finished_load_reqs.discard(req_id)
+ self._chunk_ready_req_ids.discard(req_id)
+ self._chunk_finished_req_ids.discard(req_id)
+ self._chunk_stream_completed.discard(req_id)
+ self._stage_recv_req_ids.discard(req_id)
+ self._full_payload_pending_broadcast_req_ids.discard(req_id)
+ self._async_chunk_updated_req_ids.discard(req_id)
+ self._local_stage_payload_cache.pop(req_id, None)
+ self._local_request_metadata.pop(req_id, None)
+
+ def prune_inactive_requests(self, active_req_ids: Any) -> set[str]:
+ """Drop connector state for requests that no longer exist locally.
+
+ Preempted / unscheduled requests are expected to stay in
+ ``self.requests`` and therefore remain untouched. This only prunes
+ stale request IDs that have already fallen out of the active request
+ map, preventing background recv/send bookkeeping from outliving the
+ request lifecycle.
+ """
+ if active_req_ids is None:
+ return set()
+
+ active_req_ids = set(active_req_ids)
+ pending_req_ids = set(getattr(self, "_pending_load_reqs", {}).keys())
+ received_req_ids = set(getattr(self, "_stage_recv_req_ids", set()))
+ received_req_ids.update(getattr(self, "_full_payload_pending_broadcast_req_ids", set()))
+ received_req_ids.update(getattr(self, "_local_request_metadata", {}).keys())
+ # Pending recv requests may not yet be in the caller's active set
+ # (e.g. WAITING_FOR_CHUNK requests live in the coordinator's internal
+ # queues, not in model runner self.requests). Protect them so that
+ # legitimate waiting requests are not pruned.
+ #
+ # Likewise, a full payload can arrive on the background recv thread
+ # after the scheduler_output snapshot for the current execute_model()
+ # cycle was already materialized. Those requests may briefly live only
+ # in recv-side buffers/local cache until the next scheduler cycle wakes
+ # them up; pruning them here drops the payload before stage_recv can be
+ # published.
+ active_req_ids.update(pending_req_ids)
+ active_req_ids.update(received_req_ids)
+ stale_req_ids: set[str] = set()
+
+ # NOTE: _pending_load_reqs is excluded from the scan list because
+ # all its entries are unconditionally protected above. The mixin
+ # cannot distinguish a legitimately-waiting pending recv from an
+ # orphaned one (only the coordinator/scheduler knows).
+ #
+ # Requests with freshly received full payloads / local stage payloads
+ # are also protected above. Their scheduler wake-up may lag the recv
+ # thread by one execute_model() cycle, especially when the request was
+ # added after the current scheduler_output snapshot.
+ #
+ # Orphaned pending recv entries (e.g. from upstream stage crash)
+ # are handled by OmniSchedulingCoordinator.collect_timed_out_request_ids()
+ # which detects wait-time violations. The scheduler then removes the
+ # request from its queues, sets FINISHED_ERROR, and calls _free_request()
+ # which ultimately triggers cleanup_finished_request() here.
+ for attr_name in (
+ "_request_ids_mapping",
+ "_get_req_chunk",
+ "_finished_load_reqs",
+ "_chunk_ready_req_ids",
+ "_chunk_finished_req_ids",
+ "_chunk_stream_completed",
+ "_stage_recv_req_ids",
+ "_full_payload_pending_broadcast_req_ids",
+ "_async_chunk_updated_req_ids",
+ "_local_stage_payload_cache",
+ "_local_request_metadata",
+ "_kv_pending_transfers",
+ "_kv_active_transfers",
+ "_kv_completed_transfers",
+ "_kv_triggered_requests",
+ ):
+ state = getattr(self, attr_name, None)
+ if isinstance(state, dict):
+ stale_req_ids.update(req_id for req_id in state if req_id not in active_req_ids)
+ elif isinstance(state, set):
+ stale_req_ids.update(req_id for req_id in state if req_id not in active_req_ids)
+
+ for req_id in stale_req_ids:
+ self.cleanup_finished_request(req_id)
+
+ return stale_req_ids
+
+ # ------------------------------------------------------------------ #
+ # Local payload cache (RFC §2.4 – Model Runner ownership)
+ # ------------------------------------------------------------------ #
+
+ def put_local_stage_payload(self, req_id: str, payload: dict[str, Any]) -> None:
+ """Store a full stage payload in the local cache."""
+ self._local_stage_payload_cache[req_id] = payload
+
+ def get_local_stage_payload(self, req_id: str) -> dict[str, Any] | None:
+ """Read a stage payload without removing it."""
+ return self._local_stage_payload_cache.get(req_id)
+
+ def pop_local_stage_payload(self, req_id: str) -> dict[str, Any] | None:
+ """Remove and return a stage payload (consume after use)."""
+ return self._local_stage_payload_cache.pop(req_id, None)
+
+ def put_local_request_metadata(self, req_id: str, metadata: dict[str, Any]) -> None:
+ """Store lightweight scheduling metadata for a request."""
+ self._local_request_metadata[req_id] = metadata
+
+ def get_local_request_metadata(self, req_id: str) -> dict[str, Any] | None:
+ """Retrieve scheduling metadata for a request."""
+ return self._local_request_metadata.get(req_id)
+
+ # ------------------------------------------------------------------ #
+ # Scheduling metadata extraction
+ # ------------------------------------------------------------------ #
+
+ _SCHEDULING_METADATA_KEYS = (
+ "next_stage_prompt_len",
+ "code_predictor_codes",
+ "left_context_size",
+ )
+
+ @classmethod
+ def _extract_scheduling_metadata(cls, payload: dict[str, Any]) -> dict[str, Any]:
+ """Extract only the fields the scheduler needs from a full payload."""
+ return {k: payload[k] for k in cls._SCHEDULING_METADATA_KEYS if k in payload}
+
+ _NON_CONSUMABLE_PAYLOAD_KEYS = {
+ "finished",
+ "override_keys",
+ "next_stage_prompt_len",
+ "left_context_size",
+ THINKER_OUTPUT_TOKEN_IDS_KEY,
+ THINKER_DECODE_TOKEN_START_KEY,
+ THINKER_DECODE_TOKEN_END_KEY,
+ }
+
+ @staticmethod
+ def _payload_value_has_content(value: Any) -> bool:
+ if value is None:
+ return False
+ if isinstance(value, torch.Tensor):
+ return value.numel() > 0
+ if isinstance(value, (list, tuple, dict, set)):
+ return len(value) > 0
+ return True
+
+ @classmethod
+ def _payload_is_consumable(cls, payload: dict[str, Any] | None) -> bool:
+ """Return True when an async payload can drive a real forward step.
+
+ Metadata-only wake-ups should not transition WAITING_FOR_CHUNK requests
+ back to schedulable state. In particular, a widened token horizon without
+ any newly visible thinker decode embeds should not force a placeholder-only
+ talker decode step.
+ """
+ if not isinstance(payload, dict) or not payload:
+ return False
+
+ decode_embeddings = payload.get(THINKER_DECODE_EMBEDDINGS_KEY)
+ if isinstance(decode_embeddings, torch.Tensor):
+ if decode_embeddings.ndim == 0:
+ return True
+ return decode_embeddings.numel() > 0 and decode_embeddings.shape[0] > 0
+
+ if "code_predictor_codes" in payload:
+ code_predictor_codes = payload.get("code_predictor_codes")
+ if isinstance(code_predictor_codes, torch.Tensor):
+ return code_predictor_codes.numel() > 0
+ # Codec code 0 is valid; non-empty code payloads are consumable.
+ if hasattr(code_predictor_codes, "__len__"):
+ return len(code_predictor_codes) > 0
+ else:
+ return code_predictor_codes is not None
+
+ for key, value in payload.items():
+ if key in cls._NON_CONSUMABLE_PAYLOAD_KEYS:
+ continue
+ if cls._payload_value_has_content(value):
+ return True
+ return False
+
+ @staticmethod
+ def _get_local_tp_group() -> Any | None:
+ """Return the local TP group when tensor parallelism is initialized."""
+ try:
+ return get_tp_group()
+ except Exception:
+ return None
+
+ def _recv_ordinary_stage_result(
+ self,
+ connector: OmniConnectorBase,
+ from_stage: str,
+ to_stage: str,
+ connector_get_key: str,
+ ) -> Any:
+ """Receive one ordinary non-KV stage payload on the local leader rank only."""
+ tp_group = self._get_local_tp_group()
+ if tp_group is None or getattr(tp_group, "world_size", 1) <= 1:
+ return connector.get(from_stage, to_stage, connector_get_key)
+ if not self.is_data_transfer_rank():
+ return None
+ return connector.get(from_stage, to_stage, connector_get_key)
+
+ def _recv_full_payload_result(
+ self,
+ connector: OmniConnectorBase,
+ from_stage: str,
+ to_stage: str,
+ connector_get_key: str,
+ ) -> Any:
+ """Receive one full-payload transfer on the local leader rank only."""
+ return self._recv_ordinary_stage_result(
+ connector,
+ from_stage,
+ to_stage,
+ connector_get_key,
+ )
+
+ def _recv_async_chunk_result(
+ self,
+ connector: OmniConnectorBase,
+ from_stage: str,
+ to_stage: str,
+ connector_get_key: str,
+ ) -> Any:
+ """Receive one ordinary async chunk on the local leader rank only."""
+ return self._recv_ordinary_stage_result(
+ connector,
+ from_stage,
+ to_stage,
+ connector_get_key,
+ )
+
+ @staticmethod
+ def _snapshot_payload(payload: Any) -> Any:
+ if isinstance(payload, dict):
+ return dict(payload)
+ return payload
+
+ def _broadcast_tp_payload_packet(self, packet: Any) -> Any:
+ """Broadcast one ordinary payload packet from TP rank 0 when TP is active."""
+ tp_group = self._get_local_tp_group()
+ if tp_group is None or getattr(tp_group, "world_size", 1) <= 1:
+ return packet
+ leader_packet = packet if self.is_data_transfer_rank() else None
+ return tp_group.broadcast_object(leader_packet, src=0)
+
+ def _apply_staged_payloads_locked(self, staged_payloads: dict[str, Any]) -> None:
+ for req_id, payload in staged_payloads.items():
+ self._local_stage_payload_cache[req_id] = self._snapshot_payload(payload)
+
+ def _collect_full_payload_results_locked(self) -> dict[str, Any] | None:
+ if not self._full_payload_pending_broadcast_req_ids:
+ return None
+ results: dict[str, Any] = {}
+ missing_req_ids: list[str] = []
+ for req_id in tuple(self._full_payload_pending_broadcast_req_ids):
+ payload = self._local_stage_payload_cache.get(req_id)
+ if payload is None:
+ missing_req_ids.append(req_id)
+ continue
+ results[req_id] = self._snapshot_payload(payload)
+ self._full_payload_pending_broadcast_req_ids.discard(req_id)
+ if missing_req_ids:
+ logger.warning(
+ "[Stage-%s] _collect_full_payload_results_locked: "
+ "pending full-payload reqs missing from local cache: %s",
+ self._stage_id,
+ missing_req_ids,
+ )
+ return results or None
+
+ def _collect_async_chunk_fanout_packet_locked(self) -> dict[str, Any] | None:
+ payload_req_ids = set(self._async_chunk_updated_req_ids)
+ payload_req_ids.update(self._finished_load_reqs)
+ payload_req_ids.update(self._chunk_finished_req_ids)
+ payload_req_ids.update(self._local_request_metadata)
+ if not (
+ payload_req_ids or self._finished_load_reqs or self._chunk_finished_req_ids or self._local_request_metadata
+ ):
+ return None
+
+ staged_payloads = {
+ req_id: self._snapshot_payload(self._local_stage_payload_cache[req_id])
+ for req_id in payload_req_ids
+ if req_id in self._local_stage_payload_cache
+ }
+ packet = {
+ "staged_payloads": staged_payloads,
+ "request_metadata": dict(self._local_request_metadata),
+ "newly_finished": set(self._finished_load_reqs),
+ "chunk_finished": set(self._chunk_finished_req_ids),
+ }
+
+ self._async_chunk_updated_req_ids.clear()
+ self._finished_load_reqs.clear()
+ self._chunk_finished_req_ids.clear()
+ self._local_request_metadata.clear()
+
+ for req_id in packet["chunk_finished"]:
+ if req_id not in self._local_stage_payload_cache:
+ continue
+ ext_req_id = self._request_ids_mapping.get(req_id, req_id)
+ self._send_side_request_payload.pop(ext_req_id, None)
+ if ext_req_id != req_id:
+ self._send_side_request_payload.pop(req_id, None)
+
+ return packet
+
+ def _apply_async_chunk_fanout_packet(self, packet: dict[str, Any]) -> None:
+ staged_payloads = packet.get("staged_payloads", {})
+ chunk_finished = set(packet.get("chunk_finished", ()))
+ with self._lock:
+ self._apply_staged_payloads_locked(staged_payloads)
+ for req_id in chunk_finished:
+ self._pending_load_reqs.pop(req_id, None)
+ self._chunk_stream_completed.add(req_id)
+
+ # ------------------------------------------------------------------ #
+ # full_payload_mode (recv_full_payload_inputs / send_full_payload_outputs)
+ # ------------------------------------------------------------------ #
+
+ def recv_full_payload_inputs(self, scheduler_output: Any) -> dict[str, Any] | None:
+ """Check for incoming full_payload_mode stage inputs (non-blocking).
+
+ Returns a dict mapping ``request_id -> engine_inputs`` for data
+ that has arrived, or ``None`` if nothing is ready. Stores full
+ payloads in the local cache and extracts scheduling metadata.
+ """
+ with self._lock:
+ results = self._collect_full_payload_results_locked() if self.is_data_transfer_rank() else None
+ results = self._broadcast_tp_payload_packet(results)
+ if not results:
+ return None
+ with self._lock:
+ self._stage_recv_req_ids.update(results.keys())
+ for req_id in results:
+ self._pending_load_reqs.pop(req_id, None)
+ self._apply_staged_payloads_locked(results)
+ for req_id, payload in results.items():
+ self._local_request_metadata[req_id] = self._extract_scheduling_metadata(payload)
+ logger.info(
+ "[Stage-%s] recv_full_payload_inputs: consumed %s reqs: %s, stage_recv_req_ids now=%s",
+ self._stage_id,
+ len(results),
+ list(results.keys()),
+ self._stage_recv_req_ids,
+ )
+ return results
+
+ @staticmethod
+ def _is_all_zero_tensor(t: Any) -> bool:
+ """Return True if *t* is a torch.Tensor whose elements are all zero."""
+ return isinstance(t, torch.Tensor) and t.numel() > 0 and not t.any()
+
+ def accumulate_full_payload_output(
+ self,
+ req_id: str,
+ pooler_output: Any,
+ request: Any,
+ ) -> None:
+ """Accumulate pooler_output for a request across steps (full_payload_mode).
+
+ Per-token tensors (2-D+, matching trailing dims) are concatenated
+ along dim-0. Scalar / global tensors (1-D or 0-D) are replaced
+ with the latest value.
+
+ All-zero tensors (e.g. ``code_predictor_codes`` emitted during
+ prefill) are dropped so that they do not pollute downstream stages
+ with garbage / noise frames.
+
+ The data is actually sent when ``flush_full_payload_outputs`` is called
+ with the finished request IDs from the next scheduler cycle.
+ """
+ # ---- Filter out all-zero tensors from the incoming pooler_output ----
+ filtered: dict[str, Any] = {}
+ dropped_zero_keys: list[tuple[str, tuple[int, ...]]] = []
+ for k, v in pooler_output.items():
+ if self._is_all_zero_tensor(v):
+ dropped_zero_keys.append((k, tuple(v.shape)))
+ continue # skip prefill zero-filled placeholders
+ filtered[k] = v
+ if dropped_zero_keys:
+ logger.info(
+ "[Stage-%s] accumulate_full_payload_output: req=%s dropped_zero_keys=%s",
+ self._stage_id,
+ req_id,
+ dropped_zero_keys,
+ )
+ pooler_output = filtered
+
+ existing = self._pending_full_payload_send.get(req_id)
+ if existing is None:
+ self._pending_full_payload_send[req_id] = (pooler_output, request)
+ return
+
+ prev_output, _ = existing
+ merged: dict[str, Any] = {}
+ for k in set(prev_output) | set(pooler_output):
+ v_new = pooler_output.get(k)
+ v_old = prev_output.get(k)
+ if v_new is None:
+ merged[k] = v_old
+ elif v_old is None:
+ merged[k] = v_new
+ elif (
+ isinstance(v_new, torch.Tensor)
+ and isinstance(v_old, torch.Tensor)
+ and v_new.dim() >= 2
+ and v_old.dim() >= 2
+ and v_new.shape[1:] == v_old.shape[1:]
+ ):
+ merged[k] = torch.cat([v_old, v_new], dim=0)
+ else:
+ merged[k] = v_new
+ self._pending_full_payload_send[req_id] = (merged, request)
+
+ def flush_full_payload_outputs(self, finished_req_ids: set[str]) -> None:
+ """Send accumulated full_payload outputs for requests that just finished."""
+ logger.info(
+ "[Stage-%s] flush_full_payload_outputs: finished_req_ids=%s, pending=%s",
+ self._stage_id,
+ finished_req_ids,
+ list(self._pending_full_payload_send.keys()),
+ )
+ to_send: dict[str, tuple[Any, Any]] = {}
+ for req_id in finished_req_ids:
+ entry = self._pending_full_payload_send.pop(req_id, None)
+ if entry is not None:
+ to_send[req_id] = entry
+ logger.info("[Stage-%s] flush_full_payload_outputs: to_send=%s", self._stage_id, list(to_send.keys()))
+ if to_send:
+ self.send_full_payload_outputs(scheduler_output=None, outputs=to_send)
+
+ def send_full_payload_outputs(
+ self,
+ scheduler_output: Any,
+ outputs: dict[str, tuple[Any, Any] | Any],
+ ) -> list[str]:
+ """Send full_payload stage outputs to the next stage via connector.
+
+ Args:
+ outputs: Mapping of ``req_id`` to either a
+ ``(pooling_output, request)`` tuple (preferred) or a raw
+ payload dict. When a tuple is supplied the request object
+ is forwarded to ``custom_process_stage_input_func``.
+
+ Returns list of request IDs successfully enqueued.
+ """
+ if self._omni_connector is None:
+ logger.info("[Stage-%s] send_full_payload_outputs: connector is None, skip", self._stage_id)
+ return []
+ if not self.is_data_transfer_rank():
+ logger.info(
+ "[Stage-%s] send_full_payload_outputs: not data_transfer_rank (rank=%s), skip",
+ self._stage_id,
+ self._local_rank,
+ )
+ return list(outputs.keys())
+ sent_ids: list[str] = []
+ next_stage_id = self._next_stage_id
+ for req_id, value in outputs.items():
+ if isinstance(value, tuple) and len(value) == 2:
+ raw_output, request = value
+ else:
+ raw_output, request = value, None
+
+ payload = raw_output
+ if self._custom_process_func is not None:
+ payload = self._build_custom_process_payload(
+ request_id=req_id,
+ request=request,
+ pooling_output=raw_output,
+ )
+ if payload is None:
+ continue
+ if payload is None:
+ logger.info("[Stage-%s] send_full_payload_outputs: payload is None for %s", self._stage_id, req_id)
+ continue
+ if isinstance(payload, dict):
+ code_predictor_codes = payload.get("code_predictor_codes")
+ if isinstance(code_predictor_codes, torch.Tensor):
+ code_len = int(code_predictor_codes.numel())
+ elif hasattr(code_predictor_codes, "__len__"):
+ code_len = len(code_predictor_codes)
+ else:
+ code_len = None
+ logger.info(
+ "[Stage-%s] send_full_payload_outputs: req=%s payload_keys=%s code_len=%s left_context_size=%s",
+ self._stage_id,
+ req_id,
+ sorted(payload.keys()),
+ code_len,
+ payload.get("left_context_size"),
+ )
+
+ external_req_id = self._resolve_external_req_id(request, req_id)
+ chunk_id = self._put_req_chunk[req_id]
+ self._put_req_chunk[req_id] += 1
+ connector_put_key = f"{external_req_id}_{self._stage_id}_{chunk_id}"
+
+ logger.info(
+ "[Stage-%s] send_full_payload_outputs: enqueue req=%s put_key=%s next_stage=%s",
+ self._stage_id,
+ req_id,
+ connector_put_key,
+ next_stage_id,
+ )
+ task = {
+ "stage_id": self._stage_id,
+ "next_stage_id": next_stage_id,
+ "put_key": connector_put_key,
+ "data": payload,
+ "request_id": req_id,
+ }
+ with self._lock:
+ self._pending_save_reqs.setdefault(req_id, deque()).append(task)
+ self._pending_save_counts[req_id] += 1
+ sent_ids.append(req_id)
+ if sent_ids:
+ self._work_available.set()
+ return sent_ids
+
+ def recv_stage_inputs(self, scheduler_output: Any) -> dict[str, Any] | None:
+ """Compatibility wrapper for ``recv_full_payload_inputs``."""
+ return self.recv_full_payload_inputs(scheduler_output)
+
+ def accumulate_batch_output(
+ self,
+ req_id: str,
+ pooler_output: Any,
+ request: Any,
+ ) -> None:
+ """Compatibility wrapper for ``accumulate_full_payload_output``."""
+ self.accumulate_full_payload_output(req_id, pooler_output, request)
+
+ def flush_batch_outputs(self, finished_req_ids: set[str]) -> None:
+ """Compatibility wrapper for ``flush_full_payload_outputs``."""
+ self.flush_full_payload_outputs(finished_req_ids)
+
+ def send_stage_outputs(
+ self,
+ scheduler_output: Any,
+ outputs: dict[str, tuple[Any, Any] | Any],
+ ) -> list[str]:
+ """Compatibility wrapper for ``send_full_payload_outputs``."""
+ return self.send_full_payload_outputs(scheduler_output, outputs)
+
+ # ------------------------------------------------------------------ #
+ # Streaming chunk mode (recv_chunk / send_chunk)
+ # ------------------------------------------------------------------ #
+
+ def register_chunk_recv(self, request: Any) -> None:
+ """Register a request for async chunk retrieval by the bg thread.
+
+ Stage-0 has no upstream producer so this is a no-op there.
+ Skips requests whose batch data has already been received to
+ prevent the bg thread from polling for non-existent chunks.
+ """
+ if self._stage_id == 0:
+ return
+ request_id = request.request_id
+ self._request_ids_mapping[request_id] = getattr(
+ request,
+ "external_req_id",
+ request_id,
+ )
+ with self._lock:
+ if request_id in self._stage_recv_req_ids:
+ return
+ # Don't re-register if the finish sentinel was already received
+ if request_id in self._chunk_stream_completed:
+ return
+ self._pending_load_reqs[request_id] = request
+ self._work_available.set()
+
+ def recv_chunk(self) -> dict[str, Any]:
+ """Collect chunks received by the bg thread since last call.
+
+ Returns a dict ``{request_id: chunk_payload}`` for newly arrived
+ chunks. Empty dict when nothing is ready.
+
+ This method reads from ``_finished_load_reqs`` without clearing
+ it -- ``get_omni_connector_output()`` is the sole consumer that
+ drains and resets ``_finished_load_reqs`` at the end of each
+ ``execute_model`` cycle.
+
+ Returns **shallow copies** of the cached payloads so that the
+ caller can read them without racing against the background recv
+ thread, which may concurrently mutate the live cache entries via
+ ``dict.update()``.
+ """
+ with self._lock:
+ finished = set(self._finished_load_reqs)
+ if not finished:
+ return {}
+ # Snapshot the payloads under the lock to avoid racing with
+ # _poll_single_request which does existing.update(payload_data)
+ # on the same dict objects.
+ result = {}
+ for rid in finished:
+ payload = self._local_stage_payload_cache.get(rid)
+ result[rid] = dict(payload) if isinstance(payload, dict) else payload
+
+ self._chunk_ready_req_ids.update(finished)
+ return result
+
+ def send_chunk(
+ self,
+ request: Any,
+ pooling_output: Any | None = None,
+ ) -> bool:
+ """Derive and enqueue one chunk for async sending.
+
+ Payload extraction runs in the caller thread (via
+ ``custom_process_stage_input_func``); the actual
+ ``connector.put()`` is done by the background save thread.
+ Non-KV data is identical across TP ranks; only rank 0 sends.
+ """
+ if self._omni_connector is None:
+ logger.warning("[Stage-%s] send_chunk: connector is None", self._stage_id)
+ return False
+ if not self.is_data_transfer_rank():
+ return True
+ raw_req_id = getattr(request, "request_id", None) or getattr(request, "req_id", None)
+ request_id = self._resolve_external_req_id(request, raw_req_id)
+ # Cache the internal→external mapping so that finish sentinels can
+ # resolve the external ID even after the request is freed.
+ if raw_req_id and raw_req_id != request_id:
+ self._request_ids_mapping.setdefault(raw_req_id, request_id)
+ chunk_id = self._put_req_chunk[request_id]
+
+ payload_data = self._build_custom_process_payload(
+ request_id=request_id,
+ request=request,
+ pooling_output=pooling_output,
+ )
+ if payload_data is None:
+ if chunk_id == 0:
+ logger.warning(
+ "[Stage-%s] send_chunk: payload is None for req=%s chunk=%s (process_func=%s)",
+ self._stage_id,
+ request_id,
+ chunk_id,
+ self._custom_process_func,
+ )
+ return False
+
+ self._put_req_chunk[request_id] += 1
+ next_stage_id = self._next_stage_id
+ connector_put_key = f"{request_id}_{self._stage_id}_{chunk_id}"
+
+ if chunk_id == 0:
+ logger.info(
+ "[Stage-%s] send_chunk: first chunk enqueued, req=%s key=%s",
+ self._stage_id,
+ request_id,
+ connector_put_key,
+ )
+
+ task = {
+ "stage_id": self._stage_id,
+ "next_stage_id": next_stage_id,
+ "put_key": connector_put_key,
+ "data": payload_data,
+ "request_id": request_id,
+ }
+ with self._lock:
+ self._pending_save_reqs.setdefault(request_id, deque()).append(task)
+ self._pending_save_counts[request_id] += 1
+ self._work_available.set()
+ return True
+
+ # ------------------------------------------------------------------ #
+ # KV cache (delegates to OmniKVTransferManager)
+ # ------------------------------------------------------------------ #
+
+ def send_kv_cache(
+ self,
+ finished_reqs: dict[str, dict[str, Any]],
+ kv_caches: list[torch.Tensor],
+ block_size: int,
+ cache_dtype: str,
+ request_id_resolver: Any | None = None,
+ ) -> list[str]:
+ """Send KV cache for finished requests.
+
+ Delegates to the existing ``OmniKVTransferManager``.
+ """
+ if self._kv_transfer_manager is None:
+ return list(finished_reqs.keys()) if finished_reqs else []
+ result = self._kv_transfer_manager.handle_finished_requests_kv_transfer(
+ finished_reqs=finished_reqs,
+ kv_caches=kv_caches,
+ block_size=block_size,
+ cache_dtype=cache_dtype,
+ request_id_resolver=request_id_resolver,
+ )
+ if result:
+ self._kv_sent_req_ids.extend(result)
+ return result
+
+ def recv_kv_cache(
+ self,
+ request_id: str,
+ target_device: torch.device | None = None,
+ ) -> tuple[dict[str, Any] | None, int]:
+ """Receive KV cache for a request.
+
+ Delegates to the existing ``OmniKVTransferManager``.
+ """
+ if self._kv_transfer_manager is None:
+ return None, 0
+ return self._kv_transfer_manager.receive_kv_cache_for_request(
+ request_id=request_id,
+ target_device=target_device,
+ )
+
+ def receive_cfg_companion_kv_payloads(
+ self,
+ cfg_request_ids: dict[str, str],
+ target_device: torch.device | None = None,
+ ) -> dict[str, tuple[dict[str, Any] | None, int]]:
+ """Receive raw CFG companion KV payloads keyed by role."""
+ return {
+ role: self.recv_kv_cache(companion_rid, target_device=target_device)
+ for role, companion_rid in cfg_request_ids.items()
+ }
+
+ def receive_multi_kv_cache(
+ self,
+ req: Any,
+ cfg_kv_collect_func: Any | None = None,
+ target_device: torch.device | None = None,
+ ) -> bool:
+ """Receive primary and optional companion KV caches for a request.
+
+ The mixin owns the runner-facing orchestration: primary KV receive,
+ companion payload fetch, and applying any model-specific CFG fields back
+ onto ``req.sampling_params``.
+ """
+ if self._kv_transfer_manager is None:
+ return False
+
+ request_id = getattr(req, "request_id", None) or (
+ req.request_ids[0] if hasattr(req, "request_ids") and req.request_ids else None
+ )
+ if not request_id:
+ logger.warning("Request has no ID, cannot receive KV cache")
+ return False
+
+ active_requests = getattr(self, "requests", None)
+ if active_requests is not None and request_id not in active_requests:
+ logger.info("Skip receiving KV cache for inactive request %s", request_id)
+ return False
+
+ primary_ok = False
+ data, _size = self.recv_kv_cache(request_id, target_device=target_device)
+ if data:
+ self._kv_transfer_manager.apply_kv_cache_to_request(req, data)
+ primary_ok = True
+
+ cfg_ids = getattr(getattr(req, "sampling_params", None), "cfg_kv_request_ids", None)
+ if cfg_ids and cfg_kv_collect_func:
+ try:
+ cfg_role_payloads = self.receive_cfg_companion_kv_payloads(
+ cfg_ids,
+ target_device=target_device,
+ )
+ cfg_kvs = cfg_kv_collect_func(request_id, cfg_role_payloads)
+ if cfg_kvs and hasattr(req, "sampling_params") and req.sampling_params is not None:
+ for key, value in cfg_kvs.items():
+ setattr(req.sampling_params, key, value)
+ logger.info("Applied CFG KV caches: %s", list(cfg_kvs.keys()))
+ except Exception:
+ logger.exception("Failed to collect CFG KV caches for %s", request_id)
+
+ return primary_ok
+
+ # ------------------------------------------------------------------ #
+ # Rank-aware KV transfer routing
+ # ------------------------------------------------------------------ #
+
+ def get_rank_aware_kv_keys(
+ self,
+ req_id: str,
+ from_stage: int,
+ to_stage: int | None = None,
+ chunk_id: int = 0,
+ ) -> list[str]:
+ """Build recv-side connector keys for all remote ranks this rank needs.
+
+ For heterogeneous TP receive, the local rank is the target rank and must
+ fetch one or more source-rank shards keyed as ``from_rank -> to_rank``.
+ """
+ remote_ranks = self.get_kv_remote_ranks()
+ return [
+ self.get_kv_connector_key(
+ req_id=req_id,
+ from_stage=from_stage,
+ chunk_id=chunk_id,
+ from_rank=remote_rank,
+ to_rank=self._local_rank,
+ )
+ for remote_rank in remote_ranks
+ ]
+
+ def get_kv_target_ranks_for_send(self) -> list[int]:
+ """Determine which target ranks this local rank should send KV shards to."""
+ self._validate_kv_tp_topology()
+ if self._from_tp == self._to_tp:
+ return [self._local_rank]
+ if self._from_tp > self._to_tp:
+ tp_ratio = self._from_tp // self._to_tp
+ return [self._local_rank // tp_ratio]
+ tp_ratio = self._to_tp // self._from_tp
+ base_rank = self._local_rank * tp_ratio
+ return [base_rank + i for i in range(tp_ratio)]
+
+ def get_rank_aware_kv_send_keys(
+ self,
+ req_id: str,
+ from_stage: int,
+ to_stage: int | None = None,
+ chunk_id: int = 0,
+ ) -> list[str]:
+ """Build send-side connector keys for this rank's KV shard(s)."""
+ target_ranks = self.get_kv_target_ranks_for_send()
+ return [
+ self.get_kv_connector_key(
+ req_id=req_id,
+ from_stage=from_stage,
+ chunk_id=chunk_id,
+ from_rank=self._local_rank,
+ to_rank=target_rank,
+ )
+ for target_rank in target_ranks
+ ]
+
+ @staticmethod
+ def _merge_rank_sharded_kv_payloads(payloads: list[dict[str, Any]]) -> dict[str, Any] | None:
+ """Merge multiple source-rank KV shards for one target rank."""
+ payloads = [payload for payload in payloads if isinstance(payload, dict)]
+ if not payloads:
+ return None
+ if len(payloads) == 1:
+ return payloads[0]
+
+ merged = dict(payloads[0])
+ layer_blocks = merged.get("layer_blocks")
+ if not isinstance(layer_blocks, dict):
+ return merged
+
+ def _merge_tensor_lists(name: str) -> list[torch.Tensor | None]:
+ merged_list: list[torch.Tensor | None] = []
+ cache_lists = [payload.get("layer_blocks", {}).get(name, []) for payload in payloads]
+ max_len = max((len(cache_list) for cache_list in cache_lists), default=0)
+ for idx in range(max_len):
+ tensors = [cache_list[idx] for cache_list in cache_lists if idx < len(cache_list)]
+ tensors = [tensor for tensor in tensors if isinstance(tensor, torch.Tensor)]
+ if not tensors:
+ merged_list.append(None)
+ elif len(tensors) == 1:
+ merged_list.append(tensors[0])
+ else:
+ merged_list.append(torch.cat(tensors, dim=-2).contiguous())
+ return merged_list
+
+ merged["layer_blocks"] = {
+ "key_cache": _merge_tensor_lists("key_cache"),
+ "value_cache": _merge_tensor_lists("value_cache"),
+ }
+ metadata = dict(merged.get("metadata", {}))
+ metadata["merged_remote_rank_count"] = len(payloads)
+ merged["metadata"] = metadata
+ return merged
+
+ def _slice_rank_sharded_kv_payload(self, payload: dict[str, Any] | None) -> dict[str, Any] | None:
+ """Slice a duplicated source-rank KV shard for ``from_tp < to_tp`` cases."""
+ if payload is None or self._from_tp >= self._to_tp:
+ return payload
+
+ tp_ratio = self._to_tp // self._from_tp
+ shard_index = self._local_rank % tp_ratio
+ layer_blocks = payload.get("layer_blocks") if isinstance(payload, dict) else None
+ if not isinstance(layer_blocks, dict):
+ return payload
+
+ def _slice_tensor_list(name: str) -> list[torch.Tensor | None]:
+ sliced: list[torch.Tensor | None] = []
+ for tensor in layer_blocks.get(name, []):
+ if not isinstance(tensor, torch.Tensor) or tensor.ndim < 2:
+ sliced.append(tensor)
+ continue
+ head_dim = tensor.shape[-2]
+ if head_dim % tp_ratio != 0:
+ sliced.append(tensor)
+ continue
+ per_rank = head_dim // tp_ratio
+ start = shard_index * per_rank
+ sliced.append(tensor.narrow(-2, start, per_rank).contiguous())
+ return sliced
+
+ payload = dict(payload)
+ payload["layer_blocks"] = {
+ "key_cache": _slice_tensor_list("key_cache"),
+ "value_cache": _slice_tensor_list("value_cache"),
+ }
+ metadata = dict(payload.get("metadata", {}))
+ metadata["sliced_for_local_rank"] = self._local_rank
+ payload["metadata"] = metadata
+ return payload
+
+ def should_replicate_payload(self) -> bool:
+ """Whether non-KV payloads should be replicated across ranks.
+
+ Data payloads (stage inputs, chunks) are identical after all-gather,
+ so only rank 0 transfers them. KV payloads are rank-specific and
+ all ranks participate.
+ """
+ return self._local_rank != 0
+
+ def get_kv_rank_mapping(self) -> dict[str, Any]:
+ """Return the current rank mapping configuration.
+
+ Useful for debugging and for downstream code that needs to know
+ the TP topology without re-parsing model config.
+ """
+ return {
+ "from_tp": self._from_tp,
+ "to_tp": self._to_tp,
+ "local_rank": self._local_rank,
+ "remote_ranks": self.get_kv_remote_ranks(),
+ "is_data_transfer_rank": self.is_data_transfer_rank(),
+ }
+
+ # ------------------------------------------------------------------ #
+ # KV transfer lifecycle (RFC – mixin-owned)
+ # ------------------------------------------------------------------ #
+
+ def mark_kv_transfer(
+ self,
+ req_id: str,
+ seq_len: int,
+ block_ids: list[int],
+ custom_metadata: dict[str, Any] | None = None,
+ ) -> None:
+ """Mark a request as needing KV cache transfer.
+
+ Called by the scheduler when a transfer trigger fires. The mixin
+ owns the lifecycle from this point: pending → active → completed.
+ """
+ if req_id in self._kv_pending_transfers:
+ return
+ self._kv_triggered_requests.add(req_id)
+ transfer = {
+ "seq_len": seq_len,
+ "block_ids": block_ids,
+ }
+ if custom_metadata is not None:
+ transfer["custom_metadata"] = custom_metadata
+ self._kv_pending_transfers[req_id] = transfer
+
+ def drain_pending_kv_transfers(self) -> dict[str, dict[str, Any]]:
+ """Drain pending KV transfers and move them to active.
+
+ Returns ``{req_id: {seq_len, block_ids}}`` for the model runner
+ to submit to ``send_kv_cache``.
+ """
+ if not self._kv_pending_transfers:
+ return {}
+ pending = dict(self._kv_pending_transfers)
+ self._kv_active_transfers.update(pending.keys())
+ self._kv_pending_transfers.clear()
+ return pending
+
+ def ack_kv_transfers(self, req_ids: list[str] | set[str]) -> None:
+ """Acknowledge completed KV transfers (from kv_extracted_req_ids).
+
+ Moves requests from active to completed so the scheduler can
+ safely free their blocks.
+ """
+ for req_id in req_ids:
+ self._kv_active_transfers.discard(req_id)
+ self._kv_completed_transfers.add(req_id)
+
+ def drain_completed_kv_transfers(self) -> set[str]:
+ """Drain and return completed KV transfer request IDs.
+
+ The scheduler calls this to know which requests' blocks can be freed.
+ """
+ completed = set(self._kv_completed_transfers)
+ self._kv_completed_transfers.clear()
+ return completed
+
+ def is_kv_transfer_triggered(self, req_id: str) -> bool:
+ """Check if a request has already triggered KV transfer."""
+ return req_id in self._kv_triggered_requests
+
+ def has_pending_kv_work(self) -> bool:
+ """True if any KV transfers are pending, active, or awaiting ack."""
+ return bool(self._kv_pending_transfers or self._kv_active_transfers or self._kv_completed_transfers)
+
+ # Output aggregation
+ # ------------------------------------------------------------------ #
+
+ def _empty_output_with_connector_signals(self) -> Any:
+ """Return a minimal ModelRunnerOutput carrying pending connector signals.
+
+ Used by early-return paths (e.g. ``num_scheduled_tokens == 0``)
+ that still need to deliver ``omni_connector_output`` to the
+ Scheduler so that WAITING_FOR_INPUT / WAITING_FOR_CHUNK
+ transitions are not lost.
+ """
+ from vllm_omni.outputs import OmniModelRunnerOutput
+
+ output = OmniModelRunnerOutput(req_ids=[], req_id_to_index={})
+ output.omni_connector_output = self.get_omni_connector_output()
+ return output
+
+ def get_omni_connector_output(self) -> OmniConnectorOutput:
+ """Collect and reset transfer results for this execute_model cycle.
+
+ ``request_metadata`` carries only lightweight scheduling metadata.
+ Full payloads remain owned by the Model Runner local cache for all
+ paths.
+ """
+ if not hasattr(self, "_lock"):
+ return OmniConnectorOutput()
+
+ tp_group = self._get_local_tp_group()
+ if self._async_chunk and tp_group is not None and getattr(tp_group, "world_size", 1) > 1:
+ if self.is_data_transfer_rank():
+ with self._lock:
+ fanout_packet = self._collect_async_chunk_fanout_packet_locked()
+ else:
+ fanout_packet = None
+ fanout_packet = self._broadcast_tp_payload_packet(fanout_packet)
+ if fanout_packet is None:
+ newly_finished = set()
+ chunk_finished = set()
+ request_metadata = {}
+ else:
+ if not self.is_data_transfer_rank():
+ self._apply_async_chunk_fanout_packet(fanout_packet)
+ newly_finished = set(fanout_packet["newly_finished"])
+ chunk_finished = set(fanout_packet["chunk_finished"])
+ request_metadata = dict(fanout_packet["request_metadata"])
+ else:
+ with self._lock:
+ newly_finished = set(self._finished_load_reqs)
+ self._finished_load_reqs.clear()
+ chunk_finished = set(self._chunk_finished_req_ids)
+ self._chunk_finished_req_ids.clear()
+ request_metadata = dict(self._local_request_metadata)
+ self._local_request_metadata.clear()
+ # _send_side_request_payload is the async accumulation buffer for
+ # future recv chunks. Clearing it on every consumable wake-up drops
+ # intermediate
+ # thinker decode spans before the model side can consume them.
+ # Only terminal chunk_finished requests may release that buffer.
+ for req_id in chunk_finished:
+ if req_id not in self._local_stage_payload_cache:
+ continue
+ ext_req_id = self._request_ids_mapping.get(req_id, req_id)
+ self._send_side_request_payload.pop(ext_req_id, None)
+ if ext_req_id != req_id:
+ self._send_side_request_payload.pop(req_id, None)
+ self._chunk_ready_req_ids.update(newly_finished)
+
+ output = OmniConnectorOutput(
+ chunk_ready_req_ids=set(self._chunk_ready_req_ids),
+ chunk_finished_req_ids=chunk_finished,
+ request_metadata=request_metadata,
+ kv_sent_req_ids=list(self._kv_sent_req_ids),
+ stage_recv_req_ids=set(self._stage_recv_req_ids),
+ has_pending_kv_work=self.has_pending_kv_work(),
+ )
+ if output.stage_recv_req_ids or chunk_finished or newly_finished:
+ logger.info(
+ "[Stage-%s] get_omni_connector_output: stage_recv=%s, chunk_finished=%s, chunk_ready=%s",
+ self._stage_id,
+ output.stage_recv_req_ids,
+ chunk_finished,
+ output.chunk_ready_req_ids,
+ )
+ self._chunk_ready_req_ids.clear()
+ self._kv_sent_req_ids.clear()
+ self._stage_recv_req_ids.clear()
+ return output
+
+ @staticmethod
+ def _connector_output_has_signals(output: OmniConnectorOutput) -> bool:
+ return bool(
+ output.chunk_ready_req_ids
+ or output.chunk_finished_req_ids
+ or output.request_metadata
+ or output.kv_sent_req_ids
+ or output.stage_recv_req_ids
+ or output.has_pending_kv_work
+ )
+
+ def attach_omni_connector_output(self, result: Any | None) -> Any:
+ omni_output = self.get_omni_connector_output()
+ if not self._connector_output_has_signals(omni_output):
+ return result
+
+ from copy import copy
+
+ from vllm.v1.worker.gpu_model_runner import EMPTY_MODEL_RUNNER_OUTPUT
+
+ wrapped = copy(result if result is not None else EMPTY_MODEL_RUNNER_OUTPUT)
+ wrapped.omni_connector_output = omni_output
+ return wrapped
+
+ # ------------------------------------------------------------------ #
+ # Properties for compatibility with custom_process funcs that access
+ # transfer_manager.put_req_chunk / request_payload / code_prompt_token_ids
+ # ------------------------------------------------------------------ #
+
+ @property
+ def put_req_chunk(self) -> dict[str, int]:
+ return self._put_req_chunk
+
+ @property
+ def request_payload(self) -> dict[str, dict[str, Any]]:
+ return self._send_side_request_payload
+
+ @request_payload.setter
+ def request_payload(self, value: dict[str, dict[str, Any]]) -> None:
+ self._send_side_request_payload = value
+
+ @property
+ def code_prompt_token_ids(self) -> dict[str, list[list[int]]]:
+ return self._code_prompt_token_ids
+
+ @property
+ def connector(self) -> Any | None:
+ return self._omni_connector
+
+ # ------------------------------------------------------------------ #
+ # Background I/O threads
+ # ------------------------------------------------------------------ #
+
+ def _recv_loop(self) -> None:
+ """Background thread: poll connector for incoming data."""
+ _recv_poll_count = 0
+ while not self._stop_event.is_set():
+ with self._lock:
+ pending_ids = list(self._pending_load_reqs.keys())
+
+ if not pending_ids:
+ self._work_available.wait(timeout=0.01)
+ self._work_available.clear()
+ continue
+
+ _recv_poll_count += 1
+ if _recv_poll_count % 5000 == 1:
+ logger.info(
+ "[Stage-%s] _recv_loop: polling %s pending reqs: %s (poll#%s)",
+ self._stage_id,
+ len(pending_ids),
+ pending_ids[:5],
+ _recv_poll_count,
+ )
+
+ made_progress = False
+ for req_id in pending_ids:
+ if self._stop_event.is_set():
+ break
+ try:
+ made_progress = self._poll_single_request(req_id) or made_progress
+ except Exception:
+ logger.warning("Error receiving data for %s", req_id, exc_info=True)
+
+ if not made_progress and not self._stop_event.is_set():
+ self._work_available.wait(timeout=0.001)
+ self._work_available.clear()
+
+ _MAX_SEND_RETRIES = 3
+
+ def _save_loop(self) -> None:
+ """Background thread: send outgoing data via connector."""
+ while not self._stop_event.is_set():
+ task = None
+ with self._lock:
+ for req_id in list(self._pending_save_reqs.keys()):
+ dq = self._pending_save_reqs[req_id]
+ if dq:
+ task = dq.popleft()
+ if not dq:
+ del self._pending_save_reqs[req_id]
+ break
+ del self._pending_save_reqs[req_id]
+
+ if task is not None:
+ success = False
+ try:
+ success = self._send_single_request(task)
+ except Exception:
+ logger.error(
+ "Error saving data for %s",
+ task.get("request_id"),
+ exc_info=True,
+ )
+ if not success:
+ self._requeue_or_drop_failed_send(task)
+ continue
+
+ self._work_available.wait(timeout=0.01)
+ self._work_available.clear()
+
+ def _requeue_or_drop_failed_send(self, task: dict) -> None:
+ """Re-enqueue a failed send task or drop it after max retries."""
+ retry_count = task.get("_retry_count", 0) + 1
+ req_id = task.get("request_id")
+ if retry_count <= self._MAX_SEND_RETRIES:
+ task["_retry_count"] = retry_count
+ logger.warning(
+ "[Stage-%s] Re-enqueuing failed send for %s (retry %d/%d)",
+ getattr(self, "_stage_id", "?"),
+ req_id,
+ retry_count,
+ self._MAX_SEND_RETRIES,
+ )
+ with self._lock:
+ dq = self._pending_save_reqs.setdefault(req_id, deque())
+ dq.appendleft(task)
+ else:
+ logger.error(
+ "[Stage-%s] Giving up on send for %s after %d retries",
+ getattr(self, "_stage_id", "?"),
+ req_id,
+ self._MAX_SEND_RETRIES,
+ )
+ self._decrement_pending_save_count(req_id)
+
+ # ------------------------------------------------------------------ #
+ # Chunk-level poll / send (ported from OmniChunkTransferAdapter)
+ # ------------------------------------------------------------------ #
+
+ def _poll_single_request(self, req_id: str) -> bool:
+ """Poll connector for one chunk of a request (non-blocking)."""
+ connector = self._omni_connector
+ if connector is None:
+ return False
+
+ if self._async_chunk and self._model_mode != "ar":
+ with self._lock:
+ staged_payload = self._local_stage_payload_cache.get(req_id)
+ metadata_in_flight = req_id in self._local_request_metadata
+ scheduler_wakeup_pending = req_id in self._finished_load_reqs
+ if self._payload_is_consumable(staged_payload) or metadata_in_flight or scheduler_wakeup_pending:
+ logger.debug(
+ "[Stage-%s] delaying recv for req=%s until staged async payload is handed to scheduler",
+ self._stage_id,
+ req_id,
+ )
+ return False
+
+ target_stage_id = self._stage_id - 1
+ chunk_id = self._get_req_chunk[req_id]
+ external_req_id = self._request_ids_mapping.get(req_id, req_id)
+ connector_get_key = f"{external_req_id}_{target_stage_id}_{chunk_id}"
+
+ if self._async_chunk:
+ result = self._recv_async_chunk_result(
+ connector,
+ str(target_stage_id),
+ str(self._stage_id),
+ connector_get_key,
+ )
+ else:
+ result = self._recv_full_payload_result(
+ connector,
+ str(target_stage_id),
+ str(self._stage_id),
+ connector_get_key,
+ )
+
+ if result is None:
+ return False
+
+ payload_data, _size = result
+ if not payload_data:
+ return False
+ if isinstance(payload_data, dict):
+ logger.info(
+ "[Stage-%s] recv_chunk_result: req=%s ext=%s key=%s keys=%s finished=%s",
+ self._stage_id,
+ req_id,
+ external_req_id,
+ connector_get_key,
+ sorted(payload_data.keys()),
+ bool(payload_data.get("finished")) if "finished" in payload_data else None,
+ )
+
+ self._get_req_chunk[req_id] += 1
+
+ if self._async_chunk:
+ is_finished = bool(payload_data.get("finished"))
+ incoming_payload_consumable = self._payload_is_consumable(payload_data)
+
+ if self._model_mode == "ar":
+ payload_data = self._accumulate_payload(external_req_id, payload_data)
+ payload_consumable = incoming_payload_consumable
+ else:
+ new_ids = payload_data.get("code_predictor_codes", [])
+ if not new_ids and not is_finished:
+ return False
+ payload_consumable = self._payload_is_consumable(payload_data)
+
+ with self._lock:
+ if is_finished:
+ self._chunk_finished_req_ids.add(req_id)
+ self._chunk_stream_completed.add(req_id)
+ # Local cache (RFC §2.4) — merge, don't replace, so that
+ # earlier chunk keys (e.g. thinker_prefill_embeddings from
+ # chunk 0) are not overwritten by later chunks.
+ existing = self._local_stage_payload_cache.get(req_id)
+ if existing is not None and isinstance(existing, dict) and isinstance(payload_data, dict):
+ existing.update(payload_data)
+ else:
+ self._local_stage_payload_cache[req_id] = payload_data
+ staged_payload = self._local_stage_payload_cache[req_id]
+ self._async_chunk_updated_req_ids.add(req_id)
+ self.put_local_request_metadata(req_id, self._extract_scheduling_metadata(staged_payload))
+ # A finish-only sentinel still needs one terminal wake-up so
+ # the downstream stage can sync the merged local payload and
+ # flush/finish even when the last recv carries no new
+ # consumable chunk bytes.
+ if payload_consumable or is_finished:
+ self._finished_load_reqs.add(req_id)
+ if is_finished and not payload_consumable:
+ logger.debug(
+ "[Stage-%s] finish sentinel arrived for req=%s without new consumable payload",
+ self._stage_id,
+ req_id,
+ )
+ elif not payload_consumable:
+ logger.debug(
+ "[Stage-%s] req=%s received metadata-only / non-consumable async payload; delaying wake-up",
+ self._stage_id,
+ req_id,
+ )
+ if is_finished:
+ self._pending_load_reqs.pop(req_id, None)
+ else:
+ # full_payload_mode: the complete payload arrives in a single get(),
+ # so always unregister immediately.
+ if isinstance(payload_data, dict):
+ engine_inputs = payload_data.get("engine_inputs", payload_data)
+ else:
+ engine_inputs = payload_data
+ with self._lock:
+ self._local_stage_payload_cache[req_id] = self._snapshot_payload(engine_inputs)
+ # Publish full-payload readiness only after the aligned TP broadcast
+ # path in recv_full_payload_inputs() has materialized the payload on all
+ # local ranks. Publishing metadata / stage_recv from the background recv
+ # thread can let the scheduler observe a request before the payload is
+ # actually visible to the model thread.
+ self._full_payload_pending_broadcast_req_ids.add(req_id)
+ self._pending_load_reqs.pop(req_id, None)
+ logger.info(
+ "[Stage-%s] full_payload recv complete: req=%s key=%s payload_type=%s",
+ self._stage_id,
+ req_id,
+ connector_get_key,
+ type(engine_inputs).__name__,
+ )
+
+ logger.debug("[Stage-%s] Received data for key %s", self._stage_id, connector_get_key)
+ return True
+
+ def _build_custom_process_payload(
+ self,
+ request_id: str | None,
+ request: Any | None,
+ pooling_output: Any | None,
+ ) -> Any | None:
+ """Run the custom process hook with a best-effort finished kwarg."""
+ if self._custom_process_func is None:
+ return None
+
+ kwargs = {
+ "transfer_manager": self,
+ "pooling_output": pooling_output,
+ "request": request,
+ }
+ supports_is_finished = getattr(
+ self,
+ "_custom_process_supports_is_finished",
+ self._custom_process_supports_is_finished_kwarg(),
+ )
+ is_finished_fn = getattr(request, "is_finished", None)
+ if callable(is_finished_fn):
+ try:
+ if supports_is_finished is not False:
+ kwargs["is_finished"] = bool(is_finished_fn())
+ except Exception:
+ logger.debug("request.is_finished() failed for %s", request_id, exc_info=True)
+
+ try:
+ return self._custom_process_func(**kwargs)
+ except TypeError as exc:
+ if "is_finished" not in kwargs or not self._is_unexpected_is_finished_kwarg_error(exc):
+ logger.exception("custom_process_stage_input_func failed for chunk %s", request_id)
+ return None
+ kwargs.pop("is_finished", None)
+ try:
+ return self._custom_process_func(**kwargs)
+ except Exception:
+ logger.exception("custom_process_stage_input_func failed for chunk %s", request_id)
+ return None
+ except Exception:
+ logger.exception("custom_process_stage_input_func failed for chunk %s", request_id)
+ return None
+
+ def _custom_process_supports_is_finished_kwarg(self) -> bool | None:
+ """Return whether the custom process hook accepts `is_finished`."""
+ if self._custom_process_func is None:
+ return None
+ try:
+ signature = inspect.signature(self._custom_process_func)
+ except (TypeError, ValueError):
+ return None
+
+ for param in signature.parameters.values():
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
+ return True
+
+ is_finished_param = signature.parameters.get("is_finished")
+ if is_finished_param is None:
+ return False
+ return is_finished_param.kind in (
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.KEYWORD_ONLY,
+ )
+
+ @staticmethod
+ def _is_unexpected_is_finished_kwarg_error(exc: TypeError) -> bool:
+ message = str(exc)
+ return (
+ "unexpected keyword argument 'is_finished'" in message
+ or 'unexpected keyword argument "is_finished"' in message
+ or "positional-only arguments passed as keyword arguments: 'is_finished'" in message
+ )
+
+ def _send_single_request(self, task: dict) -> bool:
+ """Send one queued task via connector.put().
+
+ Returns True on success. On failure (put() raises or returns
+ ``success=False``), returns False **without** decrementing
+ ``_pending_save_counts`` so the caller can retry or clean up.
+ """
+ connector = self._omni_connector
+ if connector is None:
+ return True
+
+ request_id = task.get("request_id")
+ payload_data = task.get("data")
+ if payload_data is None and task.get("request") is not None:
+ payload_data = self._build_custom_process_payload(
+ request_id=request_id,
+ request=task.get("request"),
+ pooling_output=task.get("pooling_output"),
+ )
+ put_key = task.get("put_key")
+
+ success, _size, _metadata = connector.put(
+ from_stage=str(task["stage_id"]),
+ to_stage=str(task["next_stage_id"]),
+ put_key=put_key,
+ data=payload_data,
+ )
+ logger.info(
+ "[Stage-%s] _send_single_request: put_key=%s success=%s size=%s",
+ task["stage_id"],
+ put_key,
+ success,
+ _size,
+ )
+
+ if not success:
+ return False
+
+ self._decrement_pending_save_count(request_id)
+ return True
+
+ def _decrement_pending_save_count(self, request_id: str) -> None:
+ """Decrement pending save count and run deferred cleanup if zero."""
+ cleanup_req_id = None
+ with self._lock:
+ remaining = self._pending_save_counts.get(request_id, 0)
+ if remaining > 1:
+ self._pending_save_counts[request_id] = remaining - 1
+ elif remaining == 1:
+ self._pending_save_counts.pop(request_id, None)
+ if request_id in self._deferred_send_cleanup:
+ self._deferred_send_cleanup.remove(request_id)
+ cleanup_req_id = request_id
+ if cleanup_req_id is not None:
+ self._put_req_chunk.pop(cleanup_req_id, None)
+ self._send_side_request_payload.pop(cleanup_req_id, None)
+ self._code_prompt_token_ids.pop(cleanup_req_id, None)
+
+ # ------------------------------------------------------------------ #
+ # Payload accumulation (ported from OmniChunkTransferAdapter)
+ # ------------------------------------------------------------------ #
+
+ def _accumulate_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]:
+ """Accumulate chunk payloads (concat tensors, extend lists).
+
+ Returns a **shallow copy** of the accumulated state so callers
+ (e.g. ``_poll_single_request``) can store it in
+ ``_local_stage_payload_cache`` without aliasing the authoritative
+ ``_send_side_request_payload`` dict.
+ """
+ if req_id not in self._send_side_request_payload:
+ self._send_side_request_payload[req_id] = dict(payload_data)
+ return dict(self._send_side_request_payload[req_id])
+
+ origin = self._send_side_request_payload[req_id]
+ merged = dict(origin)
+ override_keys = payload_data.get("override_keys", ())
+ drop_decode_span = False
+ decode_span_handled = False
+ for key, value in payload_data.items():
+ if key == "finished":
+ merged[key] = value
+ continue
+ if key == THINKER_DECODE_EMBEDDINGS_KEY:
+ merged_span = merge_tensor_spans(
+ get_tensor_span(
+ origin,
+ tensor_key=THINKER_DECODE_EMBEDDINGS_KEY,
+ start_key=THINKER_DECODE_TOKEN_START_KEY,
+ end_key=THINKER_DECODE_TOKEN_END_KEY,
+ ),
+ get_tensor_span(
+ payload_data,
+ tensor_key=THINKER_DECODE_EMBEDDINGS_KEY,
+ start_key=THINKER_DECODE_TOKEN_START_KEY,
+ end_key=THINKER_DECODE_TOKEN_END_KEY,
+ ),
+ )
+ if merged_span is not None:
+ merged[key], merged[THINKER_DECODE_TOKEN_START_KEY], merged[THINKER_DECODE_TOKEN_END_KEY] = (
+ merged_span
+ )
+ decode_span_handled = True
+ continue
+ if isinstance(value, torch.Tensor) and key in origin:
+ if (
+ THINKER_DECODE_TOKEN_START_KEY in origin
+ or THINKER_DECODE_TOKEN_END_KEY in origin
+ or THINKER_DECODE_TOKEN_START_KEY in payload_data
+ or THINKER_DECODE_TOKEN_END_KEY in payload_data
+ ):
+ logger.warning(
+ "[Stage-%s] req=%s falling back to legacy thinker decode "
+ "merge due to missing/invalid/non-contiguous span "
+ "metadata",
+ self._stage_id,
+ req_id,
+ )
+ drop_decode_span = True
+ merged[key] = torch.cat([origin[key], value], dim=0)
+ continue
+ merged[key] = value
+ continue
+ if key in {THINKER_DECODE_TOKEN_START_KEY, THINKER_DECODE_TOKEN_END_KEY}:
+ if decode_span_handled or drop_decode_span:
+ continue
+ merged[key] = value
+ continue
+ if key in override_keys:
+ merged[key] = value
+ continue
+ if isinstance(value, torch.Tensor) and key in origin:
+ merged[key] = torch.cat([origin[key], value], dim=0)
+ elif isinstance(value, list) and key in origin:
+ merged[key] = origin[key] + value
+ else:
+ merged[key] = value
+
+ if drop_decode_span:
+ merged.pop(THINKER_DECODE_TOKEN_START_KEY, None)
+ merged.pop(THINKER_DECODE_TOKEN_END_KEY, None)
+ self._send_side_request_payload[req_id] = merged
+ return dict(merged)
+
+ def drop_inactive_request_runtime_state(self, req_id: str) -> None:
+ """Clear inactive request state used by both the runner and mixin.
+
+ This centralizes the model-runner-side cleanup pattern so
+ ``OmniGPUModelRunner`` can reuse it instead of open-coding the same
+ inactive-request state mutations.
+ """
+ if hasattr(self, "model_intermediate_buffer"):
+ self.model_intermediate_buffer.pop(req_id, None)
+ self.drop_inactive_request_delivery_state(req_id)
+
+ # ------------------------------------------------------------------ #
+ # Helpers
+ # ------------------------------------------------------------------ #
+
+ @staticmethod
+ def _freeze_request_attr(value: Any) -> Any:
+ if isinstance(value, list):
+ return list(value)
+ if isinstance(value, tuple):
+ return list(value)
+ if isinstance(value, torch.Tensor):
+ return value.clone()
+ raw_list = getattr(value, "_x", None)
+ if raw_list is not None:
+ return list(raw_list)
+ return value
+
+ def _snapshot_request_for_send(self, request: Any, external_req_id: str) -> Any:
+ finished = bool(getattr(request, "is_finished", lambda: False)())
+ attrs: dict[str, Any] = {}
+ try:
+ attrs.update(vars(request))
+ except TypeError:
+ pass
+
+ for name in (
+ "request_id",
+ "req_id",
+ "external_req_id",
+ "prompt_token_ids",
+ "output_token_ids",
+ "all_token_ids",
+ "additional_information",
+ "sampling_params",
+ "multi_modal_data",
+ "mm_hashes",
+ ):
+ if hasattr(request, name):
+ attrs[name] = self._freeze_request_attr(getattr(request, name))
+
+ attrs["external_req_id"] = external_req_id
+ attrs["_frozen_is_finished"] = finished
+ snapshot = SimpleNamespace(**attrs)
+ snapshot.is_finished = lambda: finished
+ return snapshot
+
+ @staticmethod
+ def _create_connector(model_config: Any) -> OmniConnectorBase | None:
+ """Create a connector from model_config, or None if unconfigured."""
+ connector_config = getattr(model_config, "stage_connector_config", None)
+ if connector_config is None:
+ return None
+
+ if not isinstance(connector_config, dict):
+ connector_config = {
+ "name": getattr(connector_config, "name", None),
+ "extra": getattr(connector_config, "extra", None),
+ }
+
+ name = connector_config.get("name")
+ if not isinstance(name, str) or not name.strip():
+ raise RuntimeError("Invalid stage connector config: missing connector name")
+ name = name.strip()
+
+ extra = connector_config.get("extra")
+ if extra is None:
+ extra = {}
+ elif not isinstance(extra, dict):
+ raise RuntimeError(f"Invalid extra config for connector {name}: expected dict, got {type(extra).__name__}")
+
+ spec = ConnectorSpec(name=name, extra=extra)
+ try:
+ return OmniConnectorFactory.create_connector(spec)
+ except Exception as exc:
+ raise RuntimeError(f"Failed to create connector {name}") from exc
+
+ @staticmethod
+ def _load_custom_func(model_config: Any) -> tuple[str | None, Any | None]:
+ """Load the connector payload builder for the downstream stage.
+
+ Preferred source is ``custom_process_next_stage_input_func``. Some
+ full_payload_mode configs (async_chunk=false) only expose the next-stage prompt builder via
+ ``custom_process_input_func`` (for example ``thinker2talker``), while the
+ connector payload builder lives beside it as ``thinker2talker_full_payload``.
+ In that case, derive the full_payload_mode builder path automatically.
+ """
+ candidates: list[str] = []
+
+ next_stage_func = getattr(model_config, "custom_process_next_stage_input_func", None)
+ if isinstance(next_stage_func, str) and next_stage_func:
+ candidates.append(next_stage_func)
+
+ if not getattr(model_config, "async_chunk", False):
+ input_func = getattr(model_config, "custom_process_input_func", None)
+ if isinstance(input_func, str) and input_func:
+ try:
+ module_path, func_name = input_func.rsplit(".", 1)
+ if func_name.endswith("_full_payload") or func_name.endswith("_batch"):
+ candidates.append(f"{module_path}.{func_name}")
+ else:
+ candidates.append(f"{module_path}.{func_name}_full_payload")
+ candidates.append(f"{module_path}.{func_name}_batch")
+ candidates.append(input_func)
+ except ValueError:
+ candidates.append(input_func)
+
+ tried: set[str] = set()
+ for func_path in candidates:
+ if func_path in tried:
+ continue
+ tried.add(func_path)
+ try:
+ module_path, func_name = func_path.rsplit(".", 1)
+ module = importlib.import_module(module_path)
+ func = getattr(module, func_name, None)
+ if callable(func):
+ if not OmniConnectorModelRunnerMixin._is_connector_payload_builder(func):
+ logger.debug(
+ "Skipping incompatible connector payload hook %s; signature=%s",
+ func_path,
+ inspect.signature(func),
+ )
+ continue
+ return func_path, func
+ except Exception:
+ logger.warning("Failed to load custom func: %s", func_path, exc_info=True)
+
+ return None, None
+
+ @staticmethod
+ def _is_connector_payload_builder(func: Any) -> bool:
+ """Whether *func* matches the mixin payload-builder contract."""
+ try:
+ signature = inspect.signature(func)
+ except (TypeError, ValueError):
+ return False
+
+ params = signature.parameters
+ if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()):
+ return True
+
+ required = {"transfer_manager", "pooling_output", "request"}
+ supported = {
+ name
+ for name, param in params.items()
+ if param.kind
+ in (
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.KEYWORD_ONLY,
+ )
+ }
+ return required.issubset(supported)
+
+ def _resolve_external_req_id(self, request: Any, fallback_req_id: str) -> str:
+ """Resolve the external request ID consistently.
+
+ Checks ``_request_ids_mapping`` first (populated by
+ ``register_chunk_recv``), then falls back to the request's
+ ``external_req_id`` attribute, and finally to the given
+ ``fallback_req_id``.
+ """
+ mapped = self._request_ids_mapping.get(fallback_req_id)
+ if mapped is not None:
+ return mapped
+ if request is not None:
+ return getattr(request, "external_req_id", fallback_req_id)
+ return fallback_req_id
+
+ def _resolve_next_stage_id(self, model_config: Any) -> int:
+ """Determine the downstream stage ID from connector config.
+
+ Falls back to ``stage_id + 1`` when the config does not specify
+ a ``to_stage`` explicitly.
+ """
+ connector_config = getattr(model_config, "stage_connector_config", None)
+ if connector_config is not None:
+ if isinstance(connector_config, dict):
+ to_stage = connector_config.get("to_stage")
+ else:
+ to_stage = getattr(connector_config, "to_stage", None)
+ if isinstance(to_stage, int):
+ return to_stage
+ if isinstance(to_stage, str) and to_stage.strip():
+ return int(to_stage)
+ return self._stage_id + 1
+
+ @staticmethod
+ def _parse_rank_mapping(model_config: Any) -> dict[str, int]:
+ """Parse rank_mapping from connector config (optional).
+
+ Returns ``{"from_tp": int, "to_tp": int, "local_rank": int}``.
+ When ``rank_mapping`` is absent, assumes 1:1 homogeneous mapping.
+ """
+ connector_config = getattr(model_config, "stage_connector_config", None)
+ if connector_config is not None and not isinstance(connector_config, dict):
+ connector_config = getattr(connector_config, "__dict__", {})
+
+ rank_mapping: dict = {}
+ if isinstance(connector_config, dict):
+ rank_mapping = connector_config.get("rank_mapping", {})
+
+ from_tp = int(rank_mapping.get("from_tp", 1))
+ to_tp = int(rank_mapping.get("to_tp", 1))
+
+ local_rank = 0
+ try:
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
+ except (ValueError, TypeError):
+ pass
+
+ return {"from_tp": from_tp, "to_tp": to_tp, "local_rank": local_rank}
+
+ # ------------------------------------------------------------------ #
+ # Heterogeneous TP rank support
+ # ------------------------------------------------------------------ #
+
+ def _validate_kv_tp_topology(self) -> None:
+ """Reject heterogeneous TP mappings that cannot be routed losslessly."""
+ if self._from_tp <= 0 or self._to_tp <= 0:
+ raise ValueError(f"Invalid KV TP mapping: from_tp={self._from_tp}, to_tp={self._to_tp}")
+ larger = max(self._from_tp, self._to_tp)
+ smaller = min(self._from_tp, self._to_tp)
+ if larger % smaller != 0:
+ raise ValueError(
+ f"KV TP mapping must be divisible for rank-aware routing: from_tp={self._from_tp}, to_tp={self._to_tp}"
+ )
+
+ def get_kv_remote_ranks(self) -> list[int]:
+ """Determine which remote ranks this local rank exchanges KV with.
+
+ Follows vLLM's ``TpKVTopology.get_target_remote_ranks()`` pattern:
+ - ``from_tp > to_tp``: each to-rank reads from multiple from-ranks
+ - ``from_tp < to_tp``: multiple to-ranks read from the same from-rank
+ - ``from_tp == to_tp``: 1:1 mapping
+ """
+ self._validate_kv_tp_topology()
+ if self._from_tp == self._to_tp:
+ return [self._local_rank]
+
+ if self._from_tp > self._to_tp:
+ tp_ratio = self._from_tp // self._to_tp
+ return [self._local_rank * tp_ratio + i for i in range(tp_ratio)]
+ else:
+ tp_ratio = self._to_tp // self._from_tp
+ return [self._local_rank // tp_ratio]
+
+ def is_data_transfer_rank(self) -> bool:
+ """Whether this rank should participate in data (non-KV) transfer.
+
+ Ordinary stage payloads are TP-identical, so exactly one TP rank
+ should talk to the connector. When TP is initialized, use TP rank 0
+ so the connector leader matches TP-local broadcast source rank.
+ Otherwise fall back to LOCAL_RANK==0 for the single-rank case.
+ """
+ tp_group = self._get_local_tp_group()
+ if tp_group is not None and getattr(tp_group, "world_size", 1) > 1:
+ return getattr(tp_group, "rank_in_group", 0) == 0
+ return self._local_rank == 0
+
+ def get_kv_connector_key(
+ self,
+ req_id: str,
+ from_stage: int,
+ chunk_id: int,
+ from_rank: int,
+ to_rank: int,
+ ) -> str:
+ """Build connector key that includes rank info for KV transfers."""
+ return f"{req_id}_{from_stage}_{chunk_id}_{from_rank}_{to_rank}"
diff --git a/vllm_omni/worker/payload_span.py b/vllm_omni/worker/payload_span.py
new file mode 100644
index 00000000000..994392343a9
--- /dev/null
+++ b/vllm_omni/worker/payload_span.py
@@ -0,0 +1,64 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Helpers for explicit thinker decode span metadata."""
+
+from collections.abc import Mapping
+from typing import Any
+
+import torch
+
+THINKER_DECODE_EMBEDDINGS_KEY = "thinker_decode_embeddings"
+THINKER_OUTPUT_TOKEN_IDS_KEY = "thinker_output_token_ids"
+THINKER_DECODE_TOKEN_START_KEY = "thinker_decode_embeddings_token_start"
+THINKER_DECODE_TOKEN_END_KEY = "thinker_decode_embeddings_token_end"
+
+CACHED_THINKER_DECODE_EMBEDDINGS_KEY = "cached_thinker_decode_embeddings"
+CACHED_THINKER_DECODE_TOKEN_START_KEY = "cached_thinker_decode_embeddings_token_start"
+CACHED_THINKER_DECODE_TOKEN_END_KEY = "cached_thinker_decode_embeddings_token_end"
+
+TensorSpan = tuple[torch.Tensor, int, int]
+
+
+def get_tensor_span(payload: Mapping[str, Any], *, tensor_key: str, start_key: str, end_key: str) -> TensorSpan | None:
+ tensor = payload.get(tensor_key)
+ start = payload.get(start_key)
+ end = payload.get(end_key)
+ if not isinstance(tensor, torch.Tensor):
+ return None
+ if not isinstance(start, int) or not isinstance(end, int):
+ return None
+ if start < 0 or end < start or (end - start) != int(tensor.shape[0]):
+ return None
+ return tensor, start, end
+
+
+def merge_tensor_spans(existing_span: TensorSpan | None, incoming_span: TensorSpan | None) -> TensorSpan | None:
+ if existing_span is None or incoming_span is None:
+ return None
+
+ existing_tensor, existing_start, existing_end = existing_span
+ incoming_tensor, incoming_start, incoming_end = incoming_span
+ if incoming_tensor.device != existing_tensor.device or incoming_tensor.dtype != existing_tensor.dtype:
+ incoming_tensor = incoming_tensor.to(device=existing_tensor.device, dtype=existing_tensor.dtype)
+ if incoming_start == existing_end:
+ return torch.cat([existing_tensor, incoming_tensor], dim=0), existing_start, incoming_end
+ if incoming_start < existing_end:
+ overlap = existing_end - incoming_start
+ if overlap >= int(incoming_tensor.shape[0]):
+ return existing_tensor, existing_start, existing_end
+ trimmed_tensor = incoming_tensor[overlap:]
+ return (
+ torch.cat([existing_tensor, trimmed_tensor], dim=0),
+ existing_start,
+ existing_end + int(trimmed_tensor.shape[0]),
+ )
+ return None
+
+
+def get_tensor_span_row(span: TensorSpan | None, index: int) -> torch.Tensor | None:
+ if span is None:
+ return None
+ tensor, start, end = span
+ if index < start or index >= end:
+ return None
+ return tensor[index - start]
From cd2761e15c8e49ea7c53cd551f820318155b4988 Mon Sep 17 00:00:00 2001
From: JohnJan
Date: Mon, 13 Apr 2026 17:51:48 +0800
Subject: [PATCH 14/76] [Feature]: support Flux.2-dev tea_cache (#1871)
Co-authored-by: wuzhongjian
---
docs/user_guide/diffusion_features.md | 2 +-
.../cache/test_teacache_extractors.py | 105 ++++++++++++-
.../cache/teacache/coefficient_estimator.py | 27 ++++
vllm_omni/diffusion/cache/teacache/config.py | 9 ++
.../diffusion/cache/teacache/extractors.py | 140 ++++++++++++++++++
5 files changed, 281 insertions(+), 2 deletions(-)
diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md
index 2f28131ee55..ac140ff84a0 100644
--- a/docs/user_guide/diffusion_features.md
+++ b/docs/user_guide/diffusion_features.md
@@ -110,7 +110,7 @@ The following tables show which models support each feature:
| **FLUX.1-dev** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
| **FLUX.2-klein** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
| **FLUX.1-Kontext-dev** | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
-| **FLUX.2-dev** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
+| **FLUX.2-dev** | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **GLM-Image** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **HunyuanImage3** | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| **LongCat-Image** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
diff --git a/tests/diffusion/cache/test_teacache_extractors.py b/tests/diffusion/cache/test_teacache_extractors.py
index a52e11b3d46..c22a60e227e 100644
--- a/tests/diffusion/cache/test_teacache_extractors.py
+++ b/tests/diffusion/cache/test_teacache_extractors.py
@@ -22,7 +22,7 @@
import torch
from tests.utils import hardware_test
-from vllm_omni.diffusion.cache.teacache.extractors import extract_flux2_klein_context
+from vllm_omni.diffusion.cache.teacache.extractors import extract_flux2_context, extract_flux2_klein_context
from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import (
Flux2Transformer2DModel,
)
@@ -174,3 +174,106 @@ def test_invalid_module_raises_error(self):
img_ids=torch.randint(0, 64, (1, 1024, 4)),
txt_ids=torch.randint(0, 64, (1, 512, 4)),
)
+
+
+class TestFlux2Extractor(BaseExtractorTest):
+ """Test extract_flux2_context function."""
+
+ def get_extractor(self):
+ return extract_flux2_context
+
+ @pytest.fixture
+ def flux2_module(self):
+ """Create a minimal Flux2Transformer2DModel for testing."""
+ from vllm_omni.diffusion.models.flux2.flux2_transformer import Flux2Transformer2DModel
+
+ model = Flux2Transformer2DModel(
+ num_layers=2,
+ num_single_layers=2,
+ num_attention_heads=48,
+ attention_head_dim=128,
+ joint_attention_dim=15360,
+ )
+ return model
+
+ def get_module(self, flux2_module):
+ return flux2_module
+
+ @pytest.fixture
+ def sample_inputs(self):
+ """Create sample input tensors for Flux2.
+
+ Note: hidden_states uses in_channels=128 (default for Flux2),
+ not inner_dim=6144. The x_embedder projects from 128 -> 6144.
+ encoder_hidden_states uses joint_attention_dim=15360 (model default),
+ which then gets projected to inner_dim=6144 by context_embedder.
+ """
+ batch_size = 1
+ img_seq_len = 1024
+ txt_seq_len = 512
+ in_channels = 128 # Model default in_channels
+ txt_dim = 15360 # Model default joint_attention_dim
+
+ return {
+ "hidden_states": torch.randn(batch_size, img_seq_len, in_channels),
+ "encoder_hidden_states": torch.randn(batch_size, txt_seq_len, txt_dim),
+ "timestep": torch.tensor([500]),
+ "img_ids": torch.randint(0, 64, (batch_size, img_seq_len, 4)),
+ "txt_ids": torch.randint(0, 64, (batch_size, txt_seq_len, 4)),
+ "guidance": torch.tensor([3.5]),
+ }
+
+ def get_sample_inputs(self, sample_inputs):
+ return sample_inputs
+
+ @hardware_test(res={"cuda": "L4"}, num_cards=1)
+ def test_modulated_input_shape(self, flux2_module, sample_inputs):
+ """Test that modulated_input has correct shape matching the model's inner_dim.
+
+ Note: After x_embedder projection, hidden_states are projected from
+ in_channels (128) to inner_dim (6144), so modulated_input should match
+ the projected shape, not the input shape.
+ """
+ context = extract_flux2_klein_context(flux2_module, **sample_inputs)
+
+ batch_size, img_seq_len, _ = sample_inputs["hidden_states"].shape
+ inner_dim = flux2_module.inner_dim
+ assert context.modulated_input.shape == (batch_size, img_seq_len, inner_dim)
+
+ @hardware_test(res={"cuda": "L4"}, num_cards=1)
+ def test_run_transformer_blocks_callable(self, flux2_module, sample_inputs):
+ """Test that run_transformer_blocks is callable."""
+ context = extract_flux2_context(flux2_module, **sample_inputs)
+ assert callable(context.run_transformer_blocks)
+
+ @hardware_test(res={"cuda": "L4"}, num_cards=1)
+ def test_postprocess_callable(self, flux2_module, sample_inputs):
+ """Test that postprocess is callable."""
+ context = extract_flux2_context(flux2_module, **sample_inputs)
+ assert callable(context.postprocess)
+
+ def test_without_guidance(self, flux2_module, sample_inputs):
+ """Test context extraction works without guidance (no CFG)."""
+ inputs = sample_inputs.copy()
+ inputs["guidance"] = None
+
+ context = extract_flux2_context(flux2_module, **inputs)
+
+ assert context is not None
+ assert context.temb is not None
+
+ @pytest.mark.cpu
+ def test_invalid_module_raises_error(self):
+ """Test that invalid module without transformer_blocks raises ValueError."""
+ invalid_module = Mock()
+ invalid_module.transformer_blocks = []
+
+ with pytest.raises(ValueError, match="Module must have transformer_blocks"):
+ extract_flux2_context(
+ invalid_module,
+ hidden_states=torch.randn(1, 1024, 6144),
+ encoder_hidden_states=torch.randn(1, 512, 15360),
+ timestep=torch.tensor([500]),
+ img_ids=torch.randint(0, 64, (1, 1024, 4)),
+ txt_ids=torch.randint(0, 64, (1, 512, 4)),
+ )
diff --git a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
index 5dd80718d11..baec21c2762 100644
--- a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
+++ b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
@@ -13,6 +13,7 @@
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline
+from vllm_omni.diffusion.models.flux2.pipeline_flux2 import Flux2Pipeline
from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import StableAudioPipeline
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -103,6 +104,31 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry.register_hook(hook._HOOK_NAME, hook)
+class Flux2Adapter:
+ """Adapter for Flux2 model coefficient estimation."""
+
+ @staticmethod
+ def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> Flux2Pipeline:
+ """Load Flux2 pipeline for coefficient estimation."""
+ od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
+ od_config.model_class_name = "Flux2Pipeline"
+
+ pipeline = Flux2Pipeline(od_config=od_config)
+ loader = DiffusersPipelineLoader(LoadConfig())
+ loader.load_weights(pipeline)
+ pipeline.to(device)
+ return pipeline
+
+ @staticmethod
+ def get_transformer(pipeline: Any) -> tuple[Any, str]:
+ return pipeline.transformer, pipeline.transformer.__class__.__name__
+
+ @staticmethod
+ def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
+ registry = HookRegistry.get_or_create(transformer)
+ registry.register_hook(hook._HOOK_NAME, hook)
+
+
class DefaultAdapter:
"""Default adapter for standard diffusers pipelines."""
@@ -123,6 +149,7 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
_MODEL_ADAPTERS: dict[str, type] = {
"Bagel": BagelAdapter,
"StableAudio": StableAudioAdapter,
+ "Flux2": Flux2Adapter,
}
_EPSILON = 1e-6
diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py
index 96cf3f03eec..ecf3bfc1d3d 100644
--- a/vllm_omni/diffusion/cache/teacache/config.py
+++ b/vllm_omni/diffusion/cache/teacache/config.py
@@ -64,6 +64,15 @@
-1.04182570e01,
6.78098549e-01,
],
+ # Flux2 transformer coefficients
+ # Copied from Qwen-Image, need to be tuned specifically for Flux2 in future
+ "Flux2Transformer2DModel": [
+ -4.50000000e02,
+ 2.80000000e02,
+ -4.50000000e01,
+ 3.20000000e00,
+ -2.00000000e-02,
+ ],
}
diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py
index bdb3f6a7865..3d247e31878 100644
--- a/vllm_omni/diffusion/cache/teacache/extractors.py
+++ b/vllm_omni/diffusion/cache/teacache/extractors.py
@@ -21,6 +21,7 @@
import torch.nn as nn
from vllm_omni.diffusion.forward_context import get_forward_context
+from vllm_omni.platforms import current_omni_platform
@dataclass
@@ -827,6 +828,144 @@ def postprocess(h: torch.Tensor) -> Any:
)
+def extract_flux2_context(
+ module: nn.Module,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor | None = None,
+ joint_attention_kwargs: dict[str, Any] | None = None,
+ return_dict: bool = True,
+ **kwargs: Any,
+) -> CacheContext:
+ """
+ Extract cache context for Flux2Transformer2DModel.
+
+ This is the ONLY Flux2-specific code needed for TeaCache support.
+ It encapsulates preprocessing, modulated input extraction, transformer execution,
+ and postprocessing logic.
+
+ Args:
+ module: Flux2Transformer2DModel instance
+ hidden_states: Input hidden states tensor
+ encoder_hidden_states: Text encoder outputs
+ timestep: Current diffusion timestep
+ img_ids: Image inputs for position embedding
+ txt_ids: Text inputs for position embedding
+ guidance: Optional guidance scale for CFG
+ joint_attention_kwargs: Additional attention arguments
+ return_dict: Whether to return a Transformer2DModelOutput instead of a plain tensor
+ **kwargs: Additional keyword arguments ignored by this extractor
+
+ Returns:
+ CacheContext with all information needed for generic caching
+ """
+
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
+
+ if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0:
+ raise ValueError("Module must have transformer_blocks")
+
+ # ============================================================================
+ # PREPROCESSING (Flux2-specific)
+ # ============================================================================
+ num_txt_tokens = encoder_hidden_states.shape[1]
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+
+ temb = module.time_guidance_embed(timestep, guidance)
+
+ double_stream_mod_img = module.double_stream_modulation_img(temb)
+ double_stream_mod_txt = module.double_stream_modulation_txt(temb)
+ single_stream_mod = module.single_stream_modulation(temb)[0]
+
+ hidden_states = module.x_embedder(hidden_states)
+ encoder_hidden_states = module.context_embedder(encoder_hidden_states)
+
+ if img_ids.ndim == 3:
+ img_ids = img_ids[0]
+ if txt_ids.ndim == 3:
+ txt_ids = txt_ids[0]
+
+ if current_omni_platform.is_npu():
+ freqs_cos_image, freqs_sin_image = module.pos_embed(img_ids.cpu())
+ image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
+ freqs_cos_text, freqs_sin_text = module.pos_embed(txt_ids.cpu())
+ text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
+ else:
+ image_rotary_emb = module.pos_embed(img_ids)
+ text_rotary_emb = module.pos_embed(txt_ids)
+ concat_rotary_emb = (
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
+ )
+
+ # ============================================================================
+ # EXTRACT MODULATED INPUT (for cache decision)
+ # ============================================================================
+ block = module.transformer_blocks[0]
+ (shift_msa, scale_msa, gate_msa), _ = double_stream_mod_img
+ modulated_input = block.norm1(hidden_states)
+ modulated_input = (1 + scale_msa) * modulated_input + shift_msa
+
+ # ============================================================================
+ # DEFINE TRANSFORMER EXECUTION (Flux2-specific)
+ # ============================================================================
+ def run_transformer_blocks():
+ """Execute all Flux2 transformer blocks."""
+ h = hidden_states
+ e = encoder_hidden_states
+
+ for transformer_block in module.transformer_blocks:
+ e, h = transformer_block(
+ hidden_states=h,
+ encoder_hidden_states=e,
+ temb_mod_params_img=double_stream_mod_img,
+ temb_mod_params_txt=double_stream_mod_txt,
+ image_rotary_emb=concat_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+ h = torch.cat([e, h], dim=1)
+
+ for single_transformer_block in module.single_transformer_blocks:
+ h = single_transformer_block(
+ hidden_states=h,
+ encoder_hidden_states=None,
+ temb_mod_params=single_stream_mod,
+ image_rotary_emb=concat_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ h = h[:, num_txt_tokens:, ...]
+ return (h,)
+
+ # ============================================================================
+ # DEFINE POSTPROCESSING
+ # ============================================================================
+ def postprocess(h):
+ h = module.norm_out(h, temb)
+ output = module.proj_out(h)
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
+
+ # ============================================================================
+ # RETURN CONTEXT
+ # ============================================================================
+ return CacheContext(
+ modulated_input=modulated_input,
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ run_transformer_blocks=run_transformer_blocks,
+ postprocess=postprocess,
+ )
+
+
# Registry for model-specific extractors
# Key: Transformer class name
# Value: extractor function with signature (module, *args, **kwargs) -> CacheContext
@@ -839,6 +978,7 @@ def postprocess(h: torch.Tensor) -> Any:
"ZImageTransformer2DModel": extract_zimage_context,
"Flux2Klein": extract_flux2_klein_context,
"StableAudioDiTModel": extract_stable_audio_context,
+ "Flux2Transformer2DModel": extract_flux2_context,
# Future models:
# "FluxTransformer2DModel": extract_flux_context,
# "CogVideoXTransformer3DModel": extract_cogvideox_context,
From 155583f49f9a20477ea95a0119a7abfddbf0c646 Mon Sep 17 00:00:00 2001
From: Chenguang Zheng <645327136@qq.com>
Date: Mon, 13 Apr 2026 18:35:59 +0800
Subject: [PATCH 15/76] [Bugfix] Release stage launch lock before handshake
(#2717)
Signed-off-by: Chenguang ZHENG <645327136@qq.com>
---
.../test_async_omni_engine_stage_init.py | 89 +++++++++++++++++++
vllm_omni/engine/async_omni_engine.py | 23 ++---
2 files changed, 101 insertions(+), 11 deletions(-)
diff --git a/tests/engine/test_async_omni_engine_stage_init.py b/tests/engine/test_async_omni_engine_stage_init.py
index 6993f391ebc..7b995fe70db 100644
--- a/tests/engine/test_async_omni_engine_stage_init.py
+++ b/tests/engine/test_async_omni_engine_stage_init.py
@@ -227,6 +227,95 @@ def _capture_stage_timeout(_proc, _handshake_addr, _addresses, _vllm_cfg, handsh
assert captured_timeout == 302
+def test_launch_llm_stage_releases_launch_lock_before_complete_stage_handshake(monkeypatch):
+ """Regression test for parallel LLM stage startup during handshake wait."""
+ import vllm_omni.engine.async_omni_engine as engine_mod
+ from vllm_omni.platforms import current_omni_platform
+
+ engine = object.__new__(AsyncOmniEngine)
+ engine.log_stats = False
+ engine.model = "dummy-model"
+ engine.single_stage_mode = False
+ engine._omni_master_server = None
+
+ fake_vllm_config = types.SimpleNamespace()
+ fake_addresses = types.SimpleNamespace()
+ shared_launch_lock = threading.Lock()
+ counter_lock = threading.Lock()
+ first_handshake_started = threading.Event()
+ second_stage_spawned = threading.Event()
+ allow_first_handshake_to_finish = threading.Event()
+ launch_errors: list[BaseException] = []
+ spawn_count = 0
+
+ device_env_var = current_omni_platform.device_control_env_var
+ prev_device_env = os.environ.get(device_env_var)
+ os.environ[device_env_var] = "0"
+
+ monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None)
+ monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {})
+ monkeypatch.setattr(engine_mod, "build_vllm_config", lambda *_, **__: (fake_vllm_config, object))
+ monkeypatch.setattr(engine_mod, "acquire_device_locks", lambda *_: [])
+
+ def _spawn_stage_core(**_):
+ nonlocal spawn_count
+ with counter_lock:
+ spawn_count += 1
+ call_idx = spawn_count
+ if call_idx == 2:
+ second_stage_spawned.set()
+ return fake_addresses, types.SimpleNamespace(), f"ipc://handshake-{call_idx}"
+
+ def _complete_stage_handshake(_proc, handshake_address, _addresses, _vllm_cfg, _timeout):
+ if handshake_address == "ipc://handshake-1":
+ first_handshake_started.set()
+ assert second_stage_spawned.wait(timeout=1), (
+ "second stage did not reach spawn_stage_core while first stage waited in handshake"
+ )
+ assert allow_first_handshake_to_finish.wait(timeout=1), (
+ "second stage did not enter handshake while first stage was still waiting"
+ )
+ else:
+ allow_first_handshake_to_finish.set()
+
+ monkeypatch.setattr(engine_mod, "spawn_stage_core", _spawn_stage_core)
+ monkeypatch.setattr(engine_mod, "complete_stage_handshake", _complete_stage_handshake)
+
+ def _launch_stage(stage_id: int) -> None:
+ metadata = types.SimpleNamespace(stage_id=stage_id, runtime_cfg={"devices": str(stage_id)})
+ try:
+ engine._launch_llm_stage(
+ stage_cfg=types.SimpleNamespace(engine_args={}),
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=302,
+ llm_stage_launch_lock=shared_launch_lock,
+ )
+ except BaseException as exc: # pragma: no cover - surfaced through assertion below
+ launch_errors.append(exc)
+
+ try:
+ first_thread = threading.Thread(target=_launch_stage, args=(0,))
+ first_thread.start()
+ assert first_handshake_started.wait(timeout=1), "first stage never entered handshake"
+
+ second_thread = threading.Thread(target=_launch_stage, args=(1,))
+ second_thread.start()
+
+ first_thread.join(timeout=3)
+ second_thread.join(timeout=3)
+ finally:
+ if prev_device_env is None:
+ os.environ.pop(device_env_var, None)
+ else:
+ os.environ[device_env_var] = prev_device_env
+
+ assert not first_thread.is_alive()
+ assert not second_thread.is_alive()
+ assert second_stage_spawned.is_set()
+ assert not launch_errors
+
+
def test_attach_llm_stage_uses_omni_input_preprocessor(monkeypatch):
"""Regression test for GLM-Image t2i preprocessing path.
diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py
index 0a2e02d66ef..9609cf6e26b 100644
--- a/vllm_omni/engine/async_omni_engine.py
+++ b/vllm_omni/engine/async_omni_engine.py
@@ -424,23 +424,24 @@ def _launch_llm_stage(
proc=proc,
)
logger.info("[AsyncOmniEngine] Stage %s engine launch started", metadata.stage_id)
- # Keep the stage-specific device visibility until vLLM
- # finishes starting all child processes.
- if self.single_stage_mode and self._omni_master_server is not None:
- launch_stack.close()
- else:
- assert proc is not None
- assert handshake_address is not None
- complete_stage_handshake(
- proc, handshake_address, addresses, vllm_config, stage_init_timeout
- )
- logger.info("[AsyncOmniEngine] Stage %s engine startup completed", metadata.stage_id)
finally:
if previous_visible_devices is None:
current_omni_platform.unset_device_control_env_var()
else:
current_omni_platform.set_device_control_env_var(previous_visible_devices)
+ # After StageEngineCoreProc has been spawned it carries its
+ # stage-specific device visibility into descendants, so the
+ # slow HELLO/READY handshake can run without holding the
+ # process-wide launch lock.
+ if self.single_stage_mode and self._omni_master_server is not None:
+ launch_stack.close()
+ else:
+ assert proc is not None
+ assert handshake_address is not None
+ complete_stage_handshake(proc, handshake_address, addresses, vllm_config, stage_init_timeout)
+ logger.info("[AsyncOmniEngine] Stage %s engine startup completed", metadata.stage_id)
+
assert started_stage is not None
return started_stage
except Exception:
From ef3f72b9ae0bee0baf45258abde55bec3ae6752d Mon Sep 17 00:00:00 2001
From: amy-why-3459
Date: Mon, 13 Apr 2026 19:03:13 +0800
Subject: [PATCH 16/76] [Tests][Qwen3-Omni]Modify Qwen3-Omni performance test
cases (#2600)
Signed-off-by: amy-why-3459
---
tests/dfx/perf/scripts/run_benchmark.py | 2 +
tests/dfx/perf/tests/test.json | 305 +++++++++++++++++-------
2 files changed, 219 insertions(+), 88 deletions(-)
diff --git a/tests/dfx/perf/scripts/run_benchmark.py b/tests/dfx/perf/scripts/run_benchmark.py
index c566c2e0a0a..b64cc0d9503 100644
--- a/tests/dfx/perf/scripts/run_benchmark.py
+++ b/tests/dfx/perf/scripts/run_benchmark.py
@@ -72,6 +72,8 @@ def run_benchmark(
["vllm", "bench", "serve", "--omni"]
+ args
+ [
+ "--num-warmups",
+ "2",
"--save-result",
"--result-dir",
os.environ.get("BENCHMARK_DIR", "tests"),
diff --git a/tests/dfx/perf/tests/test.json b/tests/dfx/perf/tests/test.json
index fe7e3804698..159e27a064b 100644
--- a/tests/dfx/perf/tests/test.json
+++ b/tests/dfx/perf/tests/test.json
@@ -10,83 +10,97 @@
"dataset_name": "random",
"backend": "openai-chat-omni",
"endpoint": "/v1/chat/completions",
- "num_prompts": [
- 10,
- 40,
- 100
- ],
- "max_concurrency": [
- 1,
- 4,
- 10
- ],
+ "num_prompts": [4, 16, 40],
+ "max_concurrency": [1, 4, 10],
+ "random_input_len": 2500,
+ "random_output_len": 900,
+ "ignore_eos": true,
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [1000, 3000, 5000],
+ "mean_audio_ttfp_ms": [30000, 60000, 90000],
+ "mean_audio_rtf": [0.35, 0.45, 0.55]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [10],
+ "request_rate": [0.1],
"random_input_len": 100,
"random_output_len": 100,
+ "random_range_ratio": 0.0,
"ignore_eos": true,
+ "random_mm_base_items_per_request": 1,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "audio": 1
+ },
+ "random_mm_bucket_config": {
+ "(0, 60, 3)": 1.0
+ },
"percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
"baseline": {
- "mean_ttft_ms": [1000, 3000, 5000],
- "mean_audio_ttfp_ms": [8000, 10000, 13000],
- "mean_audio_rtf": [0.2, 0.25, 0.45]
+ "mean_ttft_ms": [2000],
+ "mean_audio_ttfp_ms": [10000],
+ "mean_audio_rtf": [0.25]
}
},
{
"dataset_name": "random-mm",
"backend": "openai-chat-omni",
"endpoint": "/v1/chat/completions",
- "num_prompts": [
- 10,
- 40,
- 100
- ],
- "request_rate": [
- 0.1,
- 0.3,
- 0.5
- ],
+ "num_prompts": [40],
+ "request_rate": [0.3],
"random_input_len": 100,
"random_output_len": 100,
"random_range_ratio": 0.0,
"ignore_eos": true,
- "random_mm_base_items_per_request": 3,
- "random_mm_num_mm_items_range_ratio": 0,
+ "random_mm_base_items_per_request": 2,
+ "random_mm_num_mm_items_range_ratio": 0.5,
"random_mm_limit_mm_per_prompt": {
"image": 1,
- "video": 1,
- "audio": 1
+ "video": 1
},
"random_mm_bucket_config": {
- "(32, 32, 1)": 0.5,
- "(0, 1, 1)": 0.1,
- "(32, 32, 2)": 0.4
+ "(256, 256, 1)": 0.5,
+ "(720, 1280, 2)": 0.5
},
"percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
"baseline": {
- "mean_ttft_ms": [2000, 4000, 6000],
- "mean_audio_ttfp_ms": [10000, 13000, 15000],
- "mean_audio_rtf": [0.25, 0.35, 0.45]
+ "mean_ttft_ms": [4000],
+ "mean_audio_ttfp_ms": [13000],
+ "mean_audio_rtf": [0.35]
}
},
{
- "dataset_name": "random",
+ "dataset_name": "random-mm",
"backend": "openai-chat-omni",
"endpoint": "/v1/chat/completions",
- "num_prompts": [
- 4,
- 16
- ],
- "max_concurrency": [
- 1,
- 4
- ],
- "random_input_len": 2500,
- "random_output_len": 900,
+ "num_prompts": [100],
+ "request_rate": [0.5],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
"ignore_eos": true,
+ "random_mm_base_items_per_request": 3,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "image": 1,
+ "video": 1,
+ "audio": 1
+ },
+ "random_mm_bucket_config": {
+ "(256, 256, 1)": 0.34,
+ "(720, 1280, 2)": 0.33,
+ "(0, 60, 3)": 0.33
+ },
"percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
"baseline": {
- "mean_ttft_ms": [1000, 3000],
- "mean_audio_ttfp_ms": [30000, 60000],
- "mean_audio_rtf": [0.35, 0.45]
+ "mean_ttft_ms": [6000],
+ "mean_audio_ttfp_ms": [15000],
+ "mean_audio_rtf": [0.45]
}
}
]
@@ -120,18 +134,10 @@
"dataset_name": "random",
"backend": "openai-chat-omni",
"endpoint": "/v1/chat/completions",
- "num_prompts": [
- 10,
- 40,
- 100
- ],
- "max_concurrency": [
- 1,
- 4,
- 10
- ],
- "random_input_len": 100,
- "random_output_len": 100,
+ "num_prompts": [4, 16, 40],
+ "max_concurrency": [1, 4, 10],
+ "random_input_len": 2500,
+ "random_output_len": 900,
"ignore_eos": true,
"percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
"baseline": {
@@ -144,59 +150,182 @@
"dataset_name": "random-mm",
"backend": "openai-chat-omni",
"endpoint": "/v1/chat/completions",
- "num_prompts": [
- 10,
- 40,
- 100
- ],
- "request_rate": [
- 0.1,
- 0.3,
- 0.5
- ],
+ "num_prompts": [10],
+ "request_rate": [0.1],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "random_mm_base_items_per_request": 1,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "audio": 1
+ },
+ "random_mm_bucket_config": {
+ "(0, 60, 3)": 1.0
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [2000],
+ "mean_audio_ttfp_ms": [2000],
+ "mean_audio_rtf": [0.25]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [40],
+ "request_rate": [0.3],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "random_mm_base_items_per_request": 2,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "image": 1,
+ "video": 1
+ },
+ "random_mm_bucket_config": {
+ "(256, 256, 1)": 0.5,
+ "(720, 1280, 2)": 0.5
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [4000],
+ "mean_audio_ttfp_ms": [4000],
+ "mean_audio_rtf": [0.4]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [100],
+ "request_rate": [0.5],
"random_input_len": 100,
"random_output_len": 100,
"random_range_ratio": 0.0,
"ignore_eos": true,
"random_mm_base_items_per_request": 3,
- "random_mm_num_mm_items_range_ratio": 0,
+ "random_mm_num_mm_items_range_ratio": 0.5,
"random_mm_limit_mm_per_prompt": {
"image": 1,
"video": 1,
"audio": 1
},
"random_mm_bucket_config": {
- "(32, 32, 1)": 0.5,
- "(0, 1, 1)": 0.1,
- "(32, 32, 2)": 0.4
+ "(256, 256, 1)": 0.34,
+ "(720, 1280, 2)": 0.33,
+ "(0, 60, 3)": 0.33
},
"percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
"baseline": {
- "mean_ttft_ms": [2000, 4000, 6000],
- "mean_audio_ttfp_ms": [2000, 4000, 6000],
- "mean_audio_rtf": [0.25, 0.4, 0.7]
+ "mean_ttft_ms": [6000],
+ "mean_audio_ttfp_ms": [6000],
+ "mean_audio_rtf": [0.7]
}
},
{
"dataset_name": "random",
"backend": "openai-chat-omni",
"endpoint": "/v1/chat/completions",
- "num_prompts": [
- 4,
- 16
- ],
- "max_concurrency": [
- 1,
- 4
- ],
+ "num_prompts": [4, 16, 40],
+ "max_concurrency": [1, 4, 10],
"random_input_len": 2500,
"random_output_len": 900,
"ignore_eos": true,
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "extra_body": {
+ "modalities": ["text"]
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el",
+ "baseline": {
+ "mean_ttft_ms": [1000, 3000, 5000]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [10],
+ "request_rate": [0.1],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "extra_body": {
+ "modalities": ["text"]
+ },
+ "random_mm_base_items_per_request": 1,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "audio": 1
+ },
+ "random_mm_bucket_config": {
+ "(0, 60, 3)": 1.0
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el",
+ "baseline": {
+ "mean_ttft_ms": [2000]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [40],
+ "request_rate": [0.3],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "extra_body": {
+ "modalities": ["text"]
+ },
+ "random_mm_base_items_per_request": 2,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "image": 1,
+ "video": 1
+ },
+ "random_mm_bucket_config": {
+ "(256, 256, 1)": 0.5,
+ "(720, 1280, 2)": 0.5
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el",
+ "baseline": {
+ "mean_ttft_ms": [4000]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [100],
+ "request_rate": [0.5],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "extra_body": {
+ "modalities": ["text"]
+ },
+ "random_mm_base_items_per_request": 3,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "image": 1,
+ "video": 1,
+ "audio": 1
+ },
+ "random_mm_bucket_config": {
+ "(256, 256, 1)": 0.34,
+ "(720, 1280, 2)": 0.33,
+ "(0, 60, 3)": 0.33
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el",
"baseline": {
- "mean_ttft_ms": [1000, 3000],
- "mean_audio_ttfp_ms": [1000, 3000],
- "mean_audio_rtf": [0.35, 0.45]
+ "mean_ttft_ms": [6000]
}
}
]
From 2c67c30550ad91e62a5919b0008caba459a09049 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?=
Date: Mon, 13 Apr 2026 19:15:49 +0800
Subject: [PATCH 17/76] [Bagel]: Support `think mode` in single stage
deployment of Bagel (#2650)
Signed-off-by: princepride
---
examples/offline_inference/bagel/end2end.py | 98 ++++++++----
.../models/bagel/bagel_transformer.py | 113 +++++++++++++-
.../diffusion/models/bagel/pipeline_bagel.py | 146 +++++++++++++++---
3 files changed, 301 insertions(+), 56 deletions(-)
diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py
index 472d748d1e6..ed5fa57e8d6 100644
--- a/examples/offline_inference/bagel/end2end.py
+++ b/examples/offline_inference/bagel/end2end.py
@@ -97,6 +97,24 @@ def parse_args():
default=False,
help="Enable thinking mode: AR stage decodes ... planning tokens before image generation.",
)
+ parser.add_argument(
+ "--max-think-tokens",
+ type=int,
+ default=1000,
+ help="Maximum number of tokens for thinking text generation (default: 1000).",
+ )
+ parser.add_argument(
+ "--do-sample",
+ action="store_true",
+ default=False,
+ help="Enable sampling for text generation (default: greedy).",
+ )
+ parser.add_argument(
+ "--text-temperature",
+ type=float,
+ default=0.3,
+ help="Temperature for text generation sampling (default: 0.3).",
+ )
args = parser.parse_args()
return args
@@ -108,7 +126,6 @@ def main():
model_name = args.model
prompts: list[OmniPromptType] = []
try:
- # Preferred: load from txt file (one prompt per line)
if getattr(args, "txt_prompts", None) and args.prompt_type == "text":
with open(args.txt_prompts, encoding="utf-8") as f:
lines = [ln.strip() for ln in f.readlines()]
@@ -121,10 +138,8 @@ def main():
raise
if not prompts:
- # Default prompt for text2img test if none provided
prompts = ["A cute cat"]
print(f"[Info] No prompts provided, using default: {prompts}")
- omni_outputs = []
from PIL import Image
@@ -132,11 +147,13 @@ def main():
omni_kwargs = {}
stage_configs_path = args.stage_configs_path
+ is_single_stage = stage_configs_path and "single_stage" in stage_configs_path
if args.think and stage_configs_path is None:
stage_configs_path = "vllm_omni/model_executor/stage_configs/bagel_think.yaml"
print(f"[Info] Think mode enabled, using stage config: {stage_configs_path}")
if stage_configs_path:
omni_kwargs["stage_configs_path"] = stage_configs_path
+ is_single_stage = "single_stage" in stage_configs_path
omni_kwargs.update(
{
@@ -198,40 +215,61 @@ def main():
formatted_prompts.append(prompt_dict)
params_list = omni.default_sampling_params_list
+
+ # For single-stage DiT, think/text params go into the diffusion sampling params extra_args.
+ # For 2-stage, diffusion params are at index 1.
+ diffusion_params_idx = 0 if is_single_stage else (1 if len(params_list) > 1 else 0)
+ diffusion_params = params_list[diffusion_params_idx]
+
if args.modality in ("text2img", "img2img"):
- if len(params_list) > 1:
- diffusion_params = params_list[1]
- diffusion_params.num_inference_steps = args.steps # type: ignore
- diffusion_params.cfg_parallel_size = args.cfg_parallel_size # type: ignore
- if args.seed is not None:
- diffusion_params.seed = args.seed # type: ignore
- extra = {
- "cfg_text_scale": args.cfg_text_scale,
- "cfg_img_scale": args.cfg_img_scale,
- }
- if args.cfg_interval is not None:
- extra["cfg_interval"] = tuple(args.cfg_interval)
- if args.cfg_renorm_type is not None:
- extra["cfg_renorm_type"] = args.cfg_renorm_type
- if args.cfg_renorm_min is not None:
- extra["cfg_renorm_min"] = args.cfg_renorm_min
- if args.negative_prompt is not None:
- extra["negative_prompt"] = args.negative_prompt
- diffusion_params.extra_args = extra # type: ignore
+ diffusion_params.num_inference_steps = args.steps # type: ignore
+ diffusion_params.cfg_parallel_size = args.cfg_parallel_size # type: ignore
+ if args.seed is not None:
+ diffusion_params.seed = args.seed # type: ignore
+
+ extra = getattr(diffusion_params, "extra_args", {}) or {}
+ extra["cfg_text_scale"] = args.cfg_text_scale
+ extra["cfg_img_scale"] = args.cfg_img_scale
+ if args.cfg_interval is not None:
+ extra["cfg_interval"] = tuple(args.cfg_interval)
+ if args.cfg_renorm_type is not None:
+ extra["cfg_renorm_type"] = args.cfg_renorm_type
+ if args.cfg_renorm_min is not None:
+ extra["cfg_renorm_min"] = args.cfg_renorm_min
+ if args.negative_prompt is not None:
+ extra["negative_prompt"] = args.negative_prompt
+
+ needs_text_gen = is_single_stage and (args.think or args.modality in ("text2text", "img2text"))
+ if needs_text_gen:
+ if args.think:
+ extra["think"] = True
+ extra["max_think_tokens"] = args.max_think_tokens
+ extra["do_sample"] = args.do_sample
+ extra["text_temperature"] = args.text_temperature
+ diffusion_params.extra_args = extra # type: ignore
omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))
img_idx = 0
for req_output in omni_outputs:
- if args.think:
- ro = getattr(req_output, "request_output", None)
- if ro and getattr(ro, "outputs", None):
- txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs)
- if txt:
- print(txt)
+ # 2-stage think mode: text output from thinker stage
+ ro = getattr(req_output, "request_output", None)
+ if ro and getattr(ro, "outputs", None):
+ txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs)
+ if txt:
+ if args.think:
+ print(f"[Think]\n{txt}")
+ else:
+ print(f"[Output] Text:\n{txt}")
- images = getattr(req_output, "images", None)
+ # Single-stage DiT: text from custom_output
+ custom = getattr(req_output, "_custom_output", {}) or {}
+ if custom.get("think_text"):
+ print(f"[Think]\n{custom['think_text']}")
+ if custom.get("text_output"):
+ print(f"[Output] Text:\n{custom['text_output']}")
+ images = getattr(req_output, "images", None)
if not images:
continue
@@ -241,8 +279,6 @@ def main():
print(f"[Output] Saved image to {save_path}")
img_idx += 1
- print(omni_outputs)
-
if __name__ == "__main__":
main()
diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py
index f8480775687..d1254f84566 100644
--- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py
+++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py
@@ -854,6 +854,7 @@ def __init__(
config, parallel_config=parallel_config, quant_config=quant_config, prefix=f"{prefix}.model"
)
self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@@ -864,6 +865,12 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.embed_tokens = value
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
def set_decoder(self, decoder):
self.model = decoder
@@ -1207,7 +1214,7 @@ def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
- text_ids = tokenizer.encode(prompt)
+ text_ids = tokenizer.encode(prompt, add_special_tokens=False)
text_ids = [new_token_ids["bos_token_id"]] + text_ids + [new_token_ids["eos_token_id"]]
text_token_lens.append(len(text_ids))
packed_text_ids.extend(text_ids)
@@ -1619,10 +1626,110 @@ def _merge_naive_caches(caches: list) -> NaiveCache:
num_layers = len(caches[0].key_cache)
merged = NaiveCache(num_layers)
for layer_idx in range(num_layers):
- merged.key_cache[layer_idx] = torch.cat([c.key_cache[layer_idx] for c in caches], dim=0)
- merged.value_cache[layer_idx] = torch.cat([c.value_cache[layer_idx] for c in caches], dim=0)
+ key_parts = [c.key_cache[layer_idx] for c in caches if c.key_cache[layer_idx] is not None]
+ val_parts = [c.value_cache[layer_idx] for c in caches if c.value_cache[layer_idx] is not None]
+ merged.key_cache[layer_idx] = torch.cat(key_parts, dim=0) if key_parts else None
+ merged.value_cache[layer_idx] = torch.cat(val_parts, dim=0) if val_parts else None
return merged
+ def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
+ """Prepare start tokens for autoregressive text generation.
+
+ Ported from the original BAGEL ``Bagel.prepare_start_tokens``.
+ """
+ packed_start_tokens, packed_key_value_indexes = list(), list()
+ packed_query_position_ids = list()
+
+ curr = 0
+ for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
+ packed_start_tokens.append(new_token_ids["bos_token_id"])
+ packed_query_position_ids.append(curr_position_id)
+ curr += curr_kvlen
+
+ generation_input = {
+ "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
+ "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long),
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
+ }
+ return generation_input
+
+ @torch.no_grad()
+ def generate_text(
+ self,
+ past_key_values: NaiveCache,
+ packed_key_value_indexes: torch.LongTensor,
+ key_values_lens: torch.IntTensor,
+ packed_start_tokens: torch.LongTensor,
+ packed_query_position_ids: torch.LongTensor,
+ max_length: int,
+ do_sample: bool = False,
+ temperature: float = 1.0,
+ end_token_id: int | None = None,
+ ):
+ """Autoregressive text generation (ported from original BAGEL).
+
+ Decodes tokens one at a time, appending to ``past_key_values``
+ until ``max_length`` is reached or ``end_token_id`` is generated.
+ """
+ step = 0
+ generated_sequence = []
+ curr_tokens = packed_start_tokens
+ while step < max_length:
+ generated_sequence.append(curr_tokens)
+ packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
+ query_lens = torch.ones_like(curr_tokens)
+ packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
+ 0,
+ len(key_values_lens),
+ device=key_values_lens.device,
+ dtype=key_values_lens.dtype,
+ )
+
+ uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
+ for i in range(len(uppacked)):
+ uppacked[i] += i
+ packed_key_value_indexes = torch.cat(uppacked, dim=0)
+
+ output = self.language_model(
+ packed_query_sequence=packed_text_embedding,
+ query_lens=query_lens,
+ packed_query_position_ids=packed_query_position_ids,
+ packed_query_indexes=packed_query_indexes,
+ past_key_values=past_key_values,
+ key_values_lens=key_values_lens,
+ packed_key_value_indexes=packed_key_value_indexes,
+ update_past_key_values=True,
+ is_causal=True,
+ mode="und",
+ )
+ past_key_values = output.past_key_values
+ packed_query_sequence = output.packed_query_sequence
+ pred_logits = self.language_model.lm_head(packed_query_sequence)
+
+ if do_sample:
+ probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
+ curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ curr_tokens = torch.argmax(pred_logits, dim=-1)
+
+ uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
+ for i in range(len(uppacked)):
+ uppacked[i] = torch.cat(
+ [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0
+ )
+ packed_key_value_indexes = torch.cat(uppacked, dim=0)
+ key_values_lens = key_values_lens + 1
+ packed_query_position_ids = packed_query_position_ids + 1
+ step += 1
+
+ if end_token_id is not None and curr_tokens[0] == end_token_id:
+ break
+
+ output_device = generated_sequence[0].device
+ return torch.stack([i.to(output_device) for i in generated_sequence], dim=0)
+
def generate_image(
self,
packed_text_ids: torch.LongTensor,
diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
index 13d0cc2093b..72e53e7f48f 100644
--- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
+++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
@@ -495,11 +495,15 @@ def vae_transforms(img):
cfg_text_context = deepcopy(gen_context)
+ # Strip <|im_start|>/<|im_end|> wrappers that end2end.py may have
+ # already added, so prepare_prompts doesn't double-add bos/eos.
+ clean_prompt = prompt.removeprefix("<|im_start|>").removesuffix("<|im_end|>")
+
# Update gen_context with text prompt
generation_input, newlens, new_rope = self.bagel.prepare_prompts(
curr_kvlens=gen_context["kv_lens"],
curr_rope=gen_context["ropes"],
- prompts=[prompt],
+ prompts=[clean_prompt],
tokenizer=self.tokenizer,
new_token_ids=self.new_token_ids,
)
@@ -527,34 +531,37 @@ def vae_transforms(img):
gen_context["kv_lens"] = newlens
gen_context["ropes"] = new_rope
- # cfg_text_context: update with negative prompt (no text condition)
+ # cfg_text_context: update with negative prompt (no text condition).
+ # When empty, keep cfg_text_context as-is (kv_lens=0) to match
+ # original BAGEL; _merge_naive_caches handles None KV entries.
neg_prompt = extra_args.get("negative_prompt", "")
- neg_input, neg_newlens, neg_rope = self.bagel.prepare_prompts(
- curr_kvlens=cfg_text_context["kv_lens"],
- curr_rope=cfg_text_context["ropes"],
- prompts=[neg_prompt],
- tokenizer=self.tokenizer,
- new_token_ids=self.new_token_ids,
- )
- for k, v in neg_input.items():
- if torch.is_tensor(v):
- neg_input[k] = v.to(self.device)
- with torch.autocast(
- device_type=self.device.type,
- enabled=self.device.type != "cpu",
- dtype=self.od_config.dtype,
- ):
- cfg_text_context["past_key_values"] = self.bagel.forward_cache_update_text(
- cfg_text_context["past_key_values"], **neg_input
+ if neg_prompt:
+ neg_input, neg_newlens, neg_rope = self.bagel.prepare_prompts(
+ curr_kvlens=cfg_text_context["kv_lens"],
+ curr_rope=cfg_text_context["ropes"],
+ prompts=[neg_prompt],
+ tokenizer=self.tokenizer,
+ new_token_ids=self.new_token_ids,
)
- cfg_text_context["kv_lens"] = neg_newlens
- cfg_text_context["ropes"] = neg_rope
+ for k, v in neg_input.items():
+ if torch.is_tensor(v):
+ neg_input[k] = v.to(self.device)
+ with torch.autocast(
+ device_type=self.device.type,
+ enabled=self.device.type != "cpu",
+ dtype=self.od_config.dtype,
+ ):
+ cfg_text_context["past_key_values"] = self.bagel.forward_cache_update_text(
+ cfg_text_context["past_key_values"], **neg_input
+ )
+ cfg_text_context["kv_lens"] = neg_newlens
+ cfg_text_context["ropes"] = neg_rope
# cfg_img_context: update with text prompt (no image condition)
cfg_img_generation_input, cfg_img_newlens, cfg_img_new_rope = self.bagel.prepare_prompts(
curr_kvlens=cfg_img_context["kv_lens"],
curr_rope=cfg_img_context["ropes"],
- prompts=[prompt],
+ prompts=[clean_prompt],
tokenizer=self.tokenizer,
new_token_ids=self.new_token_ids,
)
@@ -572,6 +579,96 @@ def vae_transforms(img):
cfg_img_context["kv_lens"] = cfg_img_newlens
cfg_img_context["ropes"] = cfg_img_new_rope
+ # ---- Detect output modality and think mode ----
+ modalities = first_prompt.get("modalities", []) if isinstance(first_prompt, dict) else []
+ is_text_output = "text" in modalities
+ think_enabled = extra_args.get("think", False)
+ think_text = None
+
+ if think_enabled and injected_kv is None:
+ max_think_tokens = int(extra_args.get("max_think_tokens", 1000))
+ do_sample = bool(extra_args.get("do_sample", False))
+ text_temperature = float(extra_args.get("text_temperature", 0.3))
+
+ with torch.autocast(
+ device_type=self.device.type,
+ enabled=self.device.type != "cpu",
+ dtype=self.od_config.dtype,
+ ):
+ start_input = self.bagel.prepare_start_tokens(
+ gen_context["kv_lens"], gen_context["ropes"], self.new_token_ids
+ )
+ for k, v in start_input.items():
+ if torch.is_tensor(v):
+ start_input[k] = v.to(self.device)
+
+ gen_ctx_copy = deepcopy(gen_context)
+ token_ids = self.bagel.generate_text(
+ past_key_values=gen_ctx_copy["past_key_values"],
+ max_length=max_think_tokens,
+ do_sample=do_sample,
+ temperature=text_temperature,
+ end_token_id=self.new_token_ids["eos_token_id"],
+ **start_input,
+ )
+ # token_ids shape: (seq_len, batch=1)
+ decoded = self.tokenizer.decode(token_ids[:, 0].tolist())
+ # Strip chat markers to get clean text
+ think_text = decoded.split("<|im_end|>")[0]
+ if "<|im_start|>" in think_text:
+ think_text = think_text.split("<|im_start|>")[-1]
+ logger.info("Think mode generated %d tokens", token_ids.shape[0])
+
+ if not is_text_output:
+ # Use the autoregressive KV cache from think generation
+ # directly, instead of decode→re-encode which adds extra
+ # bos/eos and may alter tokenization.
+ num_think_tokens = token_ids.shape[0]
+ gen_context["past_key_values"] = gen_ctx_copy["past_key_values"]
+ gen_context["kv_lens"] = [kl + num_think_tokens for kl in gen_context["kv_lens"]]
+ gen_context["ropes"] = [r + num_think_tokens for r in gen_context["ropes"]]
+
+ # ---- Text-only output (text2text / img2text) ----
+ if is_text_output and injected_kv is None:
+ if think_text is not None:
+ # Think mode already generated the text (including reasoning)
+ text_output = think_text
+ else:
+ max_text_tokens = int(extra_args.get("max_think_tokens", 500))
+ do_sample = bool(extra_args.get("do_sample", False))
+ text_temperature = float(extra_args.get("text_temperature", 0.3))
+
+ with torch.autocast(
+ device_type=self.device.type,
+ enabled=self.device.type != "cpu",
+ dtype=self.od_config.dtype,
+ ):
+ start_input = self.bagel.prepare_start_tokens(
+ gen_context["kv_lens"], gen_context["ropes"], self.new_token_ids
+ )
+ for k, v in start_input.items():
+ if torch.is_tensor(v):
+ start_input[k] = v.to(self.device)
+ token_ids = self.bagel.generate_text(
+ past_key_values=gen_context["past_key_values"],
+ max_length=max_text_tokens,
+ do_sample=do_sample,
+ temperature=text_temperature,
+ end_token_id=self.new_token_ids["eos_token_id"],
+ **start_input,
+ )
+ decoded = self.tokenizer.decode(token_ids[:, 0].tolist())
+ text_output = decoded.split("<|im_end|>")[0]
+ if "<|im_start|>" in text_output:
+ text_output = text_output.split("<|im_start|>")[-1]
+
+ return DiffusionOutput(
+ output=text_output,
+ custom_output={"text_output": text_output},
+ stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None,
+ )
+
+ # ---- Image generation (text2img / img2img) ----
if req.sampling_params.seed is not None:
torch.manual_seed(req.sampling_params.seed)
if self.device.type == "cuda":
@@ -676,12 +773,17 @@ def vae_transforms(img):
if trajectory_log_probs:
trajectory_log_probs_stacked = torch.stack(trajectory_log_probs)
+ custom = {}
+ if think_text is not None:
+ custom["think_text"] = think_text
+
return DiffusionOutput(
output=img,
trajectory_latents=trajectory_latents_stacked,
trajectory_timesteps=trajectory_timesteps_stacked,
trajectory_log_probs=trajectory_log_probs_stacked,
trajectory_decoded=trajectory_decoded,
+ custom_output=custom,
stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None,
)
From e0cdbe9a5d7ec654bbbe26c2fb6e76abe41446d2 Mon Sep 17 00:00:00 2001
From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Date: Mon, 13 Apr 2026 19:21:42 +0800
Subject: [PATCH 18/76] [Misc] Cleanup: use consistent pytest-mock in unit
tests (#2698)
Signed-off-by: yuanheng
---
tests/comfyui/conftest.py | 18 +-
tests/comfyui/test_comfyui_integration.py | 95 +-
.../test_generation_scheduler_restore.py | 27 +-
.../test_distributed_vae_executor.py | 41 +-
.../models/bagel/test_trajectory_recording.py | 34 +-
.../models/flux2/test_flux2_transformer_tp.py | 20 +-
.../offloader/test_sequential_backend.py | 120 +-
.../quantization/test_int8_config.py | 32 +-
tests/diffusion/test_diffusion_scheduler.py | 103 +-
.../diffusion/test_diffusion_step_pipeline.py | 26 +-
.../test_diffusion_worker_cuda_profiler.py | 6 +-
.../test_multiproc_engine_concurrency.py | 28 +-
tests/engine/test_arg_utils.py | 9 +-
tests/engine/test_async_omni_engine_input.py | 15 +-
.../engine/test_async_omni_engine_outputs.py | 20 +-
tests/engine/test_single_stage_mode.py | 1533 ++++++++++-------
.../openai_api/test_serving_chat_speaker.py | 40 +-
.../openai_api/test_serving_speech.py | 215 ++-
.../openai_api/test_serving_speech_stream.py | 117 +-
tests/entrypoints/test_omni_base_profiler.py | 27 +-
tests/entrypoints/test_serve.py | 188 +-
.../test_mimo_audio_code2wav_batch_decode.py | 40 +-
.../qwen2_5_omni/test_qwen2_5_omni_embed.py | 37 +-
.../qwen3_tts/test_code_predictor_dtype.py | 131 +-
.../models/test_fish_speech_voice_cache.py | 30 +-
tests/test_fish_speech_voice_cache.py | 39 +-
26 files changed, 1610 insertions(+), 1381 deletions(-)
diff --git a/tests/comfyui/conftest.py b/tests/comfyui/conftest.py
index 0b4565e9465..4280d3506ff 100644
--- a/tests/comfyui/conftest.py
+++ b/tests/comfyui/conftest.py
@@ -9,8 +9,8 @@
import os
import sys
+from types import ModuleType, SimpleNamespace
from typing import BinaryIO, TypedDict
-from unittest.mock import MagicMock
def pytest_configure(config):
@@ -58,15 +58,15 @@ def save_to(self, file: str | BinaryIO):
else:
file.write(self._data)
- mock_comfy_api = MagicMock()
- mock_comfy_api_input = MagicMock()
+ mock_comfy_api = ModuleType("comfy_api")
+ mock_comfy_api_input = ModuleType("comfy_api.input")
mock_comfy_api_input.AudioInput = AudioInput
mock_comfy_api_input.VideoInput = VideoInput
mock_comfy_api.input = mock_comfy_api_input
- mock_comfy_api_latest = MagicMock()
- mock_comfy_api_latest.Types.VideoComponents = MagicMock(side_effect=lambda **kwargs: kwargs)
- mock_comfy_api_latest.InputImpl.VideoFromComponents = MagicMock(
- side_effect=lambda _: VideoInput(b"mock_video_from_components")
+ mock_comfy_api_latest = ModuleType("comfy_api.latest")
+ mock_comfy_api_latest.Types = SimpleNamespace(VideoComponents=lambda **kwargs: kwargs)
+ mock_comfy_api_latest.InputImpl = SimpleNamespace(
+ VideoFromComponents=lambda _: VideoInput(b"mock_video_from_components")
)
mock_comfy_api.latest = mock_comfy_api_latest
@@ -76,8 +76,8 @@ def mock_load(_: str | BinaryIO):
sample_rate = 24000
return waveform, sample_rate
- mock_comfy_extras = MagicMock()
- mock_nodes_audio = MagicMock()
+ mock_comfy_extras = ModuleType("comfy_extras")
+ mock_nodes_audio = ModuleType("comfy_extras.nodes_audio")
mock_nodes_audio.load = mock_load
mock_comfy_extras.nodes_audio = mock_nodes_audio
diff --git a/tests/comfyui/test_comfyui_integration.py b/tests/comfyui/test_comfyui_integration.py
index f6ce82f9b28..80e86d82412 100644
--- a/tests/comfyui/test_comfyui_integration.py
+++ b/tests/comfyui/test_comfyui_integration.py
@@ -13,7 +13,6 @@
from enum import StrEnum, auto
from types import SimpleNamespace
from typing import Any, NamedTuple
-from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import requests
@@ -28,6 +27,7 @@
)
from comfyui_vllm_omni.utils.types import AutoregressionSamplingParams, DiffusionSamplingParams, WanModelSpecificParams
from PIL import Image
+from pytest_mock import MockerFixture
from vllm import SamplingParams
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.utils.argparse_utils import FlexibleArgumentParser
@@ -217,9 +217,10 @@ def _build_diffusion_video_output() -> OmniRequestOutput:
def _build_diffusion_image_output_for_chat_endpoint() -> OmniRequestOutput:
- request_output = MagicMock()
- request_output.images = [_build_image_output(color="blue")]
- request_output.finished = True
+ request_output = SimpleNamespace(
+ images=[_build_image_output(color="blue")],
+ finished=True,
+ )
return OmniRequestOutput(
request_id="test_req_img_chat",
finished=True,
@@ -389,51 +390,55 @@ def sampling_case(request) -> SamplingCase:
@pytest.fixture
-def mock_async_omni(server_case: ServerCase, sampling_case: SamplingCase):
+def mock_async_omni(
+ server_case: ServerCase,
+ sampling_case: SamplingCase,
+ monkeypatch: pytest.MonkeyPatch,
+ mocker: MockerFixture,
+):
async def _mock_preprocess_chat(self, *args, **kwargs):
return ([{"role": "user", "content": "test"}], [{"prompt": "test prompt"}])
# Need to mock AsyncOmni itself (not only its generate method) because
# 1. The API layer uses its stage_list and stage_configs attributes
# 2. Its __init__ method has slow side effects (model & config loading).
- with (
- patch("vllm_omni.entrypoints.openai.api_server.AsyncOmni") as MockAsyncOmni,
- patch(
- "vllm_omni.entrypoints.openai.serving_chat.OmniOpenAIServingChat._preprocess_chat",
- new=_mock_preprocess_chat,
- ),
- ):
- mock_instance = AsyncMock(spec=RealAsyncOmni)
- mock_instance.generate = _build_mock_outputs(server_case.outputs, sampling_case, server_case)
-
- mock_instance.stage_list = server_case.stage_list
- mock_instance.stage_configs = server_case.stage_configs
- mock_instance.output_modalities = _build_output_modalities(server_case.stage_configs)
- mock_instance.default_sampling_params_list = [
- SamplingParams() if _stage_type(stage) != "diffusion" else MagicMock()
- for stage in server_case.stage_configs
- ]
- mock_instance.errored = False
- mock_instance.dead_error = RuntimeError("Mock engine error")
- mock_instance.model_config = MagicMock(
- max_model_len=4096,
- io_processor_plugin=None,
- allowed_local_media_path=None,
- allowed_media_domains=None,
- )
- # Mimic Qwen3-TTS talker speaker config so CustomVoice validation passes.
- mock_instance.model_config.hf_config = MagicMock()
- mock_instance.model_config.hf_config.talker_config = MagicMock()
- mock_instance.model_config.hf_config.talker_config.speaker_id = {"Vivian": 0}
- mock_instance.io_processor = MagicMock()
- mock_instance.input_processor = MagicMock()
- mock_instance.shutdown = MagicMock()
- mock_instance.get_vllm_config = AsyncMock(return_value=None)
- mock_instance.get_supported_tasks = AsyncMock(return_value=["generate"])
- mock_instance.get_tokenizer = AsyncMock(return_value=None)
+ mock_async_omni_cls = mocker.patch("vllm_omni.entrypoints.openai.api_server.AsyncOmni")
+ monkeypatch.setattr(
+ "vllm_omni.entrypoints.openai.serving_chat.OmniOpenAIServingChat._preprocess_chat",
+ _mock_preprocess_chat,
+ )
+
+ mock_instance = mocker.AsyncMock(spec=RealAsyncOmni)
+ mock_instance.generate = _build_mock_outputs(server_case.outputs, sampling_case, server_case)
+
+ mock_instance.stage_list = server_case.stage_list
+ mock_instance.stage_configs = server_case.stage_configs
+ mock_instance.output_modalities = _build_output_modalities(server_case.stage_configs)
+ mock_instance.default_sampling_params_list = [
+ SamplingParams() if _stage_type(stage) != "diffusion" else mocker.MagicMock()
+ for stage in server_case.stage_configs
+ ]
+ mock_instance.errored = False
+ mock_instance.dead_error = RuntimeError("Mock engine error")
+ mock_instance.model_config = mocker.MagicMock(
+ max_model_len=4096,
+ io_processor_plugin=None,
+ allowed_local_media_path=None,
+ allowed_media_domains=None,
+ )
+ # Mimic Qwen3-TTS talker speaker config so CustomVoice validation passes.
+ mock_instance.model_config.hf_config = mocker.MagicMock()
+ mock_instance.model_config.hf_config.talker_config = mocker.MagicMock()
+ mock_instance.model_config.hf_config.talker_config.speaker_id = {"Vivian": 0}
+ mock_instance.io_processor = mocker.MagicMock()
+ mock_instance.input_processor = mocker.MagicMock()
+ mock_instance.shutdown = mocker.MagicMock()
+ mock_instance.get_vllm_config = mocker.AsyncMock(return_value=None)
+ mock_instance.get_supported_tasks = mocker.AsyncMock(return_value=["generate"])
+ mock_instance.get_tokenizer = mocker.AsyncMock(return_value=None)
- MockAsyncOmni.return_value = mock_instance
- yield MockAsyncOmni
+ mock_async_omni_cls.return_value = mock_instance
+ yield mock_async_omni_cls
@pytest.fixture
@@ -583,9 +588,9 @@ async def test_image_generation_node(api_server: str, model: str, image_input: b
ServerCase(
served_model="Qwen/Qwen2.5-Omni-7B",
stage_list=[
- MagicMock(is_comprehension=True, model_stage="llm"),
- MagicMock(is_comprehension=False, model_stage="llm"),
- MagicMock(is_comprehension=False, model_stage="llm"),
+ SimpleNamespace(is_comprehension=True, model_stage="llm"),
+ SimpleNamespace(is_comprehension=False, model_stage="llm"),
+ SimpleNamespace(is_comprehension=False, model_stage="llm"),
],
stage_configs=[
_make_stage_config("llm", is_comprehension=True, model_stage="thinker"),
diff --git a/tests/core/sched/test_generation_scheduler_restore.py b/tests/core/sched/test_generation_scheduler_restore.py
index 154f40b3995..5cc1cab7025 100644
--- a/tests/core/sched/test_generation_scheduler_restore.py
+++ b/tests/core/sched/test_generation_scheduler_restore.py
@@ -6,7 +6,6 @@
those requests are permanently orphaned.
"""
-import unittest
from collections import deque
import pytest
@@ -39,7 +38,7 @@ def postprocess_scheduler_output(self, output):
pass
-class TestRestoreQueuesOnError(unittest.TestCase):
+class TestRestoreQueuesOnError:
"""Verify that restore_queues is called even when rewrapping raises."""
def test_requests_not_lost_on_exception(self):
@@ -52,8 +51,8 @@ def test_requests_not_lost_on_exception(self):
# Step 1: process_pending_chunks moves req-B out
adapter.process_pending_chunks(waiting=[], running=running)
- self.assertEqual(running, ["req-A"])
- self.assertEqual(len(adapter.waiting_for_chunk_running_requests), 1)
+ assert running == ["req-A"]
+ assert len(adapter.waiting_for_chunk_running_requests) == 1
# Step 2: simulate the try/except/finally pattern
try:
@@ -65,9 +64,9 @@ def test_requests_not_lost_on_exception(self):
adapter.restore_queues(waiting=[], running=running)
# Step 3: verify request is restored
- self.assertTrue(adapter.restore_called)
- self.assertIn("req-B", running)
- self.assertEqual(len(adapter.waiting_for_chunk_running_requests), 0)
+ assert adapter.restore_called is True
+ assert "req-B" in running
+ assert len(adapter.waiting_for_chunk_running_requests) == 0
def test_requests_lost_without_fix(self):
"""Demonstrate the bug: without restore in except, request is lost."""
@@ -76,7 +75,7 @@ def test_requests_lost_without_fix(self):
running = ["req-A", "req-B"]
adapter.process_pending_chunks(waiting=[], running=running)
- self.assertEqual(running, ["req-A"])
+ assert running == ["req-A"]
# Simulate the BUGGY code: except without restore
try:
@@ -85,8 +84,8 @@ def test_requests_lost_without_fix(self):
pass # Bug: no restore_queues call
# Request is lost!
- self.assertNotIn("req-B", running)
- self.assertEqual(len(adapter.waiting_for_chunk_running_requests), 1)
+ assert "req-B" not in running
+ assert len(adapter.waiting_for_chunk_running_requests) == 1
def test_happy_path_restores_via_finally(self):
"""When no exception, restore_queues is still called via finally."""
@@ -102,9 +101,5 @@ def test_happy_path_restores_via_finally(self):
finally:
adapter.restore_queues(waiting=[], running=running)
- self.assertTrue(adapter.restore_called)
- self.assertIn("req-B", running)
-
-
-if __name__ == "__main__":
- unittest.main()
+ assert adapter.restore_called is True
+ assert "req-B" in running
diff --git a/tests/diffusion/distributed/test_distributed_vae_executor.py b/tests/diffusion/distributed/test_distributed_vae_executor.py
index dc491dcdaf1..b2ee7c10d33 100644
--- a/tests/diffusion/distributed/test_distributed_vae_executor.py
+++ b/tests/diffusion/distributed/test_distributed_vae_executor.py
@@ -1,4 +1,4 @@
-from unittest.mock import MagicMock, patch
+from types import SimpleNamespace
import pytest
import torch
@@ -61,40 +61,31 @@ def merge(self, coord_tensor_map, grid_spec):
class DummyMixin(DistributedVaeMixin):
def __init__(self):
self.use_tiling = True
- self.distributed_executor = MagicMock()
- self.distributed_executor.parallel_size = 2
- self.distributed_executor.group = None
+ self.distributed_executor = SimpleNamespace(parallel_size=2, group=None)
@pytest.fixture(autouse=True)
-def mock_dist():
- with (
- patch.object(dist, "get_world_size", return_value=2),
- patch.object(dist, "get_rank", return_value=0),
- patch.object(dist, "is_initialized", return_value=True),
- patch.object(dist, "all_reduce", return_value=None),
- patch.object(dist, "gather", return_value=None),
- patch.object(dist, "broadcast", return_value=None),
- ):
- yield
+def mock_dist(monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr(dist, "get_world_size", lambda *args, **kwargs: 2)
+ monkeypatch.setattr(dist, "get_rank", lambda *args, **kwargs: 0)
+ monkeypatch.setattr(dist, "is_initialized", lambda: True)
+ monkeypatch.setattr(dist, "all_reduce", lambda *args, **kwargs: None)
+ monkeypatch.setattr(dist, "gather", lambda *args, **kwargs: None)
+ monkeypatch.setattr(dist, "broadcast", lambda *args, **kwargs: None)
@pytest.fixture(autouse=True)
-def mock_dit_group():
- with patch(
+def mock_dit_group(monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr(
"vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor.get_dit_group",
- new=MagicMock(return_value=None),
- ):
- yield
+ lambda: None,
+ )
@pytest.fixture(autouse=True)
-def mock_dist_vae_executor():
- with (
- patch.object(DistributedVaeExecutor, "gather_tensors", side_effect=lambda x: [x]),
- patch.object(DistributedVaeExecutor, "broadcast_tensor", side_effect=lambda x: x),
- ):
- yield
+def mock_dist_vae_executor(monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr(DistributedVaeExecutor, "gather_tensors", lambda self, x: [x])
+ monkeypatch.setattr(DistributedVaeExecutor, "broadcast_tensor", lambda self, x: x)
# ============================
diff --git a/tests/diffusion/models/bagel/test_trajectory_recording.py b/tests/diffusion/models/bagel/test_trajectory_recording.py
index 80b3f9d9ba7..345eac10784 100644
--- a/tests/diffusion/models/bagel/test_trajectory_recording.py
+++ b/tests/diffusion/models/bagel/test_trajectory_recording.py
@@ -4,10 +4,10 @@
import types
from dataclasses import dataclass
-from unittest.mock import MagicMock, patch
import pytest
import torch
+from pytest_mock import MockerFixture
from vllm_omni.diffusion.models.bagel.bagel_transformer import (
Bagel,
@@ -23,9 +23,9 @@
EXPECTED_STEPS = NUM_TIMESTEPS - 1
-def _make_mock_bagel():
+def _make_mock_bagel(mocker: MockerFixture):
"""Create a mock Bagel with forward returning constant velocity."""
- mock = MagicMock(spec=Bagel)
+ mock = mocker.MagicMock(spec=Bagel)
mock._sp_size = 1
# forward returns a small constant velocity so x_t changes each step
@@ -78,18 +78,22 @@ def _make_generate_args(num_tokens=NUM_TOKENS, hidden_dim=HIDDEN_DIM, cfg=False)
@pytest.fixture(params=[False, True], ids=["no_cfg", "batched_cfg"])
-def bagel_and_args(request):
+def bagel_and_args(
+ request,
+ monkeypatch: pytest.MonkeyPatch,
+ mocker: MockerFixture,
+):
"""Mock Bagel instance and generate_image arguments.
Parametrized over CFG mode so every test runs on both the no-CFG
and batched-CFG code paths.
"""
cfg = request.param
- with patch(
+ monkeypatch.setattr(
"vllm_omni.diffusion.models.bagel.bagel_transformer.get_classifier_free_guidance_world_size",
- return_value=1,
- ):
- yield _make_mock_bagel(), _make_generate_args(cfg=cfg)
+ lambda: 1,
+ )
+ yield _make_mock_bagel(mocker), _make_generate_args(cfg=cfg)
class TestTrajectoryRecording:
@@ -188,12 +192,16 @@ class TestTrajectoryLogProbs:
"""Tests for log-prob recording when a scheduler is provided."""
@pytest.fixture()
- def bagel_scheduler_args(self):
- with patch(
+ def bagel_scheduler_args(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ mocker: MockerFixture,
+ ):
+ monkeypatch.setattr(
"vllm_omni.diffusion.models.bagel.bagel_transformer.get_classifier_free_guidance_world_size",
- return_value=1,
- ):
- yield _make_mock_bagel(), _make_generate_args(), _MockScheduler()
+ lambda: 1,
+ )
+ yield _make_mock_bagel(mocker), _make_generate_args(), _MockScheduler()
def test_log_probs_recorded_with_scheduler(self, bagel_scheduler_args):
bagel, args, scheduler = bagel_scheduler_args
diff --git a/tests/diffusion/models/flux2/test_flux2_transformer_tp.py b/tests/diffusion/models/flux2/test_flux2_transformer_tp.py
index faad08afd1c..54dda1dd07e 100644
--- a/tests/diffusion/models/flux2/test_flux2_transformer_tp.py
+++ b/tests/diffusion/models/flux2/test_flux2_transformer_tp.py
@@ -1,7 +1,6 @@
-from unittest.mock import MagicMock, patch
-
import pytest
import torch
+from pytest_mock import MockerFixture
from tests.utils import hardware_test
from vllm_omni.diffusion.models.flux2.flux2_transformer import (
@@ -12,14 +11,17 @@
# Initialize TP group before tests
@pytest.fixture(scope="function", autouse=True)
-def setup_tp_group():
+def setup_tp_group(mocker: MockerFixture):
"""Set up TP group for each test function"""
- with patch("vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size", return_value=2):
- with patch("vllm.distributed.parallel_state.get_tp_group") as mock_get_tp_group:
- mock_tp_group = MagicMock()
- mock_tp_group.world_size = 2
- mock_get_tp_group.return_value = mock_tp_group
- yield
+ mocker.patch(
+ "vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size",
+ return_value=2,
+ )
+ mock_get_tp_group = mocker.patch("vllm.distributed.parallel_state.get_tp_group")
+ mock_tp_group = mocker.MagicMock()
+ mock_tp_group.world_size = 2
+ mock_get_tp_group.return_value = mock_tp_group
+ yield
class TestFlux2TransformerWeightLoading:
diff --git a/tests/diffusion/offloader/test_sequential_backend.py b/tests/diffusion/offloader/test_sequential_backend.py
index d18637a780e..2539cc06895 100644
--- a/tests/diffusion/offloader/test_sequential_backend.py
+++ b/tests/diffusion/offloader/test_sequential_backend.py
@@ -3,8 +3,6 @@
"""Unit tests for SequentialOffloadBackend."""
-from unittest.mock import patch
-
import pytest
import torch
from torch import nn
@@ -44,7 +42,7 @@ def mock(self):
class TestMoveParamsPinMemory:
- def test_dtensor_skips_pin_memory(self, accelerator_device):
+ def test_dtensor_skips_pin_memory(self, accelerator_device, monkeypatch: pytest.MonkeyPatch):
"""DTensor should skip pin_memory to avoid RuntimeError."""
module = _create_simple_module().to(accelerator_device)
tracker, mock_pin = _track_pin_memory_calls()
@@ -56,73 +54,73 @@ def fake_isinstance(obj, cls):
return True
return original_isinstance(obj, cls)
- with patch.object(torch.Tensor, "pin_memory", mock_pin):
- with patch("builtins.isinstance", fake_isinstance):
- hook = SequentialOffloadHook(
- offload_targets=[],
- device=accelerator_device,
- pin_memory=True,
- use_hsdp=False,
- )
- hook._move_params(
- module,
- torch.device("cpu"),
- non_blocking=False,
- pin_memory=True,
- )
- assert not tracker["called"], "pin_memory should not be called for DTensor"
-
- def test_regular_tensor_calls_pin_memory(self, accelerator_device):
+ monkeypatch.setattr(torch.Tensor, "pin_memory", mock_pin)
+ monkeypatch.setattr("builtins.isinstance", fake_isinstance)
+ hook = SequentialOffloadHook(
+ offload_targets=[],
+ device=accelerator_device,
+ pin_memory=True,
+ use_hsdp=False,
+ )
+ hook._move_params(
+ module,
+ torch.device("cpu"),
+ non_blocking=False,
+ pin_memory=True,
+ )
+ assert not tracker["called"], "pin_memory should not be called for DTensor"
+
+ def test_regular_tensor_calls_pin_memory(self, accelerator_device, monkeypatch: pytest.MonkeyPatch):
"""Regular tensor should call pin_memory when moving to CPU."""
module = _create_simple_module().to(accelerator_device)
tracker, mock_pin = _track_pin_memory_calls()
- with patch.object(torch.Tensor, "pin_memory", mock_pin):
- hook = SequentialOffloadHook(
- offload_targets=[],
- device=accelerator_device,
- pin_memory=True,
- use_hsdp=False,
- )
- hook._move_params(
- module,
- torch.device("cpu"),
- non_blocking=False,
- pin_memory=True,
- )
- assert tracker["called"], "pin_memory should be called for regular tensors"
-
- def test_pin_memory_skipped_when_disabled(self, accelerator_device):
+ monkeypatch.setattr(torch.Tensor, "pin_memory", mock_pin)
+ hook = SequentialOffloadHook(
+ offload_targets=[],
+ device=accelerator_device,
+ pin_memory=True,
+ use_hsdp=False,
+ )
+ hook._move_params(
+ module,
+ torch.device("cpu"),
+ non_blocking=False,
+ pin_memory=True,
+ )
+ assert tracker["called"], "pin_memory should be called for regular tensors"
+
+ def test_pin_memory_skipped_when_disabled(self, accelerator_device, monkeypatch: pytest.MonkeyPatch):
"""pin_memory should not be called when pin_memory=False."""
module = _create_simple_module().to(accelerator_device)
tracker, mock_pin = _track_pin_memory_calls()
- with patch.object(torch.Tensor, "pin_memory", mock_pin):
- hook = SequentialOffloadHook(
- offload_targets=[],
- device=accelerator_device,
- pin_memory=False,
- use_hsdp=False,
- )
- hook._move_params(
- module,
- torch.device("cpu"),
- non_blocking=False,
- pin_memory=False,
- )
- assert not tracker["called"], "pin_memory should not be called when disabled"
-
- def test_pin_memory_skipped_for_non_cpu_target(self, accelerator_device):
+ monkeypatch.setattr(torch.Tensor, "pin_memory", mock_pin)
+ hook = SequentialOffloadHook(
+ offload_targets=[],
+ device=accelerator_device,
+ pin_memory=False,
+ use_hsdp=False,
+ )
+ hook._move_params(
+ module,
+ torch.device("cpu"),
+ non_blocking=False,
+ pin_memory=False,
+ )
+ assert not tracker["called"], "pin_memory should not be called when disabled"
+
+ def test_pin_memory_skipped_for_non_cpu_target(self, accelerator_device, monkeypatch: pytest.MonkeyPatch):
"""pin_memory should not be called for non-CPU targets."""
module = _create_simple_module().to("cpu")
tracker, mock_pin = _track_pin_memory_calls()
- with patch.object(torch.Tensor, "pin_memory", mock_pin):
- hook = SequentialOffloadHook(
- offload_targets=[],
- device=torch.device("cpu"),
- pin_memory=True,
- use_hsdp=False,
- )
- hook._move_params(module, accelerator_device, non_blocking=False, pin_memory=True)
- assert not tracker["called"], "pin_memory should not be called for non-CPU target"
+ monkeypatch.setattr(torch.Tensor, "pin_memory", mock_pin)
+ hook = SequentialOffloadHook(
+ offload_targets=[],
+ device=torch.device("cpu"),
+ pin_memory=True,
+ use_hsdp=False,
+ )
+ hook._move_params(module, accelerator_device, non_blocking=False, pin_memory=True)
+ assert not tracker["called"], "pin_memory should not be called for non-CPU target"
diff --git a/tests/diffusion/quantization/test_int8_config.py b/tests/diffusion/quantization/test_int8_config.py
index d4d5aa5a7fe..875277ece42 100644
--- a/tests/diffusion/quantization/test_int8_config.py
+++ b/tests/diffusion/quantization/test_int8_config.py
@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for Int8 quantization config."""
-from unittest.mock import MagicMock, patch
-
import pytest
import torch
from pytest_mock import MockerFixture
@@ -102,7 +100,7 @@ def test_quantization_config_string_and_dict_equivalent():
assert config_str.quantization_config.activation_scheme == config_dict.quantization_config.activation_scheme
-def test_get_quant_method(mocker: MockerFixture):
+def test_get_quant_method(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch):
"""Test for get_quant_method method for GPU"""
from vllm_omni.quantization.int8_config import Int8OnlineLinearMethod
@@ -111,18 +109,16 @@ def test_get_quant_method(mocker: MockerFixture):
def _fake_init(self, quant_config):
pass
- layer = MagicMock(spec=LinearBase)
+ layer = mocker.Mock(spec=LinearBase)
mocker.patch.object(Int8OnlineLinearMethod, "__init__", _fake_init)
prefix = "test_layer"
# Mock the platform to be GPU
- with (
- patch("vllm_omni.platforms.current_omni_platform.is_cuda", return_value=True),
- patch("vllm_omni.platforms.current_omni_platform.is_npu", return_value=False),
- ):
- method = config.get_quant_method(layer, prefix)
- assert isinstance(method, Int8OnlineLinearMethod)
+ monkeypatch.setattr(current_omni_platform, "is_cuda", lambda: True)
+ monkeypatch.setattr(current_omni_platform, "is_npu", lambda: False)
+ method = config.get_quant_method(layer, prefix)
+ assert isinstance(method, Int8OnlineLinearMethod)
# Test skipping quantization for a layer
config.ignored_layers = [prefix]
@@ -130,22 +126,20 @@ def _fake_init(self, quant_config):
assert isinstance(method, UnquantizedLinearMethod)
-def test_get_npu_quant_method():
+def test_get_npu_quant_method(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch):
"""Test for get_quant_method method for NPU"""
from vllm_omni.quantization.int8_config import NPUInt8OnlineLinearMethod
config = build_quant_config("int8")
- layer = MagicMock(spec=LinearBase)
+ layer = mocker.Mock(spec=LinearBase)
prefix = "test_layer"
# Mock the platform to be NPU
- with (
- patch("vllm_omni.platforms.current_omni_platform.is_cuda", return_value=False),
- patch("vllm_omni.platforms.current_omni_platform.is_npu", return_value=True),
- ):
- method = config.get_quant_method(layer, prefix)
- assert isinstance(method, NPUInt8OnlineLinearMethod)
+ monkeypatch.setattr(current_omni_platform, "is_cuda", lambda: False)
+ monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True)
+ method = config.get_quant_method(layer, prefix)
+ assert isinstance(method, NPUInt8OnlineLinearMethod)
# Test skipping quantization for a layer
config.ignored_layers = [prefix]
@@ -245,7 +239,7 @@ class TestNPUInt8LinearMethod:
@pytest.fixture
def mock_torch_npu(self, mocker):
- torch_npu = MagicMock()
+ torch_npu = mocker.MagicMock()
mocker.patch("vllm_omni.quantization.int8_config.torch_npu", return_value=torch_npu)
mocker.patch(
diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py
index 4324ba1e630..a64d9920e03 100644
--- a/tests/diffusion/test_diffusion_scheduler.py
+++ b/tests/diffusion/test_diffusion_scheduler.py
@@ -4,10 +4,10 @@
import queue
import threading
from types import SimpleNamespace
-from unittest.mock import Mock, patch
import pytest
import torch
+from pytest_mock import MockerFixture
from vllm_omni.diffusion.data import DiffusionOutput, DiffusionRequestAbortedError
from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
@@ -97,19 +97,19 @@ def initialize(self, od_config) -> None:
def add_request(self, request: OmniDiffusionRequest) -> str:
assert request is self._request
- self._state = Mock(sched_req_id=self._sched_req_id, req=request)
+ self._state = SimpleNamespace(sched_req_id=self._sched_req_id, req=request)
return self._sched_req_id
def schedule(self):
if self._scheduled or self._state is None:
- return Mock(
+ return SimpleNamespace(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
scheduled_req_ids=[],
is_empty=True,
)
self._scheduled = True
- return Mock(
+ return SimpleNamespace(
scheduled_new_reqs=[NewRequestData.from_state(self._state)],
scheduled_cached_reqs=CachedRequestData.make_empty(),
scheduled_req_ids=[self._state.sched_req_id],
@@ -153,7 +153,7 @@ def close(self) -> None:
class TestRequestScheduler:
def setup_method(self) -> None:
self.scheduler: RequestScheduler = RequestScheduler()
- self.scheduler.initialize(Mock())
+ self.scheduler.initialize(SimpleNamespace())
def test_single_request_success_lifecycle(self) -> None:
req_id = self.scheduler.add_request(_make_request("a"))
@@ -276,23 +276,23 @@ def test_request_id_mapping_lifecycle(self) -> None:
class TestDiffusionEngine:
- def test_add_req_and_wait_for_response_single_path(self) -> None:
+ def test_add_req_and_wait_for_response_single_path(self, mocker: MockerFixture) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
engine.scheduler = RequestScheduler()
- engine.scheduler.initialize(Mock())
+ engine.scheduler.initialize(SimpleNamespace())
engine._rpc_lock = threading.RLock()
engine.abort_queue = queue.Queue()
request = _make_request("engine")
runner_output = _make_request_output("engine")
- engine.execute_fn = Mock(return_value=runner_output)
+ engine.execute_fn = mocker.Mock(return_value=runner_output)
output = engine.add_req_and_wait_for_response(request)
assert output is runner_output.result
engine.execute_fn.assert_called_once()
- def test_supports_scheduler_interface_injection(self) -> None:
+ def test_supports_scheduler_interface_injection(self, mocker: MockerFixture) -> None:
request = _make_request("engine_iface")
runner_output = _make_request_output("engine_iface")
scheduler = _StubScheduler(request, runner_output)
@@ -301,33 +301,45 @@ def test_supports_scheduler_interface_injection(self) -> None:
engine.scheduler = scheduler
engine._rpc_lock = threading.RLock()
engine.abort_queue = queue.Queue()
- engine.execute_fn = Mock(return_value=runner_output)
+ engine.execute_fn = mocker.Mock(return_value=runner_output)
output = engine.add_req_and_wait_for_response(request)
assert output is runner_output.result
engine.execute_fn.assert_called_once()
- def test_initializes_injected_scheduler(self) -> None:
+ def test_initializes_injected_scheduler(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ mocker: MockerFixture,
+ ) -> None:
request = _make_request("init")
scheduler = _StubScheduler(request, DiffusionOutput(output=None))
- od_config = Mock(model_class_name="mock_model")
- fake_executor_cls = Mock(return_value=Mock())
+ od_config = SimpleNamespace(model_class_name="mock_model")
+ fake_executor_cls = mocker.Mock(return_value=mocker.Mock())
- with (
- patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func", return_value=None),
- patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func", return_value=None),
- patch("vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class", return_value=fake_executor_cls),
- patch.object(DiffusionEngine, "_dummy_run", return_value=None),
- ):
- DiffusionEngine(od_config, scheduler=scheduler)
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func",
+ lambda *args, **kwargs: None,
+ )
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func",
+ lambda *args, **kwargs: None,
+ )
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class",
+ lambda *args, **kwargs: fake_executor_cls,
+ )
+ monkeypatch.setattr(DiffusionEngine, "_dummy_run", lambda self: None)
+
+ DiffusionEngine(od_config, scheduler=scheduler)
assert scheduler.initialized_with is od_config
fake_executor_cls.assert_called_once_with(od_config)
def test_scheduler_alias_keeps_default_request_scheduler(self) -> None:
scheduler = Scheduler()
- scheduler.initialize(Mock())
+ scheduler.initialize(SimpleNamespace())
req_id = scheduler.add_request(_make_request("alias"))
sched_output = scheduler.schedule()
@@ -336,10 +348,10 @@ def test_scheduler_alias_keeps_default_request_scheduler(self) -> None:
assert req_id in finished
assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED
- def test_step_raises_aborted_error(self) -> None:
+ def test_step_raises_aborted_error(self, mocker: MockerFixture) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
engine.pre_process_func = None
- engine.add_req_and_wait_for_response = Mock(
+ engine.add_req_and_wait_for_response = mocker.Mock(
return_value=DiffusionOutput(aborted=True, abort_message="Request req-abort aborted.")
)
@@ -349,7 +361,7 @@ def test_step_raises_aborted_error(self) -> None:
def test_abort_queue_marks_request_finished_aborted(self) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
engine.scheduler = RequestScheduler()
- engine.scheduler.initialize(Mock())
+ engine.scheduler.initialize(SimpleNamespace())
engine.abort_queue = queue.Queue()
req_id = engine.scheduler.add_request(_make_request("req-abort"))
@@ -361,7 +373,7 @@ def test_abort_queue_marks_request_finished_aborted(self) -> None:
def test_finalize_finished_request_returns_aborted_output(self) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
engine.scheduler = RequestScheduler()
- engine.scheduler.initialize(Mock())
+ engine.scheduler.initialize(SimpleNamespace())
req_id = engine.scheduler.add_request(_make_request("req-finalize"))
engine.scheduler.finish_requests(req_id, DiffusionRequestStatus.FINISHED_ABORTED)
@@ -371,29 +383,40 @@ def test_finalize_finished_request_returns_aborted_output(self) -> None:
assert output.aborted is True
assert output.abort_message == "Request req-finalize aborted."
- def test_initializes_step_scheduler_when_step_execution_enabled(self) -> None:
- od_config = Mock(model_class_name="mock_model")
+ def test_initializes_step_scheduler_when_step_execution_enabled(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ mocker: MockerFixture,
+ ) -> None:
+ od_config = SimpleNamespace(model_class_name="mock_model")
od_config.step_execution = True
- fake_executor = Mock()
- fake_executor_cls = Mock(return_value=fake_executor)
+ fake_executor = mocker.Mock()
+ fake_executor_cls = mocker.Mock(return_value=fake_executor)
- with (
- patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func", return_value=None),
- patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func", return_value=None),
- patch("vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class", return_value=fake_executor_cls),
- patch.object(DiffusionEngine, "_dummy_run", return_value=None),
- ):
- engine = DiffusionEngine(od_config)
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func",
+ lambda *args, **kwargs: None,
+ )
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func",
+ lambda *args, **kwargs: None,
+ )
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class",
+ lambda *args, **kwargs: fake_executor_cls,
+ )
+ monkeypatch.setattr(DiffusionEngine, "_dummy_run", lambda self: None)
+ engine = DiffusionEngine(od_config)
assert isinstance(engine.scheduler, StepScheduler)
assert engine.execute_fn is fake_executor.execute_step
fake_executor_cls.assert_called_once_with(od_config)
- def test_dummy_run_raises_on_output_error(self) -> None:
+ def test_dummy_run_raises_on_output_error(self, mocker: MockerFixture) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
- engine.od_config = Mock(model_class_name="mock_model")
+ engine.od_config = SimpleNamespace(model_class_name="mock_model")
engine.pre_process_func = None
- engine.add_req_and_wait_for_response = Mock(return_value=DiffusionOutput(error="boom"))
+ engine.add_req_and_wait_for_response = mocker.Mock(return_value=DiffusionOutput(error="boom"))
with pytest.raises(RuntimeError, match="Dummy run failed: boom"):
engine._dummy_run()
@@ -402,7 +425,7 @@ def test_dummy_run_raises_on_output_error(self) -> None:
class TestStepScheduler:
def setup_method(self) -> None:
self.scheduler: StepScheduler = StepScheduler()
- self.scheduler.initialize(Mock())
+ self.scheduler.initialize(SimpleNamespace())
def test_single_request_step_lifecycle(self) -> None:
request = _make_step_request("step", num_inference_steps=3)
diff --git a/tests/diffusion/test_diffusion_step_pipeline.py b/tests/diffusion/test_diffusion_step_pipeline.py
index 68aba9ba3bf..42687d4a1ed 100644
--- a/tests/diffusion/test_diffusion_step_pipeline.py
+++ b/tests/diffusion/test_diffusion_step_pipeline.py
@@ -7,10 +7,10 @@
import threading
from contextlib import contextmanager
from types import SimpleNamespace
-from unittest.mock import Mock
import pytest
import torch
+from pytest_mock import MockerFixture
import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module
from tests.utils import hardware_test
@@ -542,11 +542,11 @@ def test_rejects_lora_requests_in_step_mode(self):
class TestExecutor:
"""MultiprocDiffusionExecutor.execute_step"""
- def test_execute_step_passes_through_runner_output(self):
+ def test_execute_step_passes_through_runner_output(self, mocker: MockerFixture):
executor = object.__new__(MultiprocDiffusionExecutor)
executor._ensure_open = lambda: None
expected = RunnerOutput(req_id="req-step", step_index=1, finished=False, result=None)
- executor.collective_rpc = Mock(return_value=expected)
+ executor.collective_rpc = mocker.Mock(return_value=expected)
request = _make_engine_request("req-step", num_inference_steps=2)
scheduler_output = _make_scheduler_output(request, sched_req_id="req-step")
@@ -578,9 +578,9 @@ class TestEngine:
),
],
)
- def test_step_engine_returns_error(self, execute_fn, expected_error):
+ def test_step_engine_returns_error(self, execute_fn, expected_error, mocker: MockerFixture):
scheduler = StepScheduler()
- scheduler.initialize(Mock())
+ scheduler.initialize(mocker.Mock())
engine = _make_engine(scheduler, execute_fn=execute_fn)
output = engine.add_req_and_wait_for_response(_make_engine_request("req-error", num_inference_steps=2))
@@ -588,9 +588,9 @@ def test_step_engine_returns_error(self, execute_fn, expected_error):
assert output.output is None
assert expected_error in output.error
- def test_step_execution_completes(self):
+ def test_step_execution_completes(self, mocker: MockerFixture):
scheduler = StepScheduler()
- scheduler.initialize(Mock())
+ scheduler.initialize(mocker.Mock())
engine = _make_engine(scheduler)
request = _make_engine_request("req-step", num_inference_steps=2)
@@ -614,9 +614,9 @@ def execute_fn(_):
assert output.error is None
assert torch.equal(output.output, torch.tensor([2.0]))
- def test_step_abort_stops_rescheduling_after_first_step(self):
+ def test_step_abort_stops_rescheduling_after_first_step(self, mocker: MockerFixture):
scheduler = StepScheduler()
- scheduler.initialize(Mock())
+ scheduler.initialize(mocker.Mock())
engine = _make_engine(scheduler)
request = _make_engine_request("req-stop", num_inference_steps=4)
@@ -639,9 +639,9 @@ def execute_fn(_):
assert step["n"] == 1
_assert_aborted_output(output, "req-stop")
- def test_step_abort_after_reschedule_returns_aborted_output(self):
+ def test_step_abort_after_reschedule_returns_aborted_output(self, mocker: MockerFixture):
scheduler = StepScheduler()
- scheduler.initialize(Mock())
+ scheduler.initialize(mocker.Mock())
engine = _make_engine(scheduler)
request = _make_engine_request("req-mid", num_inference_steps=4)
@@ -666,9 +666,9 @@ def execute_fn(sched_output):
assert step["n"] == 2
_assert_aborted_output(output, "req-mid")
- def test_finished_step_without_result_returns_error(self):
+ def test_finished_step_without_result_returns_error(self, mocker: MockerFixture):
scheduler = StepScheduler()
- scheduler.initialize(Mock())
+ scheduler.initialize(mocker.Mock())
engine = _make_engine(
scheduler,
execute_fn=lambda _: RunnerOutput(
diff --git a/tests/diffusion/test_diffusion_worker_cuda_profiler.py b/tests/diffusion/test_diffusion_worker_cuda_profiler.py
index ddc2aed2fc2..4a3b22c212e 100644
--- a/tests/diffusion/test_diffusion_worker_cuda_profiler.py
+++ b/tests/diffusion/test_diffusion_worker_cuda_profiler.py
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from unittest.mock import MagicMock
-
import pytest
from pytest_mock import MockerFixture
@@ -55,8 +53,8 @@ def test_profile_start_stop_delegates_to_cuda_profiler(
mock_diffusion_worker_dependencies,
):
fake_profiler = mocker.Mock()
- fake_profiler.start = MagicMock()
- fake_profiler.stop = MagicMock()
+ fake_profiler.start = mocker.Mock()
+ fake_profiler.stop = mocker.Mock()
mocker.patch(
"vllm_omni.diffusion.worker.diffusion_worker.CudaProfilerWrapper",
return_value=fake_profiler,
diff --git a/tests/diffusion/test_multiproc_engine_concurrency.py b/tests/diffusion/test_multiproc_engine_concurrency.py
index 517f98ddaa9..4bc3e05fe91 100644
--- a/tests/diffusion/test_multiproc_engine_concurrency.py
+++ b/tests/diffusion/test_multiproc_engine_concurrency.py
@@ -3,7 +3,7 @@
import queue
import threading
-from unittest.mock import Mock, patch
+from types import SimpleNamespace
import pytest
import torch
@@ -24,11 +24,9 @@ def _tagged_output(tag: str) -> DiffusionOutput:
return DiffusionOutput(output=torch.tensor([0]), error=tag)
-def _mock_request(tag: str) -> Mock:
- """Return a mock ``OmniDiffusionRequest`` identifiable by *tag*."""
- req = Mock()
- req.request_ids = [tag]
- return req
+def _mock_request(tag: str):
+ """Return a lightweight request object identifiable by *tag*."""
+ return SimpleNamespace(request_ids=[tag])
def _make_executor(num_gpus: int = 1):
@@ -36,20 +34,18 @@ def _make_executor(num_gpus: int = 1):
Returns ``(executor, request_queue, result_queue)``.
"""
- od_cfg = Mock()
- od_cfg.num_gpus = num_gpus
-
- with patch.object(MultiprocDiffusionExecutor, "_init_executor"):
- executor = MultiprocDiffusionExecutor(od_cfg)
+ od_cfg = SimpleNamespace(num_gpus=num_gpus)
+ monkeypatch = pytest.MonkeyPatch()
+ monkeypatch.setattr(MultiprocDiffusionExecutor, "_init_executor", lambda self: None)
+ executor = MultiprocDiffusionExecutor(od_cfg)
+ monkeypatch.undo()
req_q: queue.Queue = queue.Queue()
res_q: queue.Queue = queue.Queue()
- mock_broadcast_mq = Mock()
- mock_broadcast_mq.enqueue = req_q.put
+ mock_broadcast_mq = SimpleNamespace(enqueue=req_q.put)
- mock_rmq = Mock()
- mock_rmq.dequeue = lambda timeout=None: res_q.get(timeout=timeout if timeout is not None else 10)
+ mock_rmq = SimpleNamespace(dequeue=lambda timeout=None: res_q.get(timeout=timeout if timeout is not None else 10))
executor._broadcast_mq = mock_broadcast_mq
executor._result_mq = mock_rmq
@@ -63,7 +59,7 @@ def _make_engine(num_gpus: int = 1):
executor, req_q, res_q = _make_executor(num_gpus)
engine = DiffusionEngine.__new__(DiffusionEngine)
sched = RequestScheduler()
- sched.initialize(Mock())
+ sched.initialize(SimpleNamespace())
engine.scheduler = sched
engine.executor = executor
engine._rpc_lock = threading.RLock()
diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py
index cb1f31164ca..a1fc18f8456 100644
--- a/tests/engine/test_arg_utils.py
+++ b/tests/engine/test_arg_utils.py
@@ -6,7 +6,7 @@
import argparse
import inspect
-from unittest.mock import Mock
+from types import SimpleNamespace
import pytest
from pydantic import ValidationError
@@ -102,7 +102,7 @@ def test_qwen3_tts_codec_frame_rate_patching():
vllm_config = EngineArgs().create_model_config()
# Create a mock talking config with a dummy value for position_id_per_seconds
- mock_talker_config = Mock()
+ mock_talker_config = SimpleNamespace()
mock_talker_config.position_id_per_seconds = 12.3
vllm_config.hf_config.talker_config = mock_talker_config
@@ -146,13 +146,12 @@ def test_stage_specific_text_config_override():
# Switch the created hf text config with a mock whose
# values we want to pull through the text config helper
stage_text_config = vllm_config.hf_text_config
- vllm_config.hf_text_config = Mock()
+ vllm_config.hf_text_config = SimpleNamespace()
stage_text_config.sliding_window = 4096
stage_text_config.attention_chunk_size = 2048
# Move the stage config's text config getter & thinker config
- mock_stage_config = Mock()
- mock_stage_config.get_text_config.return_value = stage_text_config
+ mock_stage_config = SimpleNamespace(get_text_config=lambda: stage_text_config)
vllm_config.hf_config.thinker_config = mock_stage_config
# Ensure that create from a vLLM config correctly pulls the
diff --git a/tests/engine/test_async_omni_engine_input.py b/tests/engine/test_async_omni_engine_input.py
index ed6a7277b46..3700e426d42 100644
--- a/tests/engine/test_async_omni_engine_input.py
+++ b/tests/engine/test_async_omni_engine_input.py
@@ -1,6 +1,5 @@
-from unittest.mock import Mock
-
import pytest
+from pytest_mock import MockerFixture
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import EngineCoreRequest
@@ -24,18 +23,18 @@ def _make_engine_core_request() -> EngineCoreRequest:
)
-def test_build_add_request_message_preserves_additional_information():
+def test_build_add_request_message_preserves_additional_information(mocker: MockerFixture):
engine = object.__new__(AsyncOmniEngine)
params = SamplingParams(max_tokens=8)
engine.default_sampling_params_list = [params]
engine.stage_metadata = [{"stage_type": "llm"}]
engine.supported_tasks = ("speech",)
- input_processor = Mock()
+ input_processor = mocker.Mock()
input_processor.process_inputs.return_value = _make_engine_core_request()
engine.input_processor = input_processor
- output_processor = Mock()
+ output_processor = mocker.Mock()
engine.output_processors = [output_processor]
prompt = {
@@ -63,18 +62,18 @@ def test_build_add_request_message_preserves_additional_information():
output_processor.add_request.assert_called_once()
-def test_build_add_request_message_with_resumable_streaming():
+def test_build_add_request_message_with_resumable_streaming(mocker: MockerFixture):
engine = object.__new__(AsyncOmniEngine)
params = SamplingParams(max_tokens=8)
engine.default_sampling_params_list = [params]
engine.stage_metadata = [{"stage_type": "llm"}]
engine.supported_tasks = ("generate",)
- input_processor = Mock()
+ input_processor = mocker.Mock()
input_processor.process_inputs.return_value = _make_engine_core_request()
engine.input_processor = input_processor
- output_processor = Mock()
+ output_processor = mocker.Mock()
engine.output_processors = [output_processor]
msg = engine._build_add_request_message(
diff --git a/tests/engine/test_async_omni_engine_outputs.py b/tests/engine/test_async_omni_engine_outputs.py
index ccf9e8cb6b6..ef3cfab3bf8 100644
--- a/tests/engine/test_async_omni_engine_outputs.py
+++ b/tests/engine/test_async_omni_engine_outputs.py
@@ -5,36 +5,36 @@
"""
import queue
-from unittest.mock import MagicMock
import pytest
+from pytest_mock import MockerFixture
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-def _make_engine(output_queue, *, thread_alive: bool = True) -> AsyncOmniEngine:
+def _make_engine(output_queue, mocker: MockerFixture, *, thread_alive: bool = True) -> AsyncOmniEngine:
"""Create an AsyncOmniEngine bypassing __init__."""
engine = object.__new__(AsyncOmniEngine)
engine.output_queue = output_queue
- engine.orchestrator_thread = MagicMock(
- is_alive=MagicMock(return_value=thread_alive),
+ engine.orchestrator_thread = mocker.MagicMock(
+ is_alive=mocker.MagicMock(return_value=thread_alive),
)
return engine
-def test_try_get_output_raises_after_orchestrator_dies():
+def test_try_get_output_raises_after_orchestrator_dies(mocker: MockerFixture):
"""Draining remaining results then hitting an empty queue with a dead
orchestrator must raise RuntimeError so callers know the pipeline is gone."""
- mock_queue = MagicMock()
+ mock_queue = mocker.MagicMock()
# First call succeeds; second call finds the queue empty.
mock_queue.sync_q.get.side_effect = [
{"type": "output", "request_id": "r1"},
queue.Empty,
]
- engine = _make_engine(mock_queue, thread_alive=True)
+ engine = _make_engine(mock_queue, mocker, thread_alive=True)
# Collect the one buffered result.
assert engine.try_get_output()["request_id"] == "r1"
@@ -47,15 +47,15 @@ def test_try_get_output_raises_after_orchestrator_dies():
@pytest.mark.asyncio
-async def test_try_get_output_async_raises_after_orchestrator_dies():
+async def test_try_get_output_async_raises_after_orchestrator_dies(mocker: MockerFixture):
"""Same scenario as above but for the async variant."""
- mock_queue = MagicMock()
+ mock_queue = mocker.MagicMock()
mock_queue.sync_q.get_nowait.side_effect = [
{"type": "output", "request_id": "r1"},
queue.Empty,
]
- engine = _make_engine(mock_queue, thread_alive=True)
+ engine = _make_engine(mock_queue, mocker, thread_alive=True)
assert (await engine.try_get_output_async())["request_id"] == "r1"
diff --git a/tests/engine/test_single_stage_mode.py b/tests/engine/test_single_stage_mode.py
index 2c5bf6cc79c..608e92ac49e 100644
--- a/tests/engine/test_single_stage_mode.py
+++ b/tests/engine/test_single_stage_mode.py
@@ -17,10 +17,11 @@
import threading
from contextlib import contextmanager
+from types import SimpleNamespace
from typing import Any
-from unittest.mock import MagicMock, Mock, patch
import pytest
+from pytest_mock import MockerFixture
from vllm.v1.engine.utils import EngineZmqAddresses
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
@@ -41,31 +42,33 @@
# ---------------------------------------------------------------------------
-def _make_stage_cfg(stage_id: int, stage_type: str = "llm") -> Mock:
+def _make_stage_cfg(stage_id: int, stage_type: str = "llm"):
"""Return a lightweight stage config mock."""
- cfg = Mock()
- cfg.stage_id = stage_id
- cfg.stage_type = stage_type
- cfg.engine_args = MagicMock()
- cfg.engine_args.async_chunk = False
- cfg.engine_args.model_stage = None
- cfg.engine_args.engine_output_type = None
- return cfg
+ return SimpleNamespace(
+ stage_id=stage_id,
+ stage_type=stage_type,
+ engine_args=SimpleNamespace(
+ async_chunk=False,
+ model_stage=None,
+ engine_output_type=None,
+ ),
+ )
def _make_started_llm_stage(stage_id: int) -> StartedLlmStage:
"""Return a minimal StartedLlmStage for mocking."""
- addresses = Mock()
- addresses.inputs = ["tcp://127.0.0.1:5000"]
- addresses.outputs = ["tcp://127.0.0.1:5001"]
- addresses.frontend_stats_publish_address = None
+ addresses = SimpleNamespace(
+ inputs=["tcp://127.0.0.1:5000"],
+ outputs=["tcp://127.0.0.1:5001"],
+ frontend_stats_publish_address=None,
+ )
return StartedLlmStage(
stage_id=stage_id,
- metadata=Mock(stage_id=stage_id),
- vllm_config=Mock(),
- executor_class=Mock(),
- engine_manager=Mock(),
- coordinator=Mock(),
+ metadata=SimpleNamespace(stage_id=stage_id),
+ vllm_config=SimpleNamespace(),
+ executor_class=SimpleNamespace(),
+ engine_manager=SimpleNamespace(),
+ coordinator=SimpleNamespace(),
addresses=addresses,
)
@@ -348,74 +351,80 @@ class TestSingleStageModeDetection:
the orchestrator thread, so no actual engines are started.
"""
- def _make_engine_no_thread(self, **kwargs: Any) -> AsyncOmniEngine:
+ def _make_engine_no_thread(self, mocker: MockerFixture, **kwargs: Any) -> AsyncOmniEngine:
"""Create an AsyncOmniEngine without starting the orchestrator thread."""
stage_cfg = _make_stage_cfg(0)
mock_stage_configs = [stage_cfg]
- with (
- patch.object(
- AsyncOmniEngine,
- "_resolve_stage_configs",
- return_value=("/fake/path", mock_stage_configs),
- ),
- patch.object(
- AsyncOmniEngine,
- "_bootstrap_orchestrator",
- ),
- patch("threading.Thread") as mock_thread_cls,
- patch("concurrent.futures.Future") as mock_future_cls,
- ):
- mock_future = Mock()
- mock_future.result.return_value = Mock() # simulates a loop
- mock_future_cls.return_value = mock_future
+ mocker.patch.object(
+ AsyncOmniEngine,
+ "_resolve_stage_configs",
+ return_value=("/fake/path", mock_stage_configs),
+ )
+ mocker.patch.object(
+ AsyncOmniEngine,
+ "_bootstrap_orchestrator",
+ )
+ mock_thread_cls = mocker.patch("threading.Thread")
+ mock_future_cls = mocker.patch("concurrent.futures.Future")
+
+ mock_future = mocker.Mock()
+ mock_future.result.return_value = mocker.Mock() # simulates a loop
+ mock_future_cls.return_value = mock_future
- mock_thread = Mock()
- mock_thread.is_alive.return_value = False
- mock_thread_cls.return_value = mock_thread
+ mock_thread = mocker.Mock()
+ mock_thread.is_alive.return_value = False
+ mock_thread_cls.return_value = mock_thread
- engine = AsyncOmniEngine(model="fake-model", **kwargs)
+ engine = AsyncOmniEngine(model="fake-model", **kwargs)
return engine
- def test_explicit_single_stage_mode_true(self):
+ def test_explicit_single_stage_mode_true(self, mocker: MockerFixture):
engine = self._make_engine_no_thread(
+ mocker,
single_stage_mode=True,
omni_master_address="127.0.0.1",
omni_master_port=20000,
)
assert engine.single_stage_mode is True
- def test_stage_id_kwarg_promotes_to_single_stage_mode(self):
+ def test_stage_id_kwarg_promotes_to_single_stage_mode(self, mocker: MockerFixture):
engine = self._make_engine_no_thread(
+ mocker,
stage_id=0,
omni_master_address="127.0.0.1",
omni_master_port=20001,
)
assert engine.single_stage_mode is True
- def test_stage_id_kwarg_sets_filter(self):
+ def test_stage_id_kwarg_sets_filter(self, mocker: MockerFixture):
engine = self._make_engine_no_thread(
+ mocker,
stage_id=1,
omni_master_address="127.0.0.1",
omni_master_port=20002,
)
assert engine._single_stage_id_filter == 1
- def test_no_stage_id_no_single_stage_mode(self):
- engine = self._make_engine_no_thread()
+ def test_no_stage_id_no_single_stage_mode(self, mocker: MockerFixture):
+ engine = self._make_engine_no_thread(
+ mocker,
+ )
assert engine.single_stage_mode is False
assert engine._single_stage_id_filter is None
- def test_single_stage_mode_without_stage_id_has_no_filter(self):
+ def test_single_stage_mode_without_stage_id_has_no_filter(self, mocker: MockerFixture):
engine = self._make_engine_no_thread(
+ mocker,
single_stage_mode=True,
omni_master_address="127.0.0.1",
omni_master_port=20003,
)
assert engine._single_stage_id_filter is None
- def test_master_address_and_port_stored(self):
+ def test_master_address_and_port_stored(self, mocker: MockerFixture):
engine = self._make_engine_no_thread(
+ mocker,
stage_id=0,
omni_master_address="10.0.0.1",
omni_master_port=12345,
@@ -423,8 +432,10 @@ def test_master_address_and_port_stored(self):
assert engine._omni_master_address == "10.0.0.1"
assert engine._omni_master_port == 12345
- def test_omni_master_server_starts_as_none(self):
- engine = self._make_engine_no_thread()
+ def test_omni_master_server_starts_as_none(self, mocker: MockerFixture):
+ engine = self._make_engine_no_thread(
+ mocker,
+ )
assert engine._omni_master_server is None
@@ -448,7 +459,7 @@ class TestInitializeStagesRouting:
def _build_engine_skeleton(
self,
- stage_cfgs: list[Mock],
+ stage_cfgs: list[Any],
single_stage_mode: bool,
stage_id_filter: int | None,
omni_master_address: str = "127.0.0.1",
@@ -478,8 +489,8 @@ def _build_engine_skeleton(
engine.prompt_expand_func = None
return engine
- def _fake_metadata(self, stage_id: int, stage_type: str = "llm") -> Mock:
- meta = Mock()
+ def _fake_metadata(self, mocker: MockerFixture, stage_id: int, stage_type: str = "llm") -> Any:
+ meta = mocker.Mock()
meta.stage_id = stage_id
meta.stage_type = stage_type
meta.runtime_cfg = {}
@@ -492,13 +503,14 @@ def _fake_metadata(self, stage_id: int, stage_type: str = "llm") -> Mock:
def _run_initialize_stages_mocked(
self,
+ mocker: MockerFixture,
engine: AsyncOmniEngine,
- stage_cfgs: list[Mock],
+ stage_cfgs: list[Any],
*,
launch_side_effect: Any = None,
remote_side_effect: Any = None,
attach_result: Any = None,
- ) -> tuple[Mock, Mock]:
+ ) -> tuple[Any, Any]:
"""Execute _initialize_stages with all heavy helpers mocked.
Returns (mock_launch_llm_stage, mock_create_remote_llm_stage).
@@ -509,167 +521,217 @@ def _run_initialize_stages_mocked(
if getattr(cfg, "stage_type", "llm") != "diffusion"
}
- default_attach = (Mock(), Mock(), Mock(), Mock())
+ default_attach = (mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock())
- mock_launch = Mock(
+ mock_launch = mocker.Mock(
side_effect=launch_side_effect
or (lambda cfg, meta, spec, timeout, llm_stage_launch_lock, kv: started_by_stage[meta.stage_id])
)
- mock_remote = Mock(
+ mock_remote = mocker.Mock(
side_effect=remote_side_effect or (lambda cfg, meta, spec, timeout, srv: started_by_stage[meta.stage_id])
)
- mock_attach = Mock(return_value=attach_result or default_attach)
+ mock_attach = mocker.Mock(return_value=attach_result or default_attach)
- mock_oms = Mock(spec=OmniMasterServer)
- mock_oms.get_zmq_addresses.side_effect = lambda sid: Mock()
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
+ mock_oms.get_zmq_addresses.side_effect = lambda sid: mocker.Mock()
finalized = (
- [Mock() for _ in stage_cfgs],
- [Mock() for _ in stage_cfgs],
+ [mocker.Mock() for _ in stage_cfgs],
+ [mocker.Mock() for _ in stage_cfgs],
[{"final_output": True, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs],
)
- with (
- patch.object(engine, "_launch_llm_stage", mock_launch),
- patch.object(engine, "_create_remote_llm_stage", mock_remote),
- patch.object(engine, "_attach_llm_stage", mock_attach),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms),
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id, getattr(cfg, "stage_type", "llm")),
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
+ mocker.patch.object(engine, "_launch_llm_stage", mock_launch)
+ mocker.patch.object(engine, "_create_remote_llm_stage", mock_remote)
+ mocker.patch.object(engine, "_attach_llm_stage", mock_attach)
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.prepare_engine_environment",
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(
+ mocker,
+ cfg.stage_id,
+ getattr(cfg, "stage_type", "llm"),
),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
return mock_launch, mock_remote
# -- single-stage mode: stage matches filter → local launch ---------------
- def test_matching_stage_uses_launch_llm_stage(self):
+ def test_matching_stage_uses_launch_llm_stage(self, mocker: MockerFixture):
"""stage_id == _single_stage_id_filter → _launch_llm_stage is called."""
stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(engine, stage_cfgs)
+ mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list]
assert 0 in launched_ids, "_launch_llm_stage should be called for stage 0"
- def test_non_matching_stage_uses_create_remote_llm_stage(self):
+ def test_non_matching_stage_uses_create_remote_llm_stage(self, mocker: MockerFixture):
"""stage_id != _single_stage_id_filter → _create_remote_llm_stage is called."""
stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(engine, stage_cfgs)
+ mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
remote_ids = [c.args[1].stage_id for c in mock_remote.call_args_list]
assert 1 in remote_ids, "_create_remote_llm_stage should be called for stage 1"
- def test_filter_1_routes_correctly(self):
+ def test_filter_1_routes_correctly(self, mocker: MockerFixture):
"""With filter=1, stage 0 is remote and stage 1 is local."""
stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=1)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(engine, stage_cfgs)
+ mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list]
remote_ids = [c.args[1].stage_id for c in mock_remote.call_args_list]
assert 1 in launched_ids, "stage 1 should be launched locally with filter=1"
assert 0 in remote_ids, "stage 0 should use remote path with filter=1"
- def test_no_filter_all_stages_use_launch_path(self):
+ def test_no_filter_all_stages_use_launch_path(self, mocker: MockerFixture):
"""single_stage_mode=True but no filter → all stages use _launch_llm_stage."""
stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=None)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(engine, stage_cfgs)
+ mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
assert mock_remote.call_count == 0, "No remote launches without a filter"
launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list]
assert set(launched_ids) == {0, 1}
- def test_non_single_stage_mode_never_calls_create_remote(self):
+ def test_non_single_stage_mode_never_calls_create_remote(self, mocker: MockerFixture):
"""Outside single_stage_mode, _create_remote_llm_stage must not be called."""
stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=False, stage_id_filter=None)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(engine, stage_cfgs)
+ mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
assert mock_remote.call_count == 0
- def test_omni_master_server_started_in_single_stage_mode(self):
+ def test_omni_master_server_started_in_single_stage_mode(self, mocker: MockerFixture):
"""OmniMasterServer.start() must be called when single_stage_mode=True."""
stage_cfgs = [_make_stage_cfg(0)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_oms = Mock(spec=OmniMasterServer)
- mock_oms.get_zmq_addresses.return_value = Mock()
- finalized = ([Mock()], [Mock()], [{"final_output": True, "final_output_type": None, "stage_type": "llm"}])
-
- with (
- patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0)),
- patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0)),
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms),
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id),
- ),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
+ mock_oms.get_zmq_addresses.return_value = mocker.Mock()
+ finalized = (
+ [mocker.Mock()],
+ [mocker.Mock()],
+ [{"final_output": True, "final_output_type": None, "stage_type": "llm"}],
+ )
+
+ mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0))
+ mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0))
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
mock_oms.start.assert_called_once()
- def test_omni_master_server_uses_configured_stage_ids(self):
+ def test_omni_master_server_uses_configured_stage_ids(self, mocker: MockerFixture):
"""Configured stage IDs, not list indexes, should drive pre-allocation."""
stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7)
- mock_oms = Mock(spec=OmniMasterServer)
- mock_oms.get_zmq_addresses.return_value = Mock()
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
+ mock_oms.get_zmq_addresses.return_value = mocker.Mock()
finalized = (
- [Mock(), Mock()],
- [Mock(), Mock()],
+ [mocker.Mock(), mocker.Mock()],
+ [mocker.Mock(), mocker.Mock()],
[{"final_output": False, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs],
)
- with (
- patch.object(
- engine, "_launch_llm_stage", side_effect=[_make_started_llm_stage(7), _make_started_llm_stage(11)]
- ),
- patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(11)),
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms) as mock_oms_cls,
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id),
- ),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ mocker.patch.object(
+ engine,
+ "_launch_llm_stage",
+ side_effect=[_make_started_llm_stage(7), _make_started_llm_stage(11)],
+ )
+ mocker.patch.object(
+ engine,
+ "_create_remote_llm_stage",
+ return_value=_make_started_llm_stage(11),
+ )
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mock_oms_cls = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
mock_oms_cls.assert_called_once_with(
master_address=engine._omni_master_address,
@@ -677,73 +739,121 @@ def test_omni_master_server_uses_configured_stage_ids(self):
stage_ids=[7, 11],
)
- def test_single_stage_filter_uses_configured_stage_ids(self):
+ def test_single_stage_filter_uses_configured_stage_ids(self, mocker: MockerFixture):
"""Local/remote dispatch should compare against configured stage IDs."""
stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7)
- mock_oms = Mock(spec=OmniMasterServer)
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
finalized = (
- [Mock(), Mock()],
- [Mock(), Mock()],
+ [mocker.Mock(), mocker.Mock()],
+ [mocker.Mock(), mocker.Mock()],
[{"final_output": False, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs],
)
- with (
- patch.object(engine, "_launch_llm_stage", side_effect=[_make_started_llm_stage(7)]) as mock_launch,
- patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(11)) as mock_remote,
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms),
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id),
- ),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ mock_launch = mocker.patch.object(
+ engine,
+ "_launch_llm_stage",
+ side_effect=[_make_started_llm_stage(7)],
+ )
+ mock_remote = mocker.patch.object(
+ engine,
+ "_create_remote_llm_stage",
+ return_value=_make_started_llm_stage(11),
+ )
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
assert [call.args[1].stage_id for call in mock_launch.call_args_list] == [7]
assert [call.args[1].stage_id for call in mock_remote.call_args_list] == [11]
- def test_omni_master_server_preallocates_diffusion_stage_ids(self):
+ def test_omni_master_server_preallocates_diffusion_stage_ids(self, mocker: MockerFixture):
"""Diffusion stages should also receive OmniMasterServer allocations."""
stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11, stage_type="diffusion")]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7)
- mock_oms = Mock(spec=OmniMasterServer)
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
finalized = (
- [Mock(), Mock()],
- [Mock(), Mock()],
+ [mocker.Mock(), mocker.Mock()],
+ [mocker.Mock(), mocker.Mock()],
[
{"final_output": False, "final_output_type": None, "stage_type": "llm"},
{"final_output": False, "final_output_type": None, "stage_type": "diffusion"},
],
)
- with (
- patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(7)),
- patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(7)),
- patch.object(engine, "_launch_diffusion_stage", return_value=Mock()),
- patch.object(engine, "_create_remote_diffusion_stage", return_value=Mock()),
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms) as mock_oms_cls,
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id, getattr(cfg, "stage_type", "llm")),
+ mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(7))
+ mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(7))
+ mocker.patch.object(engine, "_launch_diffusion_stage", return_value=mocker.Mock())
+ mocker.patch.object(
+ engine,
+ "_create_remote_diffusion_stage",
+ return_value=mocker.Mock(),
+ )
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mock_oms_cls = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(
+ mocker,
+ cfg.stage_id,
+ getattr(cfg, "stage_type", "llm"),
),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
mock_oms_cls.assert_called_once_with(
master_address=engine._omni_master_address,
@@ -751,135 +861,200 @@ def test_omni_master_server_preallocates_diffusion_stage_ids(self):
stage_ids=[7, 11],
)
- def test_duplicate_llm_stage_ids_raise(self):
+ def test_duplicate_llm_stage_ids_raise(self, mocker: MockerFixture):
"""Duplicate configured LLM stage IDs should fail fast."""
stage_cfgs = [_make_stage_cfg(3), _make_stage_cfg(3)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=3)
- with (
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- pytest.raises(ValueError, match="Duplicate stage_id"),
- ):
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ with pytest.raises(ValueError, match="Duplicate stage_id"):
engine._initialize_stages(stage_init_timeout=60)
- def test_omni_master_server_not_started_in_normal_mode(self):
+ def test_omni_master_server_not_started_in_normal_mode(self, mocker: MockerFixture):
"""OmniMasterServer must NOT be instantiated outside single_stage_mode."""
stage_cfgs = [_make_stage_cfg(0)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=False, stage_id_filter=None)
- finalized = ([Mock()], [Mock()], [{"final_output": True, "final_output_type": None, "stage_type": "llm"}])
-
- with (
- patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0)),
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer") as mock_oms_cls,
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id),
- ),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ finalized = (
+ [mocker.Mock()],
+ [mocker.Mock()],
+ [{"final_output": True, "final_output_type": None, "stage_type": "llm"}],
+ )
+
+ mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0))
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mock_oms_cls = mocker.patch("vllm_omni.engine.async_omni_engine.OmniMasterServer")
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
mock_oms_cls.assert_not_called()
- def test_single_stage_mode_missing_master_address_raises(self):
+ def test_single_stage_mode_missing_master_address_raises(self, mocker: MockerFixture):
"""single_stage_mode without master address/port raises ValueError."""
stage_cfgs = [_make_stage_cfg(0)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
engine._omni_master_address = None # missing
engine._omni_master_port = None
- with (
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- pytest.raises(ValueError, match="omni_master_address"),
- ):
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ with pytest.raises(ValueError, match="omni_master_address"):
engine._initialize_stages(stage_init_timeout=60)
- def test_matching_diffusion_stage_uses_local_registered_launch(self):
+ def test_matching_diffusion_stage_uses_local_registered_launch(self, mocker: MockerFixture):
"""A local diffusion stage should use the registered single-stage launch path."""
stage_cfgs = [_make_stage_cfg(0, stage_type="diffusion"), _make_stage_cfg(1)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_oms = Mock(spec=OmniMasterServer)
- diffusion_client = Mock(stage_type="diffusion")
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
+ diffusion_client = mocker.Mock(stage_type="diffusion")
finalized = (
- [diffusion_client, Mock()],
- [Mock(), Mock()],
+ [diffusion_client, mocker.Mock()],
+ [mocker.Mock(), mocker.Mock()],
[
{"final_output": False, "final_output_type": None, "stage_type": "diffusion"},
{"final_output": False, "final_output_type": None, "stage_type": "llm"},
],
)
- with (
- patch.object(engine, "_launch_diffusion_stage", return_value=diffusion_client) as mock_local_diff,
- patch.object(engine, "_create_remote_diffusion_stage") as mock_remote_diff,
- patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(1)),
- patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(1)),
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms),
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id, getattr(cfg, "stage_type", "llm")),
+ mock_local_diff = mocker.patch.object(
+ engine,
+ "_launch_diffusion_stage",
+ return_value=diffusion_client,
+ )
+ mock_remote_diff = mocker.patch.object(engine, "_create_remote_diffusion_stage")
+ mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(1))
+ mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(1))
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(
+ mocker,
+ cfg.stage_id,
+ getattr(cfg, "stage_type", "llm"),
),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
assert mock_local_diff.call_count == 1
assert mock_local_diff.call_args.args[1].stage_id == 0
mock_remote_diff.assert_not_called()
- def test_non_matching_diffusion_stage_uses_remote_diffusion_client(self):
+ def test_non_matching_diffusion_stage_uses_remote_diffusion_client(self, mocker: MockerFixture):
"""A non-local diffusion stage should attach via the remote diffusion path."""
stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1, stage_type="diffusion")]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_oms = Mock(spec=OmniMasterServer)
- remote_diffusion_client = Mock(stage_type="diffusion")
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
+ remote_diffusion_client = mocker.Mock(stage_type="diffusion")
finalized = (
- [Mock(), remote_diffusion_client],
- [Mock(), Mock()],
+ [mocker.Mock(), remote_diffusion_client],
+ [mocker.Mock(), mocker.Mock()],
[
{"final_output": False, "final_output_type": None, "stage_type": "llm"},
{"final_output": False, "final_output_type": None, "stage_type": "diffusion"},
],
)
- with (
- patch.object(engine, "_launch_diffusion_stage") as mock_local_diff,
- patch.object(
- engine, "_create_remote_diffusion_stage", return_value=remote_diffusion_client
- ) as mock_remote_diff,
- patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0)),
- patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0)),
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms),
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id, getattr(cfg, "stage_type", "llm")),
+ mock_local_diff = mocker.patch.object(engine, "_launch_diffusion_stage")
+ mock_remote_diff = mocker.patch.object(
+ engine,
+ "_create_remote_diffusion_stage",
+ return_value=remote_diffusion_client,
+ )
+ mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0))
+ mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0))
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(
+ mocker,
+ cfg.stage_id,
+ getattr(cfg, "stage_type", "llm"),
),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
mock_local_diff.assert_not_called()
assert mock_remote_diff.call_count == 1
@@ -894,45 +1069,47 @@ def test_non_matching_diffusion_stage_uses_remote_diffusion_client(self):
class TestLaunchDiffusionStage:
"""Test local diffusion stage launch wiring."""
- def test_registers_stage_with_public_master_properties(self):
+ def test_registers_stage_with_public_master_properties(self, mocker: MockerFixture):
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
engine.diffusion_batch_size = 4
stage_cfg = _make_stage_cfg(5, stage_type="diffusion")
- metadata = Mock(stage_id=5)
- omni_master_server = Mock(spec=OmniMasterServer)
+ metadata = mocker.Mock(stage_id=5)
+ omni_master_server = mocker.Mock(spec=OmniMasterServer)
omni_master_server.address = "127.0.0.1"
omni_master_server.port = 25000
- proc = Mock()
- diffusion_client = Mock()
-
- with (
- patch("vllm_omni.engine.async_omni_engine.build_diffusion_config", return_value="diffusion-config"),
- patch(
- "vllm_omni.engine.async_omni_engine.register_stage_with_omni_master",
- return_value=(
- "tcp://127.0.0.1:25001",
- "tcp://127.0.0.1:25002",
- "tcp://127.0.0.1:25003",
- ),
- ) as mock_register,
- patch(
- "vllm_omni.engine.async_omni_engine.spawn_diffusion_proc",
- return_value=(proc, None, None, None),
- ) as mock_spawn,
- patch("vllm_omni.engine.async_omni_engine.complete_diffusion_handshake") as mock_handshake,
- patch(
- "vllm_omni.engine.async_omni_engine.StageDiffusionClient.from_addresses",
- return_value=diffusion_client,
- ) as mock_from_addresses,
- ):
- result = engine._launch_diffusion_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- omni_master_server=omni_master_server,
- )
+ proc = mocker.Mock()
+ diffusion_client = mocker.Mock()
+
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_diffusion_config",
+ return_value="diffusion-config",
+ )
+ mock_register = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.register_stage_with_omni_master",
+ return_value=(
+ "tcp://127.0.0.1:25001",
+ "tcp://127.0.0.1:25002",
+ "tcp://127.0.0.1:25003",
+ ),
+ )
+ mock_spawn = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.spawn_diffusion_proc",
+ return_value=(proc, None, None, None),
+ )
+ mock_handshake = mocker.patch("vllm_omni.engine.async_omni_engine.complete_diffusion_handshake")
+ mock_from_addresses = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.StageDiffusionClient.from_addresses",
+ return_value=diffusion_client,
+ )
+
+ result = engine._launch_diffusion_stage(
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ omni_master_server=omni_master_server,
+ )
mock_register.assert_called_once_with(
omni_master_address="127.0.0.1",
@@ -967,14 +1144,14 @@ def test_registers_stage_with_public_master_properties(self):
class TestCreateRemoteLlmStage:
"""Test _create_remote_llm_stage delegates correctly."""
- def _engine(self) -> AsyncOmniEngine:
+ def _engine(self, mocker: MockerFixture) -> AsyncOmniEngine:
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
engine.single_stage_mode = True
engine._single_stage_id_filter = 0
- engine._omni_master_server = Mock(spec=OmniMasterServer)
- engine._omni_master_server.get_zmq_addresses.return_value = Mock()
- engine._omni_master_server.get_allocation.return_value = Mock()
+ engine._omni_master_server = mocker.Mock(spec=OmniMasterServer)
+ engine._omni_master_server.get_zmq_addresses.return_value = mocker.Mock()
+ engine._omni_master_server.get_allocation.return_value = mocker.Mock()
engine._omni_master_server.get_stage_config.return_value = {
"stage_id": 0,
"stage_type": "llm",
@@ -982,42 +1159,40 @@ def _engine(self) -> AsyncOmniEngine:
}
return engine
- @contextmanager
- def _patch_build_and_connect(self, stage_id: int):
- fake_vllm_config = Mock()
- fake_executor_cls = Mock()
- fake_addresses = Mock()
+ def _mock_build_and_connect(self, mocker: MockerFixture, stage_id: int):
+ fake_vllm_config = mocker.Mock()
+ fake_executor_cls = mocker.Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
- eng_mgr = Mock()
- coordinator = Mock()
+ eng_mgr = mocker.Mock()
+ coordinator = mocker.Mock()
@contextmanager
def fake_connect_cm(*args, **kwargs):
yield eng_mgr, coordinator, fake_addresses
- with (
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": stage_id},
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
- return_value=fake_connect_cm(),
- ) as mock_connect,
- ):
- yield mock_connect, fake_vllm_config, fake_executor_cls, fake_addresses
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": stage_id},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(fake_vllm_config, fake_executor_cls),
+ )
+ mock_connect = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
+ return_value=fake_connect_cm(),
+ )
+
+ return mock_connect, fake_vllm_config, fake_executor_cls, fake_addresses
- def test_returns_started_llm_stage_with_correct_stage_id(self):
- engine = self._engine()
+ def test_returns_started_llm_stage_with_correct_stage_id(self, mocker: MockerFixture):
+ engine = self._engine(mocker)
stage_cfg = _make_stage_cfg(1)
- metadata = Mock(stage_id=1)
+ metadata = mocker.Mock(stage_id=1)
omni_ms = engine._omni_master_server
omni_ms.get_stage_config.return_value = {
"stage_id": 1,
@@ -1025,93 +1200,93 @@ def test_returns_started_llm_stage_with_correct_stage_id(self):
"engine_args": {},
}
- with self._patch_build_and_connect(1):
- result = engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
+ self._mock_build_and_connect(mocker, 1)
+ result = engine._create_remote_llm_stage(
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ omni_master_server=omni_ms,
+ )
assert isinstance(result, StartedLlmStage)
assert result.stage_id == 1
- def test_connect_remote_engine_cores_called_with_stage_id(self):
- engine = self._engine()
+ def test_connect_remote_engine_cores_called_with_stage_id(self, mocker: MockerFixture):
+ engine = self._engine(mocker)
stage_cfg = _make_stage_cfg(2)
- metadata = Mock(stage_id=2)
+ metadata = mocker.Mock(stage_id=2)
omni_ms = engine._omni_master_server
- omni_ms.get_zmq_addresses.return_value = Mock(inputs=["x"], outputs=["y"])
+ omni_ms.get_zmq_addresses.return_value = mocker.Mock(inputs=["x"], outputs=["y"])
omni_ms.get_stage_config.return_value = {
"stage_id": 2,
"stage_type": "llm",
"engine_args": {},
}
- fake_vllm_config = Mock()
- fake_executor_cls = Mock()
- fake_addresses = Mock()
+ fake_vllm_config = mocker.Mock()
+ fake_executor_cls = mocker.Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
@contextmanager
def fake_connect_cm(*args, **kwargs):
- yield Mock(), Mock(), fake_addresses
+ yield mocker.Mock(), mocker.Mock(), fake_addresses
- with (
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 2},
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores", return_value=fake_connect_cm()
- ) as mock_connect,
- ):
- engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": 2},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(fake_vllm_config, fake_executor_cls),
+ )
+ mock_connect = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
+ return_value=fake_connect_cm(),
+ )
+
+ engine._create_remote_llm_stage(
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ omni_master_server=omni_ms,
+ )
mock_connect.assert_called_once()
_, kwargs = mock_connect.call_args
assert kwargs.get("stage_id") == 2 or mock_connect.call_args.args[-1] == 2
omni_ms.get_stage_config.assert_called_once_with(2, timeout_s=60)
- def test_missing_registered_stage_config_raises_value_error(self):
- engine = self._engine()
+ def test_missing_registered_stage_config_raises_value_error(self, mocker: MockerFixture):
+ engine = self._engine(mocker)
stage_cfg = _make_stage_cfg(3)
- metadata = Mock(stage_id=3)
+ metadata = mocker.Mock(stage_id=3)
omni_ms = engine._omni_master_server
omni_ms.get_stage_config.return_value = None
- with patch("vllm_omni.engine.async_omni_engine.build_engine_args_dict") as mock_build_args:
- with pytest.raises(
- ValueError,
- match="Remote stage 3 registered without stage config",
- ):
- engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
+ mock_build_args = mocker.patch("vllm_omni.engine.async_omni_engine.build_engine_args_dict")
+ with pytest.raises(
+ ValueError,
+ match="Remote stage 3 registered without stage config",
+ ):
+ engine._create_remote_llm_stage(
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ omni_master_server=omni_ms,
+ )
mock_build_args.assert_not_called()
- def test_exception_during_connect_closes_started_stage(self):
+ def test_exception_during_connect_closes_started_stage(self, mocker: MockerFixture):
"""If an error occurs after StartedLlmStage creation, close_started_llm_stage is called."""
- engine = self._engine()
+ engine = self._engine(mocker)
stage_cfg = _make_stage_cfg(1)
- metadata = Mock(stage_id=1)
+ metadata = mocker.Mock(stage_id=1)
omni_ms = engine._omni_master_server
omni_ms.get_stage_config.return_value = {
"stage_id": 1,
@@ -1121,26 +1296,30 @@ def test_exception_during_connect_closes_started_stage(self):
@contextmanager
def boom(*args, **kwargs):
- yield Mock(), Mock(), Mock()
+ yield mocker.Mock(), mocker.Mock(), mocker.Mock()
raise RuntimeError("handshake failed")
- with (
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 1},
- ),
- patch("vllm_omni.engine.async_omni_engine.build_vllm_config", return_value=(Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.connect_remote_engine_cores", return_value=boom()),
- patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage") as mock_close,
- ):
- with pytest.raises(RuntimeError, match="handshake failed"):
- engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": 1},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
+ return_value=boom(),
+ )
+ mock_close = mocker.patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage")
+ with pytest.raises(RuntimeError, match="handshake failed"):
+ engine._create_remote_llm_stage(
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ omni_master_server=omni_ms,
+ )
mock_close.assert_called_once()
@@ -1148,27 +1327,29 @@ class TestConnectRemoteEngineCoresCoordinator:
"""Test coordinator launch parity with launch_core_engines."""
@staticmethod
- def _build_vllm_config(*, dp_rank: int = 0, offline_mode: bool = False, needs_dp_coordinator: bool = True) -> Mock:
- parallel_config = Mock()
+ def _build_vllm_config(
+ mocker: MockerFixture, *, dp_rank: int = 0, offline_mode: bool = False, needs_dp_coordinator: bool = True
+ ) -> Any:
+ parallel_config = mocker.Mock()
parallel_config.data_parallel_size_local = 1
parallel_config.data_parallel_size = 2
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = 0 if offline_mode else None
- vllm_config = Mock()
+ vllm_config = mocker.Mock()
vllm_config.parallel_config = parallel_config
vllm_config.needs_dp_coordinator = needs_dp_coordinator
- vllm_config.model_config = Mock(is_moe=False)
+ vllm_config.model_config = mocker.Mock(is_moe=False)
return vllm_config
- def test_uses_registered_coordinator_addresses(self):
- vllm_config = self._build_vllm_config(dp_rank=0, offline_mode=False, needs_dp_coordinator=True)
+ def test_uses_registered_coordinator_addresses(self, mocker: MockerFixture):
+ vllm_config = self._build_vllm_config(mocker, dp_rank=0, offline_mode=False, needs_dp_coordinator=True)
- omni_master_server = Mock(spec=OmniMasterServer)
+ omni_master_server = mocker.Mock(spec=OmniMasterServer)
omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses(
inputs=["tcp://client-in"], outputs=["tcp://client-out"]
)
- omni_master_server.get_allocation.return_value = Mock(handshake_bind_address="tcp://127.0.0.1:26001")
+ omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
omni_master_server.get_stage_coordinator_addresses.return_value = StageCoordinatorAddresses(
coordinator_input="tcp://coord-in",
coordinator_output="tcp://coord-out",
@@ -1177,103 +1358,107 @@ def test_uses_registered_coordinator_addresses(self):
@contextmanager
def fake_socket_ctx(*args, **kwargs):
- yield Mock()
+ yield mocker.Mock()
- with (
- patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx()),
- patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup") as mock_wait,
- ):
- with connect_remote_engine_cores(
- vllm_config=vllm_config,
- omni_master_server=omni_master_server,
- stage_id=7,
- ) as (_, yielded_coordinator, yielded_addresses):
- assert yielded_coordinator is None
- assert yielded_addresses.coordinator_input == "tcp://coord-in"
- assert yielded_addresses.coordinator_output == "tcp://coord-out"
- assert yielded_addresses.frontend_stats_publish_address == "tcp://stats"
+ mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
+ return_value=fake_socket_ctx(),
+ )
+ mock_wait = mocker.patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup")
+ with connect_remote_engine_cores(
+ vllm_config=vllm_config,
+ omni_master_server=omni_master_server,
+ stage_id=7,
+ ) as (_, yielded_coordinator, yielded_addresses):
+ assert yielded_coordinator is None
+ assert yielded_addresses.coordinator_input == "tcp://coord-in"
+ assert yielded_addresses.coordinator_output == "tcp://coord-out"
+ assert yielded_addresses.frontend_stats_publish_address == "tcp://stats"
omni_master_server.get_stage_coordinator_addresses.assert_called_once_with(7)
mock_wait.assert_called_once()
- def test_defaults_to_no_coordinator_addresses_when_none_registered(self):
+ def test_defaults_to_no_coordinator_addresses_when_none_registered(self, mocker: MockerFixture):
vllm_config = self._build_vllm_config(
+ mocker,
dp_rank=0,
offline_mode=False,
needs_dp_coordinator=True,
)
- omni_master_server = Mock(spec=OmniMasterServer)
+ omni_master_server = mocker.Mock(spec=OmniMasterServer)
omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses(
inputs=["tcp://client-in"], outputs=["tcp://client-out"]
)
- omni_master_server.get_allocation.return_value = Mock(handshake_bind_address="tcp://127.0.0.1:26001")
+ omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
omni_master_server.get_stage_coordinator_addresses.return_value = StageCoordinatorAddresses()
@contextmanager
def fake_socket_ctx(*args, **kwargs):
- yield Mock()
+ yield mocker.Mock()
- with (
- patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx()),
- patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup"),
- ):
- with connect_remote_engine_cores(
- vllm_config=vllm_config,
- omni_master_server=omni_master_server,
- stage_id=7,
- ) as (_, yielded_coordinator, yielded_addresses):
- assert yielded_coordinator is None
- assert yielded_addresses.coordinator_input is None
- assert yielded_addresses.coordinator_output is None
- assert yielded_addresses.frontend_stats_publish_address is None
+ mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
+ return_value=fake_socket_ctx(),
+ )
+ mocker.patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup")
+ with connect_remote_engine_cores(
+ vllm_config=vllm_config,
+ omni_master_server=omni_master_server,
+ stage_id=7,
+ ) as (_, yielded_coordinator, yielded_addresses):
+ assert yielded_coordinator is None
+ assert yielded_addresses.coordinator_input is None
+ assert yielded_addresses.coordinator_output is None
+ assert yielded_addresses.frontend_stats_publish_address is None
class TestLaunchOmniCoreEngines:
"""Tests for local omni engine launch wiring."""
- def test_registers_stage_once_and_reuses_handshake_for_all_local_engines(self):
- parallel_config = Mock(
+ def test_registers_stage_once_and_reuses_handshake_for_all_local_engines(self, mocker: MockerFixture):
+ parallel_config = mocker.Mock(
data_parallel_size_local=2,
data_parallel_size=4,
data_parallel_rank=3,
)
- vllm_config = Mock(parallel_config=parallel_config)
+ vllm_config = mocker.Mock(parallel_config=parallel_config)
- omni_master_server = Mock(spec=OmniMasterServer)
+ omni_master_server = mocker.Mock(spec=OmniMasterServer)
omni_master_server.address = "127.0.0.1"
omni_master_server.port = 26000
- omni_master_server.get_allocation.return_value = Mock(handshake_bind_address="tcp://127.0.0.1:26001")
+ omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
stage_config = {"stage_id": 7, "stage_type": "llm"}
- local_engine_manager = Mock()
+ local_engine_manager = mocker.Mock()
@contextmanager
def fake_socket_ctx(*args, **kwargs):
- yield Mock()
-
- with (
- patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value="tcp://127.0.0.1:26001",
- ) as mock_register,
- patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx()),
- patch(
- "vllm_omni.engine.stage_engine_startup.CoreEngineProcManager",
- return_value=local_engine_manager,
- ) as mock_manager_cls,
- patch("vllm_omni.engine.stage_engine_startup.wait_for_engine_startup"),
- ):
- with launch_omni_core_engines(
- vllm_config=vllm_config,
- executor_class=Mock(),
- log_stats=False,
- omni_master_server=omni_master_server,
- stage_id=7,
- stage_config=stage_config,
- ) as (yielded_manager, yielded_coordinator, yielded_addresses):
- assert yielded_manager is local_engine_manager
- assert yielded_coordinator is None
+ yield mocker.Mock()
+
+ mock_register = mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
+ return_value="tcp://127.0.0.1:26001",
+ )
+ mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
+ return_value=fake_socket_ctx(),
+ )
+ mock_manager_cls = mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.CoreEngineProcManager",
+ return_value=local_engine_manager,
+ )
+ mocker.patch("vllm_omni.engine.stage_engine_startup.wait_for_engine_startup")
+ with launch_omni_core_engines(
+ vllm_config=vllm_config,
+ executor_class=mocker.Mock(),
+ log_stats=False,
+ omni_master_server=omni_master_server,
+ stage_id=7,
+ stage_config=stage_config,
+ ) as (yielded_manager, yielded_coordinator, yielded_addresses):
+ assert yielded_manager is local_engine_manager
+ assert yielded_coordinator is None
mock_register.assert_called_once_with(
omni_master_address="127.0.0.1",
@@ -1292,55 +1477,56 @@ def fake_socket_ctx(*args, **kwargs):
assert manager_kwargs["handshake_address"] == "tcp://127.0.0.1:26001"
assert manager_kwargs["executor_class"] is not None
- def test_registers_stage_with_coordinator_when_started(self):
- parallel_config = Mock(
+ def test_registers_stage_with_coordinator_when_started(self, mocker: MockerFixture):
+ parallel_config = mocker.Mock(
data_parallel_size_local=1,
data_parallel_size=2,
data_parallel_rank=0,
)
- vllm_config = Mock(parallel_config=parallel_config)
+ vllm_config = mocker.Mock(parallel_config=parallel_config)
vllm_config.needs_dp_coordinator = True
- vllm_config.model_config = Mock(is_moe=False)
+ vllm_config.model_config = mocker.Mock(is_moe=False)
- omni_master_server = Mock(spec=OmniMasterServer)
+ omni_master_server = mocker.Mock(spec=OmniMasterServer)
omni_master_server.address = "127.0.0.1"
omni_master_server.port = 26000
omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses(
inputs=["tcp://client-in"], outputs=["tcp://client-out"]
)
- omni_master_server.get_allocation.return_value = Mock(handshake_bind_address="tcp://127.0.0.1:26001")
+ omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
- coordinator = Mock()
+ coordinator = mocker.Mock()
coordinator.proc.pid = 1234
coordinator.get_engine_socket_addresses.return_value = ("tcp://coord-in", "tcp://coord-out")
coordinator.get_stats_publish_address.return_value = "tcp://stats"
@contextmanager
def fake_socket_ctx(*args, **kwargs):
- yield Mock()
-
- with (
- patch("vllm_omni.engine.stage_engine_startup.DPCoordinator", return_value=coordinator),
- patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value="tcp://127.0.0.1:26001",
- ) as mock_register,
- patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx()),
- patch(
- "vllm_omni.engine.stage_engine_startup.CoreEngineProcManager",
- return_value=Mock(),
- ) as mock_manager_cls,
- patch("vllm_omni.engine.stage_engine_startup.wait_for_engine_startup") as mock_wait,
+ yield mocker.Mock()
+
+ mocker.patch("vllm_omni.engine.stage_engine_startup.DPCoordinator", return_value=coordinator)
+ mock_register = mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
+ return_value="tcp://127.0.0.1:26001",
+ )
+ mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
+ return_value=fake_socket_ctx(),
+ )
+ mock_manager_cls = mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.CoreEngineProcManager",
+ return_value=mocker.Mock(),
+ )
+ mock_wait = mocker.patch("vllm_omni.engine.stage_engine_startup.wait_for_engine_startup")
+ with launch_omni_core_engines(
+ vllm_config=vllm_config,
+ executor_class=mocker.Mock(),
+ log_stats=False,
+ omni_master_server=omni_master_server,
+ stage_id=7,
+ stage_config={"stage_id": 7},
):
- with launch_omni_core_engines(
- vllm_config=vllm_config,
- executor_class=Mock(),
- log_stats=False,
- omni_master_server=omni_master_server,
- stage_id=7,
- stage_config={"stage_id": 7},
- ):
- pass
+ pass
mock_register.assert_called_once_with(
omni_master_address="127.0.0.1",
@@ -1363,19 +1549,19 @@ class TestLaunchLlmStageSingleStageMode:
"""Test that _launch_llm_stage selects launch_omni_core_engines when
single_stage_mode=True and _omni_master_server is set."""
- def _build_engine_with_oms(self) -> AsyncOmniEngine:
+ def _build_engine_with_oms(self, mocker: MockerFixture) -> AsyncOmniEngine:
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
engine.single_stage_mode = True
engine._single_stage_id_filter = 0
engine._llm_stage_launch_lock = threading.Lock()
- mock_oms = Mock(spec=OmniMasterServer)
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
mock_oms.address = "127.0.0.1"
mock_oms.port = 25000
- alloc = Mock()
+ alloc = mocker.Mock()
alloc.handshake_bind_address = "tcp://127.0.0.1:25001"
mock_oms.get_allocation.return_value = alloc
- fake_addresses = Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
@@ -1383,66 +1569,60 @@ def _build_engine_with_oms(self) -> AsyncOmniEngine:
engine._omni_master_server = mock_oms
return engine
- @contextmanager
- def _patch_launch_omni_cm(self, stage_id: int):
- fake_vllm_config = Mock()
- fake_executor_cls = Mock()
- fake_addresses = Mock()
+ def _mock_launch_omni(self, mocker: MockerFixture, stage_id: int):
+ fake_vllm_config = mocker.Mock()
+ fake_executor_cls = mocker.Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
- eng_mgr = Mock()
+ eng_mgr = mocker.Mock()
@contextmanager
def fake_launch_omni(*args, **kwargs):
yield eng_mgr, None, fake_addresses
- with (
- patch("vllm_omni.engine.async_omni_engine.setup_stage_devices"),
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": stage_id},
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.acquire_device_locks",
- return_value=[],
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.release_device_locks",
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- return_value=fake_launch_omni(),
- ) as mock_launch_omni,
- ):
- yield mock_launch_omni
+ mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": stage_id},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(fake_vllm_config, fake_executor_cls),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.acquire_device_locks",
+ return_value=[],
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
+ return mocker.patch(
+ "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
+ return_value=fake_launch_omni(),
+ )
- def test_launch_omni_core_engines_used_in_single_stage_mode(self):
+ def test_launch_omni_core_engines_used_in_single_stage_mode(self, mocker: MockerFixture):
"""single_stage_mode + _omni_master_server → launch_omni_core_engines."""
- engine = self._build_engine_with_oms()
- metadata = Mock(stage_id=0, runtime_cfg={})
+ engine = self._build_engine_with_oms(mocker)
+ metadata = mocker.Mock(stage_id=0, runtime_cfg={})
stage_cfg = _make_stage_cfg(0)
- with self._patch_launch_omni_cm(0) as mock_launch_omni:
- result = engine._launch_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
+ mock_launch_omni = self._mock_launch_omni(mocker, 0)
+ result = engine._launch_llm_stage(
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ llm_stage_launch_lock=threading.Lock(),
+ )
mock_launch_omni.assert_called_once()
assert mock_launch_omni.call_args.kwargs["stage_config"] is stage_cfg
assert isinstance(result, StartedLlmStage)
assert result.stage_id == 0
- def test_spawn_stage_core_used_in_normal_mode(self):
+ def test_spawn_stage_core_used_in_normal_mode(self, mocker: MockerFixture):
"""~single_stage_mode → spawn_stage_core + complete_stage_handshake."""
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
@@ -1450,44 +1630,45 @@ def test_spawn_stage_core_used_in_normal_mode(self):
engine._omni_master_server = None
engine._llm_stage_launch_lock = threading.Lock()
- fake_vllm_config = Mock()
- fake_executor_cls = Mock()
- fake_addresses = Mock()
+ fake_vllm_config = mocker.Mock()
+ fake_executor_cls = mocker.Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
- fake_proc = Mock()
+ fake_proc = mocker.Mock()
fake_handshake_address = "ipc:///tmp/fake-handshake"
stage_init_timeout = 60
- with (
- patch("vllm_omni.engine.async_omni_engine.setup_stage_devices"),
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- ),
- patch("vllm_omni.engine.async_omni_engine.acquire_device_locks", return_value=[]),
- patch("vllm_omni.engine.async_omni_engine.release_device_locks"),
- patch(
- "vllm_omni.engine.async_omni_engine.spawn_stage_core",
- return_value=(fake_addresses, fake_proc, fake_handshake_address),
- ) as mock_spawn,
- patch("vllm_omni.engine.async_omni_engine.complete_stage_handshake") as mock_handshake,
- patch("vllm_omni.engine.async_omni_engine.launch_omni_core_engines") as mock_omni,
- ):
- metadata = Mock(stage_id=0, runtime_cfg={})
- result = engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=stage_init_timeout,
- llm_stage_launch_lock=threading.Lock(),
- )
+ mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": 0},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(fake_vllm_config, fake_executor_cls),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.acquire_device_locks",
+ return_value=[],
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
+ mock_spawn = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.spawn_stage_core",
+ return_value=(fake_addresses, fake_proc, fake_handshake_address),
+ )
+ mock_handshake = mocker.patch("vllm_omni.engine.async_omni_engine.complete_stage_handshake")
+ mock_omni = mocker.patch("vllm_omni.engine.async_omni_engine.launch_omni_core_engines")
+ metadata = mocker.Mock(stage_id=0, runtime_cfg={})
+ result = engine._launch_llm_stage(
+ stage_cfg=_make_stage_cfg(0),
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=stage_init_timeout,
+ llm_stage_launch_lock=threading.Lock(),
+ )
mock_spawn.assert_called_once_with(
vllm_config=fake_vllm_config,
@@ -1505,50 +1686,58 @@ def test_spawn_stage_core_used_in_normal_mode(self):
assert isinstance(result, StartedLlmStage)
assert result.proc is fake_proc
- def test_launch_omni_passes_stage_id_and_master_server(self):
+ def test_launch_omni_passes_stage_id_and_master_server(self, mocker: MockerFixture):
"""launch_omni_core_engines receives the correct stage_id and omni_master_server."""
- engine = self._build_engine_with_oms()
- metadata = Mock(stage_id=0, runtime_cfg={})
+ engine = self._build_engine_with_oms(mocker)
+ metadata = mocker.Mock(stage_id=0, runtime_cfg={})
captured_kwargs: dict[str, Any] = {}
@contextmanager
def capturing_launch(*args, **kwargs):
captured_kwargs.update(kwargs)
- fake_addresses = Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
- yield Mock(), None, fake_addresses
+ yield mocker.Mock(), None, fake_addresses
- with (
- patch("vllm_omni.engine.async_omni_engine.setup_stage_devices"),
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- ),
- patch("vllm_omni.engine.async_omni_engine.build_vllm_config", return_value=(Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.acquire_device_locks", return_value=[]),
- patch("vllm_omni.engine.async_omni_engine.release_device_locks"),
- patch("vllm_omni.engine.async_omni_engine.launch_omni_core_engines", side_effect=capturing_launch),
- ):
- engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
+ mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": 0},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.acquire_device_locks",
+ return_value=[],
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
+ side_effect=capturing_launch,
+ )
+
+ engine._launch_llm_stage(
+ stage_cfg=_make_stage_cfg(0),
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ llm_stage_launch_lock=threading.Lock(),
+ )
assert captured_kwargs.get("stage_id") == 0
assert captured_kwargs.get("omni_master_server") is engine._omni_master_server
- def test_launch_omni_context_exits_before_stage_cleanup_on_error(self):
+ def test_launch_omni_context_exits_before_stage_cleanup_on_error(self, mocker: MockerFixture):
"""Errors after entering the omni launch context still unwind it first."""
- engine = self._build_engine_with_oms()
- metadata = Mock(stage_id=0, runtime_cfg={})
+ engine = self._build_engine_with_oms(mocker)
+ metadata = mocker.Mock(stage_id=0, runtime_cfg={})
- fake_addresses = Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
@@ -1558,47 +1747,51 @@ def test_launch_omni_context_exits_before_stage_cleanup_on_error(self):
@contextmanager
def fake_launch_omni(*args, **kwargs):
try:
- yield Mock(), None, fake_addresses
+ yield mocker.Mock(), None, fake_addresses
finally:
events.append("launch_exit")
- with (
- patch("vllm_omni.engine.async_omni_engine.setup_stage_devices"),
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- ),
- patch("vllm_omni.engine.async_omni_engine.build_vllm_config", return_value=(Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.acquire_device_locks", return_value=[]),
- patch("vllm_omni.engine.async_omni_engine.release_device_locks"),
- patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- return_value=fake_launch_omni(),
- ),
- patch("vllm_omni.engine.async_omni_engine.logger.info", side_effect=RuntimeError("boom")),
- patch(
- "vllm_omni.engine.async_omni_engine.close_started_llm_stage",
- side_effect=lambda _started: events.append("stage_close"),
- ) as mock_close_stage,
- ):
- with pytest.raises(RuntimeError, match="boom"):
- engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
+ mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": 0},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.acquire_device_locks",
+ return_value=[],
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
+ return_value=fake_launch_omni(),
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.logger.info", side_effect=RuntimeError("boom"))
+ mock_close_stage = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.close_started_llm_stage",
+ side_effect=lambda _started: events.append("stage_close"),
+ )
+ with pytest.raises(RuntimeError, match="boom"):
+ engine._launch_llm_stage(
+ stage_cfg=_make_stage_cfg(0),
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ llm_stage_launch_lock=threading.Lock(),
+ )
mock_close_stage.assert_called_once()
assert events == ["launch_exit", "stage_close"]
- def test_base_exception_propagates_without_started_stage_cleanup(self):
+ def test_base_exception_propagates_without_started_stage_cleanup(self, mocker: MockerFixture):
"""BaseException subclasses should bypass the Exception cleanup path."""
- engine = self._build_engine_with_oms()
- metadata = Mock(stage_id=0, runtime_cfg={})
+ engine = self._build_engine_with_oms(mocker)
+ metadata = mocker.Mock(stage_id=0, runtime_cfg={})
- fake_addresses = Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
@@ -1611,37 +1804,41 @@ class FatalLaunchInterrupt(BaseException):
@contextmanager
def fake_launch_omni(*args, **kwargs):
try:
- yield Mock(), None, fake_addresses
+ yield mocker.Mock(), None, fake_addresses
finally:
events.append("launch_exit")
- with (
- patch("vllm_omni.engine.async_omni_engine.setup_stage_devices"),
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- ),
- patch("vllm_omni.engine.async_omni_engine.build_vllm_config", return_value=(Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.acquire_device_locks", return_value=[]),
- patch("vllm_omni.engine.async_omni_engine.release_device_locks"),
- patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- return_value=fake_launch_omni(),
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.logger.info",
- side_effect=FatalLaunchInterrupt("stop"),
- ),
- patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage") as mock_close_stage,
- ):
- with pytest.raises(FatalLaunchInterrupt, match="stop"):
- engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
+ mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": 0},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.acquire_device_locks",
+ return_value=[],
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
+ return_value=fake_launch_omni(),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.logger.info",
+ side_effect=FatalLaunchInterrupt("stop"),
+ )
+ mock_close_stage = mocker.patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage")
+ with pytest.raises(FatalLaunchInterrupt, match="stop"):
+ engine._launch_llm_stage(
+ stage_cfg=_make_stage_cfg(0),
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ llm_stage_launch_lock=threading.Lock(),
+ )
mock_close_stage.assert_not_called()
assert events == ["launch_exit"]
diff --git a/tests/entrypoints/openai_api/test_serving_chat_speaker.py b/tests/entrypoints/openai_api/test_serving_chat_speaker.py
index 3b9151120e0..97c05e45b41 100644
--- a/tests/entrypoints/openai_api/test_serving_chat_speaker.py
+++ b/tests/entrypoints/openai_api/test_serving_chat_speaker.py
@@ -4,9 +4,9 @@
import asyncio
from types import SimpleNamespace
-from unittest.mock import AsyncMock, MagicMock
import pytest
+from pytest_mock import MockerFixture
from vllm_omni.entrypoints.openai.utils import (
get_supported_speakers_from_hf_config,
@@ -25,9 +25,9 @@ def serving_chat():
return instance
-def _make_hf_config(*, speaker_id: dict | None = None, spk_id: dict | None = None):
- hf_config = MagicMock()
- talker_config = MagicMock()
+def _make_hf_config(mocker: MockerFixture, *, speaker_id: dict | None = None, spk_id: dict | None = None):
+ hf_config = mocker.MagicMock()
+ talker_config = mocker.MagicMock()
talker_config.speaker_id = speaker_id
talker_config.spk_id = spk_id
hf_config.talker_config = talker_config
@@ -51,14 +51,14 @@ def test_validate_requested_speaker_skips_validation_when_supported_empty():
assert validate_requested_speaker(" ", {"vivian"}) is None
-def test_get_supported_speakers_from_hf_config_uses_spk_id_fallback():
- hf_config = _make_hf_config(speaker_id=None, spk_id={"Serena": 0})
+def test_get_supported_speakers_from_hf_config_uses_spk_id_fallback(mocker: MockerFixture):
+ hf_config = _make_hf_config(mocker, speaker_id=None, spk_id={"Serena": 0})
assert get_supported_speakers_from_hf_config(hf_config) == {"serena"}
-def test_get_supported_speakers_caches_normalized_keys(serving_chat):
- serving_chat.model_config = MagicMock()
- serving_chat.model_config.hf_config = _make_hf_config(speaker_id={"Vivian": 0, "Ethan": 1})
+def test_get_supported_speakers_caches_normalized_keys(mocker: MockerFixture, serving_chat):
+ serving_chat.model_config = mocker.MagicMock()
+ serving_chat.model_config.hf_config = _make_hf_config(mocker, speaker_id={"Vivian": 0, "Ethan": 1})
assert serving_chat._get_supported_speakers() == {"vivian", "ethan"}
@@ -67,15 +67,15 @@ def test_get_supported_speakers_caches_normalized_keys(serving_chat):
assert serving_chat._get_supported_speakers() == {"vivian", "ethan"}
-def test_create_chat_completion_converts_value_error_to_error_response(serving_chat):
+def test_create_chat_completion_converts_value_error_to_error_response(mocker: MockerFixture, serving_chat):
serving_chat._diffusion_mode = False
- serving_chat._check_model = AsyncMock(return_value=None)
- serving_chat.engine_client = MagicMock(errored=False)
- serving_chat._maybe_get_adapters = MagicMock(return_value=None)
- serving_chat.models = MagicMock()
+ serving_chat._check_model = mocker.AsyncMock(return_value=None)
+ serving_chat.engine_client = mocker.MagicMock(errored=False)
+ serving_chat._maybe_get_adapters = mocker.MagicMock(return_value=None)
+ serving_chat.models = mocker.MagicMock()
serving_chat.models.model_name.return_value = "test-model"
- serving_chat.renderer = MagicMock()
- serving_chat.renderer.get_tokenizer.return_value = MagicMock()
+ serving_chat.renderer = mocker.MagicMock()
+ serving_chat.renderer.get_tokenizer.return_value = mocker.MagicMock()
serving_chat.reasoning_parser_cls = None
serving_chat.tool_parser = None
serving_chat.use_harmony = False
@@ -85,12 +85,12 @@ def test_create_chat_completion_converts_value_error_to_error_response(serving_c
serving_chat.chat_template = None
serving_chat.chat_template_content_format = "string"
serving_chat.default_chat_template_kwargs = {}
- serving_chat._validate_chat_template = MagicMock(return_value=None)
- serving_chat._prepare_extra_chat_template_kwargs = MagicMock(return_value={})
- serving_chat._preprocess_chat = AsyncMock(
+ serving_chat._validate_chat_template = mocker.MagicMock(return_value=None)
+ serving_chat._prepare_extra_chat_template_kwargs = mocker.MagicMock(return_value={})
+ serving_chat._preprocess_chat = mocker.AsyncMock(
side_effect=ValueError("Invalid speaker 'uncle_fu'. Supported: ethan, vivian")
)
- serving_chat.create_error_response = MagicMock(return_value="error-response")
+ serving_chat.create_error_response = mocker.MagicMock(return_value="error-response")
request = SimpleNamespace(
tool_choice=None,
diff --git a/tests/entrypoints/openai_api/test_serving_speech.py b/tests/entrypoints/openai_api/test_serving_speech.py
index 06b6f5c16c1..c8841206207 100644
--- a/tests/entrypoints/openai_api/test_serving_speech.py
+++ b/tests/entrypoints/openai_api/test_serving_speech.py
@@ -6,7 +6,6 @@
from inspect import Signature, signature
from pathlib import Path
from types import SimpleNamespace
-from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
@@ -901,7 +900,7 @@ def test_load_supported_speakers(self, mocker: MockerFixture):
# Verify speakers are normalized to lowercase
assert server.supported_speakers == {"ryan", "vivian", "aiden"}
- def test_build_tts_params_with_uploaded_voice(self, speech_server):
+ def test_build_tts_params_with_uploaded_voice(self, speech_server, mocker: MockerFixture):
"""Test _build_tts_params auto-sets ref_audio for uploaded voices (x_vector only)."""
speech_server.uploaded_speakers = {
"custom_voice": {
@@ -914,18 +913,18 @@ def test_build_tts_params_with_uploaded_voice(self, speech_server):
}
speech_server.supported_speakers = {"ryan", "vivian", "custom_voice"}
- with patch.object(speech_server, "_get_uploaded_audio_data") as mock_get_audio:
- mock_get_audio.return_value = "data:audio/wav;base64,ZmFrZWF1ZGlv"
- req = OpenAICreateSpeechRequest(input="Hello", voice="custom_voice")
- params = speech_server._build_tts_params(req)
+ mock_get_audio = mocker.patch.object(speech_server, "_get_uploaded_audio_data")
+ mock_get_audio.return_value = "data:audio/wav;base64,ZmFrZWF1ZGlv"
+ req = OpenAICreateSpeechRequest(input="Hello", voice="custom_voice")
+ params = speech_server._build_tts_params(req)
- assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"]
- assert params["x_vector_only_mode"] == [True]
- assert params["task_type"] == ["Base"]
- assert params["voice_created_at"] == [1711234567.89]
- assert "ref_text" not in params
+ assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"]
+ assert params["x_vector_only_mode"] == [True]
+ assert params["task_type"] == ["Base"]
+ assert params["voice_created_at"] == [1711234567.89]
+ assert "ref_text" not in params
- def test_build_tts_params_with_uploaded_voice_ref_text(self, speech_server):
+ def test_build_tts_params_with_uploaded_voice_ref_text(self, speech_server, mocker: MockerFixture):
"""Test _build_tts_params enables in-context cloning when ref_text is stored."""
speech_server.uploaded_speakers = {
"custom_voice": {
@@ -938,16 +937,16 @@ def test_build_tts_params_with_uploaded_voice_ref_text(self, speech_server):
}
speech_server.supported_speakers = {"ryan", "vivian", "custom_voice"}
- with patch.object(speech_server, "_get_uploaded_audio_data") as mock_get_audio:
- mock_get_audio.return_value = "data:audio/wav;base64,ZmFrZWF1ZGlv"
- req = OpenAICreateSpeechRequest(input="Hello", voice="custom_voice")
- params = speech_server._build_tts_params(req)
+ mock_get_audio = mocker.patch.object(speech_server, "_get_uploaded_audio_data")
+ mock_get_audio.return_value = "data:audio/wav;base64,ZmFrZWF1ZGlv"
+ req = OpenAICreateSpeechRequest(input="Hello", voice="custom_voice")
+ params = speech_server._build_tts_params(req)
- assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"]
- assert params["x_vector_only_mode"] == [False]
- assert params["task_type"] == ["Base"]
- assert params["ref_text"] == ["Hello world transcript"]
- assert params["voice_created_at"] == [1711234567.89]
+ assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"]
+ assert params["x_vector_only_mode"] == [False]
+ assert params["task_type"] == ["Base"]
+ assert params["ref_text"] == ["Hello world transcript"]
+ assert params["voice_created_at"] == [1711234567.89]
def test_build_tts_params_without_uploaded_voice(self, speech_server):
"""Test _build_tts_params does not auto-set ref_audio for non-uploaded voices."""
@@ -989,45 +988,43 @@ def test_build_tts_params_with_explicit_ref_audio(self, speech_server):
# x_vector_only_mode should not be set when explicit ref_audio is provided
assert "x_vector_only_mode" not in params
- def test_get_uploaded_audio_data(self, speech_server):
+ def test_get_uploaded_audio_data(self, speech_server, mocker: MockerFixture):
"""Test _get_uploaded_audio_data function."""
# Mock file operations
- with (
- patch("builtins.open", create=True) as mock_open,
- patch("base64.b64encode") as mock_b64encode,
- patch("pathlib.Path.exists") as mock_exists,
- ):
- mock_exists.return_value = True
- mock_b64encode.return_value = b"ZmFrZWF1ZGlv"
-
- # Setup mock file
- mock_file = MagicMock()
- mock_file.read.return_value = b"fakeaudio"
- mock_open.return_value.__enter__.return_value = mock_file
-
- # Setup uploaded speaker
- speech_server.uploaded_speakers = {
- "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"}
- }
- result = speech_server._get_uploaded_audio_data("test_voice")
+ mock_open = mocker.patch("builtins.open", create=True)
+ mock_b64encode = mocker.patch("base64.b64encode")
+ mock_exists = mocker.patch("pathlib.Path.exists")
+ mock_exists.return_value = True
+ mock_b64encode.return_value = b"ZmFrZWF1ZGlv"
+
+ # Setup mock file
+ mock_file = mocker.MagicMock()
+ mock_file.read.return_value = b"fakeaudio"
+ mock_open.return_value.__enter__.return_value = mock_file
+
+ # Setup uploaded speaker
+ speech_server.uploaded_speakers = {
+ "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"}
+ }
+ result = speech_server._get_uploaded_audio_data("test_voice")
- assert result == "data:audio/wav;base64,ZmFrZWF1ZGlv"
- mock_open.assert_called_once_with(Path("/tmp/test.wav"), "rb")
- mock_b64encode.assert_called_once_with(b"fakeaudio")
+ assert result == "data:audio/wav;base64,ZmFrZWF1ZGlv"
+ mock_open.assert_called_once_with(Path("/tmp/test.wav"), "rb")
+ mock_b64encode.assert_called_once_with(b"fakeaudio")
- def test_get_uploaded_audio_data_missing_file(self, speech_server):
+ def test_get_uploaded_audio_data_missing_file(self, speech_server, mocker: MockerFixture):
"""Test _get_uploaded_audio_data when file is missing."""
- with patch("pathlib.Path.exists") as mock_exists:
- mock_exists.return_value = False
+ mock_exists = mocker.patch("pathlib.Path.exists")
+ mock_exists.return_value = False
- # Setup uploaded speaker
- speech_server.uploaded_speakers = {
- "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"}
- }
+ # Setup uploaded speaker
+ speech_server.uploaded_speakers = {
+ "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"}
+ }
- result = speech_server._get_uploaded_audio_data("test_voice")
+ result = speech_server._get_uploaded_audio_data("test_voice")
- assert result is None
+ assert result is None
def test_get_uploaded_audio_data_voice_not_found(self, speech_server):
"""Test _get_uploaded_audio_data when voice is not in uploaded_speakers."""
@@ -1049,7 +1046,7 @@ def test_voice_field_still_accepted(self):
req = OpenAICreateSpeechRequest.model_validate({"input": "Hello", "voice": "custom_voice"})
assert req.voice == "custom_voice"
- def test_speaker_alias_in_base_task_with_uploaded_voice(self, speech_server):
+ def test_speaker_alias_in_base_task_with_uploaded_voice(self, speech_server, mocker: MockerFixture):
"""Using 'speaker' key with an uploaded voice should work for Base task."""
speech_server.uploaded_speakers = {
"utesf": {
@@ -1061,13 +1058,13 @@ def test_speaker_alias_in_base_task_with_uploaded_voice(self, speech_server):
}
req = OpenAICreateSpeechRequest.model_validate({"input": "Hello", "speaker": "UTESF", "task_type": "Base"})
assert req.voice == "UTESF"
- with patch("pathlib.Path.exists", return_value=True):
- result = speech_server._validate_qwen_tts_request(req)
+ mocker.patch("pathlib.Path.exists", return_value=True)
+ result = speech_server._validate_qwen_tts_request(req)
assert result is None
# ── uploaded voice with embedding ──
- def test_build_tts_params_with_uploaded_voice_embedding(self, speech_server):
+ def test_build_tts_params_with_uploaded_voice_embedding(self, speech_server, mocker: MockerFixture):
"""Test _build_tts_params loads embedding for embedding-uploaded voices."""
speech_server.uploaded_speakers = {
"emb_voice": {
@@ -1083,20 +1080,20 @@ def test_build_tts_params_with_uploaded_voice_embedding(self, speech_server):
speech_server.supported_speakers = {"ryan", "vivian", "emb_voice"}
fake_embedding = [0.1] * 1024
- with patch.object(speech_server, "_get_uploaded_speaker_embedding") as mock_get_emb:
- mock_get_emb.return_value = fake_embedding
- req = OpenAICreateSpeechRequest(input="Hello", voice="emb_voice")
- params = speech_server._build_tts_params(req)
+ mock_get_emb = mocker.patch.object(speech_server, "_get_uploaded_speaker_embedding")
+ mock_get_emb.return_value = fake_embedding
+ req = OpenAICreateSpeechRequest(input="Hello", voice="emb_voice")
+ params = speech_server._build_tts_params(req)
- assert "voice_clone_prompt" in params
- assert params["voice_clone_prompt"][0]["ref_spk_embedding"] == fake_embedding
- assert params["task_type"] == ["Base"]
- assert params["x_vector_only_mode"] == [True]
- assert "ref_audio" not in params
+ assert "voice_clone_prompt" in params
+ assert params["voice_clone_prompt"][0]["ref_spk_embedding"] == fake_embedding
+ assert params["task_type"] == ["Base"]
+ assert params["x_vector_only_mode"] == [True]
+ assert "ref_audio" not in params
# ── regression: full flow from issue #1603 ──
- def test_regression_1603_speaker_key_with_uploaded_audio_voice(self, speech_server):
+ def test_regression_1603_speaker_key_with_uploaded_audio_voice(self, speech_server, mocker: MockerFixture):
"""Regression test for #1603: upload audio voice, then invoke TTS with 'speaker' key.
Verifies the full validate → build_params pipeline works end-to-end.
@@ -1116,14 +1113,14 @@ def test_regression_1603_speaker_key_with_uploaded_audio_voice(self, speech_serv
assert req.voice == "UTESF"
# Validation should pass (file exists)
- with patch("pathlib.Path.exists", return_value=True):
- err = speech_server._validate_qwen_tts_request(req)
+ mocker.patch("pathlib.Path.exists", return_value=True)
+ err = speech_server._validate_qwen_tts_request(req)
assert err is None, f"Validation failed: {err}"
# Build params should auto-set ref_audio from stored file
- with patch.object(speech_server, "_get_uploaded_audio_data") as mock_audio:
- mock_audio.return_value = "data:audio/wav;base64,ZmFrZQ=="
- params = speech_server._build_tts_params(req)
+ mock_audio = mocker.patch.object(speech_server, "_get_uploaded_audio_data")
+ mock_audio.return_value = "data:audio/wav;base64,ZmFrZQ=="
+ params = speech_server._build_tts_params(req)
assert params["task_type"] == ["Base"]
assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZQ=="]
@@ -1131,7 +1128,7 @@ def test_regression_1603_speaker_key_with_uploaded_audio_voice(self, speech_serv
assert params["x_vector_only_mode"] == [False]
assert params["speaker"] == ["utesf"]
- def test_regression_1603_speaker_key_with_uploaded_embedding_voice(self, speech_server):
+ def test_regression_1603_speaker_key_with_uploaded_embedding_voice(self, speech_server, mocker: MockerFixture):
"""Regression test for #1603: upload embedding voice, then invoke TTS with 'speaker' key.
Verifies embedding-uploaded voices are loaded as voice_clone_prompt, not as audio.
@@ -1154,15 +1151,15 @@ def test_regression_1603_speaker_key_with_uploaded_embedding_voice(self, speech_
assert req.voice == "myvoice"
# Validation should pass
- with patch("pathlib.Path.exists", return_value=True):
- err = speech_server._validate_qwen_tts_request(req)
+ mocker.patch("pathlib.Path.exists", return_value=True)
+ err = speech_server._validate_qwen_tts_request(req)
assert err is None, f"Validation failed: {err}"
# Build params should use embedding, NOT audio
fake_emb = [0.1] * 1024
- with patch.object(speech_server, "_get_uploaded_speaker_embedding") as mock_emb:
- mock_emb.return_value = fake_emb
- params = speech_server._build_tts_params(req)
+ mock_emb = mocker.patch.object(speech_server, "_get_uploaded_speaker_embedding")
+ mock_emb.return_value = fake_emb
+ params = speech_server._build_tts_params(req)
assert params["task_type"] == ["Base"]
assert params["x_vector_only_mode"] == [True]
@@ -1171,7 +1168,7 @@ def test_regression_1603_speaker_key_with_uploaded_embedding_voice(self, speech_
# Must NOT have ref_audio — that would fail for safetensors files
assert "ref_audio" not in params
- def test_validate_rejects_embedding_voice_with_pending_cache(self, speech_server):
+ def test_validate_rejects_embedding_voice_with_pending_cache(self, speech_server, mocker: MockerFixture):
"""Validation should reject embedding voices whose cache is not yet ready."""
speech_server.uploaded_speakers = {
"myvoice": {
@@ -1184,12 +1181,12 @@ def test_validate_rejects_embedding_voice_with_pending_cache(self, speech_server
}
}
req = OpenAICreateSpeechRequest.model_validate({"input": "Hello", "speaker": "myvoice", "task_type": "Base"})
- with patch("pathlib.Path.exists", return_value=True):
- err = speech_server._validate_qwen_tts_request(req)
+ mocker.patch("pathlib.Path.exists", return_value=True)
+ err = speech_server._validate_qwen_tts_request(req)
assert err is not None
assert "not yet ready" in err
- def test_x_vector_only_mode_not_overwritten_for_uploaded_embedding(self, speech_server):
+ def test_x_vector_only_mode_not_overwritten_for_uploaded_embedding(self, speech_server, mocker: MockerFixture):
"""x_vector_only_mode set by uploaded embedding must not be overwritten by request field."""
speech_server.uploaded_speakers = {
"emb_voice": {
@@ -1203,11 +1200,11 @@ def test_x_vector_only_mode_not_overwritten_for_uploaded_embedding(self, speech_
}
}
fake_emb = [0.1] * 1024
- with patch.object(speech_server, "_get_uploaded_speaker_embedding") as mock_emb:
- mock_emb.return_value = fake_emb
- # Client explicitly sends x_vector_only_mode=False, but embedding requires True
- req = OpenAICreateSpeechRequest(input="Hello", voice="emb_voice", x_vector_only_mode=False)
- params = speech_server._build_tts_params(req)
+ mock_emb = mocker.patch.object(speech_server, "_get_uploaded_speaker_embedding")
+ mock_emb.return_value = fake_emb
+ # Client explicitly sends x_vector_only_mode=False, but embedding requires True
+ req = OpenAICreateSpeechRequest(input="Hello", voice="emb_voice", x_vector_only_mode=False)
+ params = speech_server._build_tts_params(req)
assert params["x_vector_only_mode"] == [True]
assert "voice_clone_prompt" in params
@@ -1654,9 +1651,9 @@ async def test_omni_model_includes_generate(self):
assert "generate" in tasks
-def test_api_server_create_speech_wraps_error_response_status():
- handler = MagicMock()
- handler.create_speech = AsyncMock(
+def test_api_server_create_speech_wraps_error_response_status(mocker: MockerFixture):
+ handler = mocker.MagicMock()
+ handler.create_speech = mocker.AsyncMock(
return_value=ErrorResponse(
error=ErrorInfo(message="bad request", type="BadRequestError", param=None, code=400),
)
@@ -1851,9 +1848,9 @@ def test_build_fish_prompt_normalizes_legacy_speaker_tags(self, fish_speech_serv
assert "<|speaker:0|>你好,[laughing]欢迎回来。<|speaker:1|>我也来了。" in encoded_texts
assert all(allowed_special is None for _, _, allowed_special in tokenizer.calls)
- def test_build_fish_clone_prompt_normalizes_text_fields(self, fish_speech_server):
+ def test_build_fish_clone_prompt_normalizes_text_fields(self, fish_speech_server, mocker: MockerFixture):
fish_speech_server._fish_speech_tokenizer = _FakeFishTokenizer()
- fish_speech_server._estimate_fish_prompt_len = MagicMock(return_value=123)
+ fish_speech_server._estimate_fish_prompt_len = mocker.MagicMock(return_value=123)
request = OpenAICreateSpeechRequest(
input="你好,欢迎回来。",
@@ -1904,8 +1901,10 @@ def test_build_fish_prompt_rejects_unsafe_control_tokens(self, fish_speech_serve
with pytest.raises(ValueError, match="unsupported control token"):
fish_speech_server._build_fish_speech_prompt(request)
- def test_prepare_speech_generation_overrides_fish_default_max_tokens(self, fish_speech_server):
- fish_speech_server._build_fish_speech_prompt_async = AsyncMock(
+ def test_prepare_speech_generation_overrides_fish_default_max_tokens(
+ self, fish_speech_server, mocker: MockerFixture
+ ):
+ fish_speech_server._build_fish_speech_prompt_async = mocker.AsyncMock(
return_value={
"prompt_token_ids": [1, 2, 3],
"additional_information": {},
@@ -1924,8 +1923,8 @@ def test_prepare_speech_generation_overrides_fish_default_max_tokens(self, fish_
assert sampling_params_list[0].max_tokens == 4096
assert fish_speech_server.engine_client.default_sampling_params_list[0].max_tokens == 2048
- def test_prepare_speech_generation_uses_stage_default_max_tokens(self, fish_speech_server):
- fish_speech_server._build_fish_speech_prompt_async = AsyncMock(
+ def test_prepare_speech_generation_uses_stage_default_max_tokens(self, fish_speech_server, mocker: MockerFixture):
+ fish_speech_server._build_fish_speech_prompt_async = mocker.AsyncMock(
return_value={
"prompt_token_ids": [1, 2, 3],
"additional_information": {},
@@ -1956,9 +1955,9 @@ def test_prepare_speech_generation_rejects_invalid_fish_max_new_tokens(self, fis
fish_speech_server.engine_client.generate.assert_not_called()
- def test_create_speech_batch_allows_fish_text_only_items(self, fish_speech_server):
- fish_speech_server._check_model = AsyncMock(return_value=None)
- fish_speech_server._generate_audio_bytes = AsyncMock(return_value=("YWJj", "audio/wav"))
+ def test_create_speech_batch_allows_fish_text_only_items(self, fish_speech_server, mocker: MockerFixture):
+ fish_speech_server._check_model = mocker.AsyncMock(return_value=None)
+ fish_speech_server._generate_audio_bytes = mocker.AsyncMock(return_value=("YWJj", "audio/wav"))
batch = BatchSpeechRequest(items=[SpeechBatchItem(input="hello fish")])
response = asyncio.run(fish_speech_server.create_speech_batch(batch))
@@ -2154,8 +2153,8 @@ def test_validate_cosyvoice3_max_new_tokens_range(self, cosyvoice3_server):
assert error is not None
assert "max_new_tokens" in error
- def test_prepare_speech_generation_cosyvoice3(self, cosyvoice3_server):
- cosyvoice3_server._build_cosyvoice3_prompt = AsyncMock(
+ def test_prepare_speech_generation_cosyvoice3(self, cosyvoice3_server, mocker: MockerFixture):
+ cosyvoice3_server._build_cosyvoice3_prompt = mocker.AsyncMock(
return_value={
"prompt": "Hello",
"multi_modal_data": {"audio": (np.zeros(24000), 24000)},
@@ -2236,9 +2235,9 @@ def qwen3_tts_server(self, mocker: MockerFixture):
yield server
server.shutdown()
- def test_prepare_speech_generation_awaits_voxtral_async(self, voxtral_server):
+ def test_prepare_speech_generation_awaits_voxtral_async(self, voxtral_server, mocker: MockerFixture):
"""Voxtral path in _prepare_speech_generation should call the async wrapper."""
- voxtral_server._build_voxtral_prompt_async = AsyncMock(
+ voxtral_server._build_voxtral_prompt_async = mocker.AsyncMock(
return_value={
"prompt_token_ids": [1, 2, 3],
"additional_information": {"voice": ["test"]},
@@ -2248,13 +2247,13 @@ def test_prepare_speech_generation_awaits_voxtral_async(self, voxtral_server):
asyncio.run(voxtral_server._prepare_speech_generation(request))
voxtral_server._build_voxtral_prompt_async.assert_awaited_once()
- def test_prepare_speech_generation_awaits_qwen3_tts_async(self, qwen3_tts_server):
+ def test_prepare_speech_generation_awaits_qwen3_tts_async(self, qwen3_tts_server, mocker: MockerFixture):
"""Qwen3 TTS path should call _estimate_prompt_len_async."""
- qwen3_tts_server._validate_tts_request = MagicMock(return_value=None)
- qwen3_tts_server._build_tts_params = MagicMock(
+ qwen3_tts_server._validate_tts_request = mocker.MagicMock(return_value=None)
+ qwen3_tts_server._build_tts_params = mocker.MagicMock(
return_value={"text": ["hello"], "task_type": ["CustomVoice"], "speaker": ["Vivian"]}
)
- qwen3_tts_server._estimate_prompt_len_async = AsyncMock(return_value=512)
+ qwen3_tts_server._estimate_prompt_len_async = mocker.AsyncMock(return_value=512)
request = OpenAICreateSpeechRequest(input="hello")
asyncio.run(qwen3_tts_server._prepare_speech_generation(request))
qwen3_tts_server._build_tts_params.assert_called_once()
@@ -2281,8 +2280,8 @@ def test_shutdown_is_idempotent(self, mocker: MockerFixture):
server.shutdown() # Should not raise
assert server._tts_executor is None
- def test_diffusion_instance_shutdown_safe(self):
+ def test_diffusion_instance_shutdown_safe(self, mocker: MockerFixture):
"""Diffusion instances (created via for_diffusion) should have safe shutdown."""
- server = OmniOpenAIServingSpeech.for_diffusion(diffusion_engine=MagicMock(), model_name="test-model")
+ server = OmniOpenAIServingSpeech.for_diffusion(diffusion_engine=mocker.MagicMock(), model_name="test-model")
assert server._tts_executor is None
server.shutdown() # Should not raise
diff --git a/tests/entrypoints/openai_api/test_serving_speech_stream.py b/tests/entrypoints/openai_api/test_serving_speech_stream.py
index 1d26b5855f1..1b93ef58e24 100644
--- a/tests/entrypoints/openai_api/test_serving_speech_stream.py
+++ b/tests/entrypoints/openai_api/test_serving_speech_stream.py
@@ -1,8 +1,8 @@
import asyncio
-from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import FastAPI, WebSocket
+from pytest_mock import MockerFixture
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
@@ -13,19 +13,26 @@
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-def _build_test_app(speech_service=None, *, idle_timeout=30.0, config_timeout=10.0):
+def _build_test_app(
+ speech_service=None,
+ *,
+ idle_timeout=30.0,
+ config_timeout=10.0,
+ mocker: MockerFixture | None = None,
+):
if speech_service is None:
- speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = AsyncMock(return_value=(b"RIFF" + b"\x00" * 32, "audio/wav"))
- speech_service._prepare_speech_generation = AsyncMock(return_value=("req-1", object(), {}))
+ assert mocker is not None
+ speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = mocker.AsyncMock(return_value=(b"RIFF" + b"\x00" * 32, "audio/wav"))
+ speech_service._prepare_speech_generation = mocker.AsyncMock(return_value=("req-1", object(), {}))
async def mock_generate_pcm_chunks(_generator, _request_id):
for chunk in (b"\x01\x02", b"\x03\x04\x05"):
yield chunk
speech_service._generate_pcm_chunks = mock_generate_pcm_chunks
- speech_service.engine_client = MagicMock()
- speech_service.engine_client.abort = AsyncMock()
+ speech_service.engine_client = mocker.MagicMock()
+ speech_service.engine_client.abort = mocker.AsyncMock()
handler = OmniStreamingSpeechHandler(
speech_service=speech_service,
@@ -42,8 +49,8 @@ async def ws_endpoint(websocket: WebSocket):
class TestStreamingSpeechWebSocket:
- def test_non_streaming_single_frame(self):
- app, speech_service = _build_test_app()
+ def test_non_streaming_single_frame(self, mocker: MockerFixture):
+ app, speech_service = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -68,13 +75,13 @@ def test_non_streaming_single_frame(self):
assert speech_service._generate_audio_bytes.await_count == 1
- def test_streaming_multiple_binary_frames(self):
+ def test_streaming_multiple_binary_frames(self, mocker: MockerFixture):
captured_requests = []
- speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = AsyncMock(return_value=(b"", "audio/wav"))
- speech_service.engine_client = MagicMock()
- speech_service.engine_client.abort = AsyncMock()
+ speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = mocker.AsyncMock(return_value=(b"", "audio/wav"))
+ speech_service.engine_client = mocker.MagicMock()
+ speech_service.engine_client.abort = mocker.AsyncMock()
async def mock_prepare_speech_generation(request):
captured_requests.append(request)
@@ -123,8 +130,8 @@ async def mock_generate_pcm_chunks(_generator, _request_id):
assert captured_requests[0].initial_codec_chunk_frames == 12
assert speech_service._generate_audio_bytes.await_count == 0
- def test_flush_on_input_done(self):
- app, _ = _build_test_app()
+ def test_flush_on_input_done(self, mocker: MockerFixture):
+ app, _ = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -142,8 +149,8 @@ def test_flush_on_input_done(self):
}
assert ws.receive_json() == {"type": "session.done", "total_sentences": 1}
- def test_invalid_streaming_config(self):
- app, _ = _build_test_app()
+ def test_invalid_streaming_config(self, mocker: MockerFixture):
+ app, _ = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -159,8 +166,8 @@ def test_invalid_streaming_config(self):
assert error["type"] == "error"
assert "response_format='pcm'" in error["message"]
- def test_empty_input_text_emits_no_audio(self):
- app, speech_service = _build_test_app()
+ def test_empty_input_text_emits_no_audio(self, mocker: MockerFixture):
+ app, speech_service = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -172,8 +179,8 @@ def test_empty_input_text_emits_no_audio(self):
assert speech_service._generate_audio_bytes.await_count == 0
- def test_multiple_sentences_increment_indices(self):
- app, _ = _build_test_app()
+ def test_multiple_sentences_increment_indices(self, mocker: MockerFixture):
+ app, _ = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -203,8 +210,8 @@ def test_multiple_sentences_increment_indices(self):
ws.send_json({"type": "input.done"})
assert ws.receive_json() == {"type": "session.done", "total_sentences": 2}
- def test_unknown_message_type_keeps_session_open(self):
- app, _ = _build_test_app()
+ def test_unknown_message_type_keeps_session_open(self, mocker: MockerFixture):
+ app, _ = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -227,21 +234,21 @@ def test_unknown_message_type_keeps_session_open(self):
ws.send_json({"type": "input.done"})
assert ws.receive_json() == {"type": "session.done", "total_sentences": 1}
- def test_config_timeout_closes_session(self):
- app, _ = _build_test_app(config_timeout=0.01)
+ def test_config_timeout_closes_session(self, mocker: MockerFixture):
+ app, _ = _build_test_app(config_timeout=0.01, mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
error = ws.receive_json()
assert error == {"type": "error", "message": "Timeout waiting for session.config"}
- def test_generation_error_marks_audio_done(self):
- speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = AsyncMock(side_effect=RuntimeError("boom"))
- speech_service._prepare_speech_generation = AsyncMock(return_value=("req-err", object(), {}))
- speech_service._generate_pcm_chunks = AsyncMock()
- speech_service.engine_client = MagicMock()
- speech_service.engine_client.abort = AsyncMock()
+ def test_generation_error_marks_audio_done(self, mocker: MockerFixture):
+ speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = mocker.AsyncMock(side_effect=RuntimeError("boom"))
+ speech_service._prepare_speech_generation = mocker.AsyncMock(return_value=("req-err", object(), {}))
+ speech_service._generate_pcm_chunks = mocker.AsyncMock()
+ speech_service.engine_client = mocker.MagicMock()
+ speech_service.engine_client.abort = mocker.AsyncMock()
app, _ = _build_test_app(speech_service)
with TestClient(app) as client:
@@ -256,12 +263,12 @@ def test_generation_error_marks_audio_done(self):
ws.send_json({"type": "input.done"})
assert ws.receive_json() == {"type": "session.done", "total_sentences": 1}
- def test_streaming_generation_error_marks_audio_done(self):
- speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = AsyncMock(return_value=(b"", "audio/wav"))
- speech_service._prepare_speech_generation = AsyncMock(return_value=("req-stream-err", object(), {}))
- speech_service.engine_client = MagicMock()
- speech_service.engine_client.abort = AsyncMock()
+ def test_streaming_generation_error_marks_audio_done(self, mocker: MockerFixture):
+ speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = mocker.AsyncMock(return_value=(b"", "audio/wav"))
+ speech_service._prepare_speech_generation = mocker.AsyncMock(return_value=("req-stream-err", object(), {}))
+ speech_service.engine_client = mocker.MagicMock()
+ speech_service.engine_client.abort = mocker.AsyncMock()
async def mock_generate_pcm_chunks(_generator, _request_id):
yield b"\x01\x02"
@@ -298,8 +305,8 @@ async def mock_generate_pcm_chunks(_generator, _request_id):
ws.send_json({"type": "input.done"})
assert ws.receive_json() == {"type": "session.done", "total_sentences": 1}
- def test_invalid_input_text_type_returns_validation_error(self):
- app, speech_service = _build_test_app()
+ def test_invalid_input_text_type_returns_validation_error(self, mocker: MockerFixture):
+ app, speech_service = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -316,9 +323,9 @@ def test_invalid_input_text_type_returns_validation_error(self):
assert speech_service._generate_audio_bytes.await_count == 0
- def test_input_text_message_too_large(self, monkeypatch):
+ def test_input_text_message_too_large(self, monkeypatch, mocker: MockerFixture):
monkeypatch.setattr(streaming_speech_module, "_MAX_INPUT_TEXT_MESSAGE_SIZE", 32)
- app, speech_service = _build_test_app()
+ app, speech_service = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -335,9 +342,9 @@ def test_input_text_message_too_large(self, monkeypatch):
assert speech_service._generate_audio_bytes.await_count == 0
- def test_session_config_message_too_large(self, monkeypatch):
+ def test_session_config_message_too_large(self, monkeypatch, mocker: MockerFixture):
monkeypatch.setattr(streaming_speech_module, "_MAX_CONFIG_MESSAGE_SIZE", 64)
- app, _ = _build_test_app()
+ app, _ = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -348,12 +355,12 @@ def test_session_config_message_too_large(self, monkeypatch):
"message": "session.config message too large",
}
- def test_disconnect_aborts_streaming_request(self):
- speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = AsyncMock(return_value=(b"", "audio/wav"))
- speech_service._prepare_speech_generation = AsyncMock(return_value=("req-abort", object(), {}))
- speech_service.engine_client = MagicMock()
- speech_service.engine_client.abort = AsyncMock()
+ def test_disconnect_aborts_streaming_request(self, mocker: MockerFixture):
+ speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = mocker.AsyncMock(return_value=(b"", "audio/wav"))
+ speech_service._prepare_speech_generation = mocker.AsyncMock(return_value=("req-abort", object(), {}))
+ speech_service.engine_client = mocker.MagicMock()
+ speech_service.engine_client.abort = mocker.AsyncMock()
async def mock_generate_pcm_chunks(_generator, _request_id):
yield b"\x01\x02"
@@ -361,11 +368,11 @@ async def mock_generate_pcm_chunks(_generator, _request_id):
speech_service._generate_pcm_chunks = mock_generate_pcm_chunks
handler = OmniStreamingSpeechHandler(speech_service=speech_service)
- websocket = MagicMock()
- websocket.send_json = AsyncMock(side_effect=[None, WebSocketDisconnect()])
- websocket.send_bytes = AsyncMock(side_effect=WebSocketDisconnect())
+ websocket = mocker.MagicMock()
+ websocket.send_json = mocker.AsyncMock(side_effect=[None, WebSocketDisconnect()])
+ websocket.send_bytes = mocker.AsyncMock(side_effect=WebSocketDisconnect())
- config = MagicMock()
+ config = mocker.MagicMock()
config.model = None
config.voice = "Vivian"
config.task_type = None
diff --git a/tests/entrypoints/test_omni_base_profiler.py b/tests/entrypoints/test_omni_base_profiler.py
index 0c1ddc6a5db..ca10eed91f6 100644
--- a/tests/entrypoints/test_omni_base_profiler.py
+++ b/tests/entrypoints/test_omni_base_profiler.py
@@ -1,8 +1,7 @@
"""Unit tests for OmniBase and AsyncOmni profiler methods."""
-from unittest.mock import MagicMock, patch
-
import pytest
+from pytest_mock import MockerFixture
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -11,12 +10,12 @@ class TestOmniBaseProfiler:
"""Test suite for OmniBase profiler methods (start_profile, stop_profile)."""
@pytest.fixture
- def mock_engine(self):
+ def mock_engine(self, mocker: MockerFixture):
"""Create a mock AsyncOmniEngine for testing."""
- engine = MagicMock()
+ engine = mocker.MagicMock()
engine.num_stages = 3
engine.is_alive.return_value = True
- engine.default_sampling_params_list = [MagicMock() for _ in range(3)]
+ engine.default_sampling_params_list = [mocker.MagicMock() for _ in range(3)]
engine.get_stage_metadata.side_effect = lambda i: {
"final_output_type": "text" if i == 0 else "audio",
"final_output": True,
@@ -25,17 +24,15 @@ def mock_engine(self):
return engine
@pytest.fixture
- def omni_base_instance(self, mock_engine):
+ def omni_base_instance(self, mock_engine, mocker: MockerFixture):
"""Create an OmniBase instance with mocked dependencies."""
- with (
- patch("vllm_omni.entrypoints.omni_base.AsyncOmniEngine", return_value=mock_engine),
- patch("vllm_omni.entrypoints.omni_base.omni_snapshot_download", side_effect=lambda x: x),
- patch("vllm_omni.entrypoints.omni_base.weakref.finalize"),
- ):
- from vllm_omni.entrypoints.omni_base import OmniBase
-
- instance = OmniBase(model="test-model")
- return instance
+ mocker.patch("vllm_omni.entrypoints.omni_base.AsyncOmniEngine", return_value=mock_engine)
+ mocker.patch("vllm_omni.entrypoints.omni_base.omni_snapshot_download", side_effect=lambda x: x)
+ mocker.patch("vllm_omni.entrypoints.omni_base.weakref.finalize")
+ from vllm_omni.entrypoints.omni_base import OmniBase
+
+ instance = OmniBase(model="test-model")
+ return instance
def test_start_profile_calls_collective_rpc(self, omni_base_instance, mock_engine):
"""Test that start_profile calls collective_rpc with correct arguments."""
diff --git a/tests/entrypoints/test_serve.py b/tests/entrypoints/test_serve.py
index 916db3cc22a..afa7fa82e4b 100644
--- a/tests/entrypoints/test_serve.py
+++ b/tests/entrypoints/test_serve.py
@@ -3,9 +3,9 @@
from __future__ import annotations
import argparse
-from unittest.mock import Mock, patch
import pytest
+from pytest_mock import MockerFixture
from vllm_omni.entrypoints.cli.serve import run_headless
@@ -26,45 +26,43 @@ def _make_headless_args() -> argparse.Namespace:
)
-def test_run_headless_registers_stage_once_and_launches_all_local_engines() -> None:
+def test_run_headless_registers_stage_once_and_launches_all_local_engines(mocker: MockerFixture) -> None:
args = _make_headless_args()
- stage_cfg = Mock(stage_id=3)
+ stage_cfg = mocker.Mock(stage_id=3)
stage_cfgs = [stage_cfg]
- parallel_config = Mock(
+ parallel_config = mocker.Mock(
data_parallel_size_local=2,
data_parallel_rank=4,
data_parallel_rank_local=1,
node_rank_within_dp=0,
)
- vllm_config = Mock(parallel_config=parallel_config)
- executor_class = Mock()
- engine_manager = Mock()
-
- with (
- patch(
- "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
- return_value=("/fake/stages.yaml", stage_cfgs),
- ),
- patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment"),
- patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=Mock()),
- patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={}),
- patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={}),
- patch(
- "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- ),
- patch(
- "vllm_omni.engine.stage_init_utils.build_vllm_config",
- return_value=(vllm_config, executor_class),
- ) as mock_build_vllm_config,
- patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value="tcp://127.0.0.1:26001",
- ) as mock_register,
- patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager) as mock_manager_cls,
- patch("signal.signal"),
- ):
- run_headless(args)
+ vllm_config = mocker.Mock(parallel_config=parallel_config)
+ executor_class = mocker.Mock()
+ engine_manager = mocker.Mock()
+
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
+ return_value=("/fake/stages.yaml", stage_cfgs),
+ )
+ mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment")
+ mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock())
+ mocker.patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={})
+ mocker.patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={})
+ mocker.patch(
+ "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mock_build_vllm_config = mocker.patch(
+ "vllm_omni.engine.stage_init_utils.build_vllm_config",
+ return_value=(vllm_config, executor_class),
+ )
+ mock_register = mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
+ return_value="tcp://127.0.0.1:26001",
+ )
+ mock_manager_cls = mocker.patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager)
+ mocker.patch("signal.signal")
+ run_headless(args)
mock_build_vllm_config.assert_called_once_with(
stage_cfg,
@@ -92,89 +90,85 @@ def test_run_headless_registers_stage_once_and_launches_all_local_engines() -> N
engine_manager.shutdown.assert_called_once_with()
-def test_run_headless_honors_explicit_log_stats_flag() -> None:
+def test_run_headless_honors_explicit_log_stats_flag(mocker: MockerFixture) -> None:
args = _make_headless_args()
args.log_stats = True
- stage_cfg = Mock(stage_id=3)
+ stage_cfg = mocker.Mock(stage_id=3)
stage_cfgs = [stage_cfg]
- parallel_config = Mock(
+ parallel_config = mocker.Mock(
data_parallel_size_local=2,
data_parallel_rank=4,
data_parallel_rank_local=1,
node_rank_within_dp=0,
)
- vllm_config = Mock(parallel_config=parallel_config)
- executor_class = Mock()
- engine_manager = Mock()
-
- with (
- patch(
- "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
- return_value=("/fake/stages.yaml", stage_cfgs),
- ),
- patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment"),
- patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=Mock()),
- patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={}),
- patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={}),
- patch(
- "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- ),
- patch(
- "vllm_omni.engine.stage_init_utils.build_vllm_config",
- return_value=(vllm_config, executor_class),
- ),
- patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value="tcp://127.0.0.1:26001",
- ),
- patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager) as mock_manager_cls,
- patch("signal.signal"),
- ):
- run_headless(args)
+ vllm_config = mocker.Mock(parallel_config=parallel_config)
+ executor_class = mocker.Mock()
+ engine_manager = mocker.Mock()
+
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
+ return_value=("/fake/stages.yaml", stage_cfgs),
+ )
+ mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment")
+ mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock())
+ mocker.patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={})
+ mocker.patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={})
+ mocker.patch(
+ "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.stage_init_utils.build_vllm_config",
+ return_value=(vllm_config, executor_class),
+ )
+ mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
+ return_value="tcp://127.0.0.1:26001",
+ )
+ mock_manager_cls = mocker.patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager)
+ mocker.patch("signal.signal")
+ run_headless(args)
manager_kwargs = mock_manager_cls.call_args.kwargs
assert manager_kwargs["log_stats"] is True
-def test_run_headless_launches_diffusion_stage_via_omni_master() -> None:
+def test_run_headless_launches_diffusion_stage_via_omni_master(mocker: MockerFixture) -> None:
args = _make_headless_args()
- stage_cfg = Mock(stage_id=3, stage_type="diffusion")
- stage_cfg.engine_args = Mock()
+ stage_cfg = mocker.Mock(stage_id=3, stage_type="diffusion")
+ stage_cfg.engine_args = mocker.Mock()
stage_cfg.engine_input_source = []
stage_cfgs = [stage_cfg]
- metadata = Mock(stage_id=3)
- od_config = Mock()
- proc = Mock()
+ metadata = mocker.Mock(stage_id=3)
+ od_config = mocker.Mock()
+ proc = mocker.Mock()
proc.exitcode = 0
proc.is_alive.return_value = False
- with (
- patch(
- "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
- return_value=("/fake/stages.yaml", stage_cfgs),
- ),
- patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment"),
- patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=Mock()),
- patch(
- "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- ),
- patch("vllm_omni.engine.stage_init_utils.extract_stage_metadata", return_value=metadata),
- patch("vllm_omni.engine.stage_init_utils.inject_kv_stage_info") as mock_inject_stage_info,
- patch("vllm_omni.engine.stage_init_utils.build_diffusion_config", return_value=od_config),
- patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value=("tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"),
- ) as mock_register,
- patch(
- "vllm_omni.diffusion.stage_diffusion_proc.spawn_diffusion_proc",
- return_value=(proc, "tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"),
- ) as mock_spawn,
- patch("vllm_omni.diffusion.stage_diffusion_proc.complete_diffusion_handshake") as mock_handshake,
- patch("signal.signal"),
- ):
- run_headless(args)
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
+ return_value=("/fake/stages.yaml", stage_cfgs),
+ )
+ mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment")
+ mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock())
+ mocker.patch(
+ "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch("vllm_omni.engine.stage_init_utils.extract_stage_metadata", return_value=metadata)
+ mock_inject_stage_info = mocker.patch("vllm_omni.engine.stage_init_utils.inject_kv_stage_info")
+ mocker.patch("vllm_omni.engine.stage_init_utils.build_diffusion_config", return_value=od_config)
+ mock_register = mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
+ return_value=("tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"),
+ )
+ mock_spawn = mocker.patch(
+ "vllm_omni.diffusion.stage_diffusion_proc.spawn_diffusion_proc",
+ return_value=(proc, "tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"),
+ )
+ mock_handshake = mocker.patch("vllm_omni.diffusion.stage_diffusion_proc.complete_diffusion_handshake")
+ mocker.patch("signal.signal")
+ run_headless(args)
mock_inject_stage_info.assert_called_once_with(stage_cfg, 3)
mock_register.assert_called_once_with(
diff --git a/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py b/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py
index 85c0e8b56e4..8858d1f8f16 100644
--- a/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py
+++ b/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py
@@ -2,10 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import SimpleNamespace
-from unittest.mock import Mock
import pytest
import torch
+from pytest_mock import MockerFixture
from vllm_omni.model_executor.models.mimo_audio.config_mimo_audio import TALKER_CODEC_PAD_TOKEN_ID
from vllm_omni.model_executor.models.mimo_audio.mimo_audio_code2wav import (
@@ -51,7 +51,7 @@ def _make_invalid_flat_immediate_eostm(eostm_id: int = 666) -> torch.Tensor:
return g.reshape(-1)
-def _minimal_model():
+def _minimal_model(mocker: MockerFixture):
"""Avoid __init__ (HF tokenizer paths); only fields used by _batch_decode_waveforms."""
model = object.__new__(MiMoAudioToken2WavForConditionalGenerationVLLM)
model.device = torch.device("cpu")
@@ -59,7 +59,7 @@ def _minimal_model():
model.streamer_config = AudioStreamerConfig(group_size=_GROUP, audio_channels=_AC)
model.codes = _codes_ns()
- decode_vq = Mock(
+ decode_vq = mocker.Mock(
side_effect=lambda audio_codes: torch.ones(
audio_codes.shape[1],
7,
@@ -67,7 +67,7 @@ def _minimal_model():
device=audio_codes.device,
)
)
- decoder = Mock()
+ decoder = mocker.Mock()
audio_tok = SimpleNamespace(
encoder=SimpleNamespace(decode_vq=decode_vq),
@@ -78,9 +78,9 @@ def _minimal_model():
return model, audio_tok
-def test_batch_decode_waveforms_empty_input_list():
+def test_batch_decode_waveforms_empty_input_list(mocker: MockerFixture):
"""Empty input list returns a single zero-length float32 tensor on model device."""
- model, _ = _minimal_model()
+ model, _ = _minimal_model(mocker)
out = MiMoAudioToken2WavForConditionalGenerationVLLM._batch_decode_waveforms(model, [])
assert len(out) == 1
assert out[0].dtype == torch.float32
@@ -88,9 +88,9 @@ def test_batch_decode_waveforms_empty_input_list():
assert out[0].device == model.device
-def test_batch_decode_waveforms_single_vs_multiple_decoder_shapes():
+def test_batch_decode_waveforms_single_vs_multiple_decoder_shapes(mocker: MockerFixture):
"""Single and multi-request batches produce correctly shaped packed hidden states and trimmed waveforms."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
decoder = audio_tok.decoder
# Single valid request: decoder output rank-3 for double squeeze path
@@ -118,9 +118,9 @@ def test_batch_decode_waveforms_single_vs_multiple_decoder_shapes():
assert out2[1].shape == (8 * _FTP,)
-def test_batch_decode_waveforms_mixed_valid_invalid_requests():
+def test_batch_decode_waveforms_mixed_valid_invalid_requests(mocker: MockerFixture):
"""Mixed valid and invalid requests: invalid slots get empty tensors, valid slots get decoded waveforms."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
valid_a = _make_valid_flat_codes(1)
valid_b = _make_valid_flat_codes(1)
dummy = _make_dummy_code_tensor()
@@ -151,9 +151,9 @@ def test_batch_decode_waveforms_mixed_valid_invalid_requests():
assert input_lengths.tolist() == [4, 4]
-def test_batch_decode_waveforms_all_invalid_returns_per_request_empty():
+def test_batch_decode_waveforms_all_invalid_returns_per_request_empty(mocker: MockerFixture):
"""All-invalid batch skips decoder entirely and returns empty tensors for every slot."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
out = MiMoAudioToken2WavForConditionalGenerationVLLM._batch_decode_waveforms(
model,
[None, _make_dummy_code_tensor(), torch.tensor([], dtype=torch.long)],
@@ -163,9 +163,9 @@ def test_batch_decode_waveforms_all_invalid_returns_per_request_empty():
audio_tok.decoder.assert_not_called()
-def test_batch_decode_waveforms_output_shape_trim_when_decoder_returns_extra_samples():
+def test_batch_decode_waveforms_output_shape_trim_when_decoder_returns_extra_samples(mocker: MockerFixture):
"""Decoder output longer than valid_len is trimmed to the exact expected waveform length."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
flat = _make_valid_flat_codes(1)
# Longer than valid_len so branch wav = wav[:valid_len] runs
audio_tok.decoder.return_value = torch.ones(1, 1, 10_000, dtype=torch.float32)
@@ -175,9 +175,9 @@ def test_batch_decode_waveforms_output_shape_trim_when_decoder_returns_extra_sam
assert out[0].dtype == torch.float32
-def test_batch_decode_waveforms_multi_request_trims_each_row_when_decoder_returns_extra():
+def test_batch_decode_waveforms_multi_request_trims_each_row_when_decoder_returns_extra(mocker: MockerFixture):
"""Else-branch split: per-request wav[:valid_len] when decoder pads each batch row."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
a = _make_valid_flat_codes(1)
b = _make_valid_flat_codes(2)
audio_tok.decoder.return_value = torch.ones(2, 1, 10_000, dtype=torch.float32)
@@ -189,9 +189,9 @@ def test_batch_decode_waveforms_multi_request_trims_each_row_when_decoder_return
assert out[1].dtype == torch.float32
-def test_batch_decode_waveforms_valid_only_at_edges_maps_to_correct_indices():
+def test_batch_decode_waveforms_valid_only_at_edges_maps_to_correct_indices(mocker: MockerFixture):
"""Tensor packing order must match valid_indices when invalid requests are in the middle."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
first = _make_valid_flat_codes(1)
last = _make_valid_flat_codes(2)
inputs = [
@@ -212,9 +212,9 @@ def test_batch_decode_waveforms_valid_only_at_edges_maps_to_correct_indices():
assert input_lengths.tolist() == [4, 8]
-def test_batch_decode_waveforms_output_shapes_1d_float32_for_all_slots():
+def test_batch_decode_waveforms_output_shapes_1d_float32_for_all_slots(mocker: MockerFixture):
"""Every slot is a 1-D float32 vector (empty or waveform), matching downstream expectations."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
inputs = [_make_valid_flat_codes(1), None, _make_valid_flat_codes(1)]
audio_tok.decoder.return_value = torch.ones(2, 1, 5000, dtype=torch.float32)
out = MiMoAudioToken2WavForConditionalGenerationVLLM._batch_decode_waveforms(model, inputs)
diff --git a/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py b/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py
index 8e04b04966b..587e7f7f8b1 100644
--- a/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py
+++ b/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py
@@ -10,10 +10,9 @@
- Interleaved (use_audio_in_video) should also work correctly.
"""
-from unittest.mock import Mock
-
import pytest
import torch
+from pytest_mock import MockerFixture
from vllm.model_executor.models.qwen2_5_omni_thinker import (
check_interleaved_audio_video,
merge_interleaved_embeddings,
@@ -107,7 +106,7 @@ def test_interleaved(self):
# ---------------------------------------------------------------------------
-def make_mock_model(hidden: int = 8):
+def make_mock_model(mocker: MockerFixture, hidden: int = 8):
"""
Return a minimal mock of Qwen2_5OmniThinkerForConditionalGeneration
that has enough structure to run embed_input_ids.
@@ -116,10 +115,10 @@ def make_mock_model(hidden: int = 8):
Qwen2_5OmniThinkerForConditionalGeneration,
)
- model = Mock(spec=Qwen2_5OmniThinkerForConditionalGeneration)
+ model = mocker.Mock(spec=Qwen2_5OmniThinkerForConditionalGeneration)
# Config with token IDs
- cfg = Mock()
+ cfg = mocker.Mock()
cfg.video_token_index = VIDEO_TOKEN_ID
cfg.audio_token_index = AUDIO_TOKEN_ID
model.config = cfg
@@ -130,9 +129,9 @@ def fake_lm_embed(ids: torch.Tensor) -> torch.Tensor:
# view with shared memory, which masked_scatter_ cannot handle).
return ids.float().unsqueeze(-1).expand(-1, hidden).clone()
- lang_model = Mock()
+ lang_model = mocker.Mock()
lang_model.embed_input_ids = fake_lm_embed
- model.get_language_model = Mock(return_value=lang_model)
+ model.get_language_model = mocker.Mock(return_value=lang_model)
from vllm.model_executor.models.interfaces import SupportsMultiModal
@@ -169,7 +168,7 @@ def build_mm_embeds(audio_n, image_n, video_n, hidden, audio_val=10.0, image_val
class TestEmbedInputIds:
- def _run(self, audio_n, image_n, video_n, hidden=8):
+ def _run(self, mocker: MockerFixture, audio_n, image_n, video_n, hidden=8):
"""
Run embed_input_ids for a non-interleaved mixed-modality sequence.
Returns (result_embeds, input_ids, is_multimodal).
@@ -177,33 +176,33 @@ def _run(self, audio_n, image_n, video_n, hidden=8):
input_ids, is_multimodal = make_token_seq(audio_n, image_n, video_n)
mm_embeds = build_mm_embeds(audio_n, image_n, video_n, hidden)
- model, _ = make_mock_model(hidden)
+ model, _ = make_mock_model(mocker, hidden)
result = model.embed_input_ids(input_ids, mm_embeds, is_multimodal=is_multimodal)
return result, input_ids, is_multimodal
- def test_audio_only(self):
+ def test_audio_only(self, mocker: MockerFixture):
"""Audio-only: audio positions get audio embeddings."""
audio_n, hidden = 5, 8
audio_val = 10.0
- result, input_ids, is_multimodal = self._run(audio_n, 0, 0, hidden)
+ result, input_ids, is_multimodal = self._run(mocker, audio_n, 0, 0, hidden)
audio_pos = (input_ids == AUDIO_TOKEN_ID).nonzero(as_tuple=True)[0]
assert result[audio_pos].allclose(torch.full((audio_n, hidden), audio_val)), (
"Audio positions should get audio embeddings"
)
- def test_video_only(self):
+ def test_video_only(self, mocker: MockerFixture):
"""Video-only: video positions get video embeddings."""
video_n, hidden = 6, 8
video_val = 30.0
- result, input_ids, is_multimodal = self._run(0, 0, video_n, hidden)
+ result, input_ids, is_multimodal = self._run(mocker, 0, 0, video_n, hidden)
video_pos = (input_ids == VIDEO_TOKEN_ID).nonzero(as_tuple=True)[0]
assert result[video_pos].allclose(torch.full((video_n, hidden), video_val)), (
"Video positions should get video embeddings"
)
- def test_mixed_modalities_audio_goes_to_audio_pos(self):
+ def test_mixed_modalities_audio_goes_to_audio_pos(self, mocker: MockerFixture):
"""
Regression test for GitHub issue #34506:
With audio + image + video (non-interleaved), audio positions must
@@ -212,7 +211,7 @@ def test_mixed_modalities_audio_goes_to_audio_pos(self):
audio_n, image_n, video_n, hidden = 5, 4, 6, 8
audio_val, image_val, video_val = 10.0, 20.0, 30.0
- result, input_ids, is_multimodal = self._run(audio_n, image_n, video_n, hidden)
+ result, input_ids, is_multimodal = self._run(mocker, audio_n, image_n, video_n, hidden)
audio_pos = (input_ids == AUDIO_TOKEN_ID).nonzero(as_tuple=True)[0]
image_pos = (input_ids == IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0]
@@ -233,10 +232,10 @@ def test_mixed_modalities_audio_goes_to_audio_pos(self):
f"Video emb wrong: expected {video_val}, got mean={mean_v:.1f}"
)
- def test_text_positions_unchanged(self):
+ def test_text_positions_unchanged(self, mocker: MockerFixture):
"""Text positions should keep their text embeddings."""
audio_n, image_n, video_n, hidden = 3, 2, 4, 8
- result, input_ids, is_multimodal = self._run(audio_n, image_n, video_n, hidden)
+ result, input_ids, is_multimodal = self._run(mocker, audio_n, image_n, video_n, hidden)
text_pos = (~is_multimodal).nonzero(as_tuple=True)[0]
# Text tokens have value TEXT_TOKEN_ID=0, so embed -> 0.0
@@ -244,7 +243,7 @@ def test_text_positions_unchanged(self):
"Text positions should keep text embeddings"
)
- def test_interleaved_use_audio_in_video(self):
+ def test_interleaved_use_audio_in_video(self, mocker: MockerFixture):
"""
Interleaved (use_audio_in_video): video chunks interleaved with audio.
Video embeddings must go to video positions, audio to audio positions.
@@ -263,7 +262,7 @@ def test_interleaved_use_audio_in_video(self):
torch.full((audio_n, hidden), audio_val),
]
- model, _ = make_mock_model(hidden)
+ model, _ = make_mock_model(mocker, hidden)
result = model.embed_input_ids(input_ids, mm_embeds, is_multimodal=is_multimodal)
video_pos = (input_ids == VIDEO_TOKEN_ID).nonzero(as_tuple=True)[0]
diff --git a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py b/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py
index e2970dcb2df..b0ce10a8d5e 100644
--- a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py
+++ b/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py
@@ -15,9 +15,10 @@
import os
import sys
import types
-from unittest.mock import MagicMock, patch
+import pytest
import torch
+from pytest_mock import MockerFixture
# Direct file import to avoid vllm_omni.__init__ patch dependencies.
_BASE = os.path.join(
@@ -41,28 +42,31 @@ def _load_module(name: str, filename: str):
return mod
-def _build_mock_modules() -> dict[str, object]:
+def _build_mock_modules(mocker: MockerFixture) -> dict[str, object]:
"""Build the dict of modules to inject into sys.modules."""
- platforms_mock = MagicMock()
+ platforms_mock = mocker.MagicMock()
platforms_mock.current_omni_platform.supports_torch_inductor.return_value = False
- logger_mock = MagicMock()
- logger_mock.init_logger = lambda name: MagicMock()
+ logger_mock = mocker.MagicMock()
+ logger_mock.init_logger = lambda name: mocker.MagicMock()
- vllm_config_mod = MagicMock()
- vllm_config_mod.set_current_vllm_config = lambda cfg: MagicMock(__enter__=MagicMock(), __exit__=MagicMock())
+ vllm_config_mod = mocker.MagicMock()
+ vllm_config_mod.set_current_vllm_config = lambda cfg: mocker.MagicMock(
+ __enter__=mocker.MagicMock(),
+ __exit__=mocker.MagicMock(),
+ )
- weight_utils_mock = MagicMock()
+ weight_utils_mock = mocker.MagicMock()
weight_utils_mock.default_weight_loader = lambda p, w: None
pkg = types.ModuleType("vllm_omni.model_executor.models.qwen3_tts")
pkg.__path__ = [os.path.abspath(_BASE)]
return {
- "vllm_omni": MagicMock(),
+ "vllm_omni": mocker.MagicMock(),
"vllm_omni.platforms": platforms_mock,
"vllm.logger": logger_mock,
- "vllm.config": MagicMock(),
+ "vllm.config": mocker.MagicMock(),
"vllm.config.vllm": vllm_config_mod,
"vllm.model_executor.model_loader.weight_utils": weight_utils_mock,
"vllm_omni.model_executor": types.ModuleType("vllm_omni.model_executor"),
@@ -71,38 +75,47 @@ def _build_mock_modules() -> dict[str, object]:
}
-def _load_target_classes():
+def _load_target_classes(mocker: MockerFixture):
"""Load config and code predictor modules with mocked dependencies.
- Uses patch.dict to ensure sys.modules is always restored, even on failure.
+ Uses mocker.patch.dict to ensure sys.modules is always restored, even on failure.
"""
- mocks = _build_mock_modules()
- with patch.dict(sys.modules, mocks):
- config_mod = _load_module(
- "vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts",
- "configuration_qwen3_tts.py",
- )
- sys.modules["vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts"] = config_mod
+ mocks = _build_mock_modules(mocker)
+ mocker.patch.dict(sys.modules, mocks)
+ config_mod = _load_module(
+ "vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts",
+ "configuration_qwen3_tts.py",
+ )
+ sys.modules["vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts"] = config_mod
- cp_mod = _load_module(
- "vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code_predictor_vllm",
- "qwen3_tts_code_predictor_vllm.py",
- )
+ cp_mod = _load_module(
+ "vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code_predictor_vllm",
+ "qwen3_tts_code_predictor_vllm.py",
+ )
return config_mod, cp_mod
-_config_mod, _cp_mod = _load_target_classes()
-
-Qwen3TTSTalkerCodePredictorConfig = _config_mod.Qwen3TTSTalkerCodePredictorConfig
-Qwen3TTSTalkerConfig = _config_mod.Qwen3TTSTalkerConfig
-CodePredictorWrapper = _cp_mod.Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM
-CodePredictorModel = _cp_mod.Qwen3TTSTalkerCodePredictorModelVLLM
+@pytest.fixture
+def loaded_target_classes(mocker: MockerFixture):
+ config_mod, cp_mod = _load_target_classes(mocker)
+ return (
+ config_mod.Qwen3TTSTalkerCodePredictorConfig,
+ config_mod.Qwen3TTSTalkerConfig,
+ cp_mod.Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM,
+ cp_mod.Qwen3TTSTalkerCodePredictorModelVLLM,
+ )
-def _make_tiny_config() -> tuple:
+def _make_tiny_config(loaded_target_classes) -> tuple:
"""Create minimal configs for a tiny code predictor model."""
- cp_config = Qwen3TTSTalkerCodePredictorConfig(
+ (
+ qwen3_tts_talker_code_predictor_config,
+ qwen3_tts_talker_config,
+ _,
+ _,
+ ) = loaded_target_classes
+ cp_config = qwen3_tts_talker_code_predictor_config(
vocab_size=64,
hidden_size=32,
intermediate_size=64,
@@ -113,16 +126,16 @@ def _make_tiny_config() -> tuple:
num_code_groups=4,
rms_norm_eps=1e-6,
)
- talker_config = Qwen3TTSTalkerConfig(
+ talker_config = qwen3_tts_talker_config(
hidden_size=32,
num_code_groups=4,
)
return cp_config, talker_config
-def _make_vllm_config(max_num_seqs: int = 4) -> MagicMock:
+def _make_vllm_config(mocker: MockerFixture, max_num_seqs: int = 4):
"""Create a mock VllmConfig with scheduler_config."""
- vllm_config = MagicMock()
+ vllm_config = mocker.MagicMock()
vllm_config.scheduler_config.max_num_seqs = max_num_seqs
return vllm_config
@@ -130,12 +143,13 @@ def _make_vllm_config(max_num_seqs: int = 4) -> MagicMock:
class TestCodePredictorDtypeAlignment:
"""Test that code predictor buffers match model parameter dtype."""
- def test_ensure_buffers_uses_given_dtype(self) -> None:
+ def test_ensure_buffers_uses_given_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None:
"""_ensure_buffers should create proj_buf with the given dtype."""
- cp_config, talker_config = _make_tiny_config()
- vllm_config = _make_vllm_config()
+ _, _, code_predictor_wrapper, _ = loaded_target_classes
+ cp_config, talker_config = _make_tiny_config(loaded_target_classes)
+ vllm_config = _make_vllm_config(mocker)
- predictor = CodePredictorWrapper(
+ predictor = code_predictor_wrapper(
vllm_config=vllm_config,
config=cp_config,
talker_config=talker_config,
@@ -150,12 +164,13 @@ def test_ensure_buffers_uses_given_dtype(self) -> None:
predictor._ensure_buffers(torch.device("cpu"), torch.float32)
assert predictor._proj_buf.dtype == torch.float32
- def test_warmup_aligns_buffer_to_model_params(self) -> None:
+ def test_warmup_aligns_buffer_to_model_params(self, mocker: MockerFixture, loaded_target_classes) -> None:
"""_warmup_buckets should align proj_buf dtype to model parameters."""
- cp_config, talker_config = _make_tiny_config()
- vllm_config = _make_vllm_config(max_num_seqs=2)
+ _, _, code_predictor_wrapper, _ = loaded_target_classes
+ cp_config, talker_config = _make_tiny_config(loaded_target_classes)
+ vllm_config = _make_vllm_config(mocker, max_num_seqs=2)
- predictor = CodePredictorWrapper(
+ predictor = code_predictor_wrapper(
vllm_config=vllm_config,
config=cp_config,
talker_config=talker_config,
@@ -177,12 +192,13 @@ def test_warmup_aligns_buffer_to_model_params(self) -> None:
assert predictor._proj_buf.dtype == torch.float16
- def test_setup_compile_caches_model_dtype(self) -> None:
+ def test_setup_compile_caches_model_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None:
"""_setup_compile should cache model parameter dtype."""
- cp_config, talker_config = _make_tiny_config()
- vllm_config = _make_vllm_config(max_num_seqs=2)
+ _, _, code_predictor_wrapper, _ = loaded_target_classes
+ cp_config, talker_config = _make_tiny_config(loaded_target_classes)
+ vllm_config = _make_vllm_config(mocker, max_num_seqs=2)
- predictor = CodePredictorWrapper(
+ predictor = code_predictor_wrapper(
vllm_config=vllm_config,
config=cp_config,
talker_config=talker_config,
@@ -193,12 +209,13 @@ def test_setup_compile_caches_model_dtype(self) -> None:
predictor._setup_compile()
assert predictor._model_dtype == torch.float16
- def test_forward_with_mismatched_input_dtype(self) -> None:
+ def test_forward_with_mismatched_input_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None:
"""forward() should not crash when inputs are float32 but model is float16."""
- cp_config, talker_config = _make_tiny_config()
- vllm_config = _make_vllm_config(max_num_seqs=2)
+ _, _, code_predictor_wrapper, _ = loaded_target_classes
+ cp_config, talker_config = _make_tiny_config(loaded_target_classes)
+ vllm_config = _make_vllm_config(mocker, max_num_seqs=2)
- predictor = CodePredictorWrapper(
+ predictor = code_predictor_wrapper(
vllm_config=vllm_config,
config=cp_config,
talker_config=talker_config,
@@ -231,10 +248,11 @@ def test_forward_with_mismatched_input_dtype(self) -> None:
class TestCodePredictorModelDtype:
"""Test the inner model forward with different dtypes."""
- def test_model_forward_float16(self) -> None:
+ def test_model_forward_float16(self, loaded_target_classes) -> None:
"""Inner model forward should work in float16."""
- cp_config, _ = _make_tiny_config()
- model = CodePredictorModel(cp_config, talker_hidden_size=32).to(torch.float16)
+ _, _, _, code_predictor_model = loaded_target_classes
+ cp_config, _ = _make_tiny_config(loaded_target_classes)
+ model = code_predictor_model(cp_config, talker_hidden_size=32).to(torch.float16)
bsz, seq_len = 1, 4
inputs = torch.randn(bsz, seq_len, 32, dtype=torch.float16)
@@ -244,10 +262,11 @@ def test_model_forward_float16(self) -> None:
assert output.dtype == torch.float16
assert output.shape == (bsz, seq_len, 32)
- def test_model_forward_float32(self) -> None:
+ def test_model_forward_float32(self, loaded_target_classes) -> None:
"""Inner model forward should work in float32."""
- cp_config, _ = _make_tiny_config()
- model = CodePredictorModel(cp_config, talker_hidden_size=32).to(torch.float32)
+ _, _, _, code_predictor_model = loaded_target_classes
+ cp_config, _ = _make_tiny_config(loaded_target_classes)
+ model = code_predictor_model(cp_config, talker_hidden_size=32).to(torch.float32)
bsz, seq_len = 1, 4
inputs = torch.randn(bsz, seq_len, 32, dtype=torch.float32)
diff --git a/tests/model_executor/models/test_fish_speech_voice_cache.py b/tests/model_executor/models/test_fish_speech_voice_cache.py
index 8fe7a4a4d11..fef4b551ab2 100644
--- a/tests/model_executor/models/test_fish_speech_voice_cache.py
+++ b/tests/model_executor/models/test_fish_speech_voice_cache.py
@@ -10,11 +10,11 @@
import os
import tempfile
-from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import torch
+from pytest_mock import MockerFixture
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -61,18 +61,18 @@ class TestFishSpeechVoiceCacheIntegration:
"""Test the cache-hit / cache-miss / no-cache paths in the model."""
@pytest.fixture
- def mock_model(self):
+ def mock_model(self, mocker: MockerFixture):
"""Create a mock FishSpeechSlowARForConditionalGeneration with cache."""
from vllm_omni.utils.voice_cache import VoiceEmbeddingCache
- model = MagicMock()
+ model = mocker.MagicMock()
model._voice_cache = VoiceEmbeddingCache(max_entries=4)
model._semantic_begin_id = 151678
model._num_codebooks = 10
model._codebook_size = 4096
model.model_path = "/fake/model"
- model.codebook_embeddings = MagicMock()
- model.codebook_embeddings.weight = MagicMock()
+ model.codebook_embeddings = mocker.MagicMock()
+ model.codebook_embeddings.weight = mocker.MagicMock()
model.codebook_embeddings.weight.device = torch.device("cpu")
return model
@@ -166,9 +166,9 @@ def test_created_at_zero_disables_cache(self, mock_model):
class TestFishSpeechValidatorUploadedVoice:
"""Test _validate_fish_tts_request uploaded voice resolution."""
- def test_uploaded_voice_resolves_ref_audio(self):
+ def test_uploaded_voice_resolves_ref_audio(self, mocker: MockerFixture):
"""When voice matches an uploaded speaker, ref_audio should be auto-set."""
- request = MagicMock()
+ request = mocker.MagicMock()
request.input = "Hello"
request.voice = "alice"
request.ref_audio = None
@@ -185,17 +185,17 @@ def test_uploaded_voice_resolves_ref_audio(self):
}
# Simulate: voice in uploaded_speakers, file exists, get_audio returns data URL.
- with patch("pathlib.Path.exists", return_value=True):
- voice_lower = request.voice.lower()
- assert voice_lower in uploaded_speakers
+ mocker.patch("pathlib.Path.exists", return_value=True)
+ voice_lower = request.voice.lower()
+ assert voice_lower in uploaded_speakers
- speaker_info = uploaded_speakers[voice_lower]
- ref_text_from_upload = speaker_info.get("ref_text")
- assert ref_text_from_upload == "Hi this is Alice"
+ speaker_info = uploaded_speakers[voice_lower]
+ ref_text_from_upload = speaker_info.get("ref_text")
+ assert ref_text_from_upload == "Hi this is Alice"
- def test_uploaded_voice_without_ref_text_uses_request_ref_text(self):
+ def test_uploaded_voice_without_ref_text_uses_request_ref_text(self, mocker: MockerFixture):
"""If upload has no ref_text but request provides it, use request's."""
- request = MagicMock()
+ request = mocker.MagicMock()
request.input = "Hello"
request.voice = "bob"
request.ref_audio = None
diff --git a/tests/test_fish_speech_voice_cache.py b/tests/test_fish_speech_voice_cache.py
index 8fe7a4a4d11..1c299d80142 100644
--- a/tests/test_fish_speech_voice_cache.py
+++ b/tests/test_fish_speech_voice_cache.py
@@ -10,11 +10,12 @@
import os
import tempfile
-from unittest.mock import MagicMock, patch
+from pathlib import Path
import numpy as np
import pytest
import torch
+from pytest_mock import MockerFixture
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -61,18 +62,18 @@ class TestFishSpeechVoiceCacheIntegration:
"""Test the cache-hit / cache-miss / no-cache paths in the model."""
@pytest.fixture
- def mock_model(self):
+ def mock_model(self, mocker: MockerFixture):
"""Create a mock FishSpeechSlowARForConditionalGeneration with cache."""
from vllm_omni.utils.voice_cache import VoiceEmbeddingCache
- model = MagicMock()
+ model = mocker.MagicMock()
model._voice_cache = VoiceEmbeddingCache(max_entries=4)
model._semantic_begin_id = 151678
model._num_codebooks = 10
model._codebook_size = 4096
model.model_path = "/fake/model"
- model.codebook_embeddings = MagicMock()
- model.codebook_embeddings.weight = MagicMock()
+ model.codebook_embeddings = mocker.MagicMock()
+ model.codebook_embeddings.weight = mocker.MagicMock()
model.codebook_embeddings.weight.device = torch.device("cpu")
return model
@@ -166,9 +167,13 @@ def test_created_at_zero_disables_cache(self, mock_model):
class TestFishSpeechValidatorUploadedVoice:
"""Test _validate_fish_tts_request uploaded voice resolution."""
- def test_uploaded_voice_resolves_ref_audio(self):
+ def test_uploaded_voice_resolves_ref_audio(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ mocker: MockerFixture,
+ ):
"""When voice matches an uploaded speaker, ref_audio should be auto-set."""
- request = MagicMock()
+ request = mocker.MagicMock()
request.input = "Hello"
request.voice = "alice"
request.ref_audio = None
@@ -185,17 +190,21 @@ def test_uploaded_voice_resolves_ref_audio(self):
}
# Simulate: voice in uploaded_speakers, file exists, get_audio returns data URL.
- with patch("pathlib.Path.exists", return_value=True):
- voice_lower = request.voice.lower()
- assert voice_lower in uploaded_speakers
+ monkeypatch.setattr(Path, "exists", lambda self: True)
- speaker_info = uploaded_speakers[voice_lower]
- ref_text_from_upload = speaker_info.get("ref_text")
- assert ref_text_from_upload == "Hi this is Alice"
+ voice_lower = request.voice.lower()
+ assert voice_lower in uploaded_speakers
+
+ speaker_info = uploaded_speakers[voice_lower]
+ ref_text_from_upload = speaker_info.get("ref_text")
+ assert ref_text_from_upload == "Hi this is Alice"
- def test_uploaded_voice_without_ref_text_uses_request_ref_text(self):
+ def test_uploaded_voice_without_ref_text_uses_request_ref_text(
+ self,
+ mocker: MockerFixture,
+ ):
"""If upload has no ref_text but request provides it, use request's."""
- request = MagicMock()
+ request = mocker.MagicMock()
request.input = "Hello"
request.voice = "bob"
request.ref_audio = None
From 2a1d5060abbae97648d86f57d70fe5af57d41467 Mon Sep 17 00:00:00 2001
From: amy-why-3459
Date: Mon, 13 Apr 2026 20:43:40 +0800
Subject: [PATCH 19/76] [skip ci][doc]Update async_chunk design diagram (#2420)
Signed-off-by: amy-why-3459
---
docs/design/feature/async_chunk_design.md | 76 +++++++++++++++---
.../architecture/qwen3-omni-async-chunk.png | Bin 198564 -> 68497 bytes
.../qwen3-omni-non-async-chunk.png | Bin 263596 -> 49242 bytes
3 files changed, 67 insertions(+), 9 deletions(-)
diff --git a/docs/design/feature/async_chunk_design.md b/docs/design/feature/async_chunk_design.md
index 202ef0e18e8..45314a0aec6 100644
--- a/docs/design/feature/async_chunk_design.md
+++ b/docs/design/feature/async_chunk_design.md
@@ -19,7 +19,7 @@ The `async_chunk` feature enables asynchronous, chunked processing of data acros
For qwen3-omni:
- **Thinker → Talker**: Per decode step (typically chunk_size=1)
-- **Talker → Code2Wav**: Accumulated to `codec_chunk_frames` (default=25) before sending. During the initial phase, a dynamic initial chunk size (IC) is automatically selected based on server load to reduce TTFA. Use the per-request `initial_codec_chunk_frames` API field to override.
+- **Talker → Code2Wav**: Accumulated to `codec_chunk_frames` (default=25) before sending. During the initial phase, a dynamic initial chunk size (IC) is automatically selected based on server load to reduce TTFP. Use the per-request `initial_codec_chunk_frames` API field to override.
- **Code2Wav**: Streaming decode with code2wav chunk_size
With `async_chunk`:
@@ -75,26 +75,84 @@ Enabling **async_chunk** (False→True) sharply reduces time-to-first-audio (TTF
## Architecture
-### Data Flow
-#### Sequential Flow
+### Async Chunk Pipeline Overview
+
+The following diagram illustrates the **Async Chunk Architecture** for multi-stage models (e.g., Qwen3-Omni with Thinker → Talker → Code2Wav), showing how data flows through the 4-stage pipeline with parallel processing and dual-stream output:
+
-### Async Chunk architecture
+In sequential mode, each stage must wait for the previous stage to complete entirely before starting.
+
+### Async Chunk System Architecture