From 3173e2f7634e60fffb7c82ce0dd3852d2eed14c3 Mon Sep 17 00:00:00 2001 From: yinpe <11810305@mail.sustech.edu.cn> Date: Thu, 9 Apr 2026 12:49:59 +0800 Subject: [PATCH 1/6] modify cfg tracker Signed-off-by: yinpe <11810305@mail.sustech.edu.cn> --- .../entrypoints/test_cfg_companion_tracker.py | 124 +++------- vllm_omni/engine/async_omni_engine.py | 1 - vllm_omni/engine/orchestrator.py | 108 ++------ .../entrypoints/cfg_companion_tracker.py | 231 +++++------------- 4 files changed, 124 insertions(+), 340 deletions(-) diff --git a/tests/entrypoints/test_cfg_companion_tracker.py b/tests/entrypoints/test_cfg_companion_tracker.py index 941ead41ff0..d316410c52e 100644 --- a/tests/entrypoints/test_cfg_companion_tracker.py +++ b/tests/entrypoints/test_cfg_companion_tracker.py @@ -1,114 +1,64 @@ -import time -from types import SimpleNamespace - import pytest from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker +from vllm_omni.inputs.data import OmniDiffusionSamplingParams pytestmark = [pytest.mark.core_model, pytest.mark.cpu] -def dummy_expand_func(prompt, sp0): - if prompt == "expand_me": - return [SimpleNamespace(prompt={"prompt": "neg"}, role="cfg_text", request_id_suffix="__cfg_text")] - return [] - +def test_register_companion_and_cleanup(): + tracker = CfgCompanionTracker() -@pytest.fixture -def tracker(): - sp0 = SimpleNamespace() - return CfgCompanionTracker(prompt_expand_func=dummy_expand_func, stage0_sampling_params=sp0, timeout_s=0.1) + tracker.register_companion("req1", "cfg_text", "req1__cfg_text") + tracker.register_companion("req1", "cfg_img", "req1__cfg_img") + assert tracker.is_companion("req1__cfg_text") + assert tracker.get_companion_request_ids("req1") == { + "cfg_text": "req1__cfg_text", + "cfg_img": "req1__cfg_img", + } -def test_companion_tracker_initialization(tracker): - assert not tracker.is_active - assert tracker.num_companions == 0 + removed = tracker.cleanup_parent("req1") + assert sorted(removed) == ["req1__cfg_img", "req1__cfg_text"] + assert not tracker.is_companion("req1__cfg_text") + assert tracker.get_companion_request_ids("req1") == {} -def test_expand_prompts_registers_companions(tracker): - request_id_to_prompt = {"req1": "expand_me", "req2": "do_not_expand"} - pairs = tracker.expand_prompts(request_id_to_prompt) +def test_attach_cfg_request_ids_clones_diffusion_params(): + tracker = CfgCompanionTracker() + tracker.register_companion("req1", "cfg_text", "req1__cfg_text") - assert len(pairs) == 1 - companion_id, prompt = pairs[0] - assert companion_id == "req1__cfg_text" - assert prompt == {"prompt": "neg"} + params = OmniDiffusionSamplingParams() + updated = tracker.attach_cfg_request_ids("req1", params) - assert tracker.is_active - assert tracker.num_companions == 1 - assert tracker.is_companion("req1__cfg_text") - assert not tracker.is_companion("req2__cfg_text") - assert tracker.has_companions("req1") - assert not tracker.has_companions("req2") + assert updated is not params + assert params.cfg_kv_request_ids is None + assert updated.cfg_kv_request_ids == {"cfg_text": "req1__cfg_text"} - comp_map = tracker.get_companion_request_ids("req1") - assert comp_map == {"cfg_text": "req1__cfg_text"} +def test_abort_parents_with_parent_or_companion_ids(): + tracker = CfgCompanionTracker() + tracker.register_companion("req1", "cfg_text", "req1__cfg_text") + tracker.defer_parent("req1", {"out": 1}, stage_id=0) -def test_companion_lifecycle_success(tracker): - request_id_to_prompt = {"req1": "expand_me"} - tracker.expand_prompts(request_id_to_prompt) + aborted = tracker.abort_parents(["req1__cfg_text"]) - # Defer parent - engine_outputs = {"out": 123} - tracker.defer_parent("req1", engine_outputs, stage_id=0) + assert sorted(aborted) == ["req1", "req1__cfg_text"] + assert not tracker.is_companion("req1__cfg_text") + assert tracker.pop_pending_parent("req1") is None - # Initially not done - assert not tracker.all_companions_done("req1") - # Companion completes - parent_id = tracker.on_companion_completed("req1__cfg_text") +def test_companion_completion_flushes_deferred_parent(): + tracker = CfgCompanionTracker() + tracker.register_companion("req1", "cfg_text", "req1__cfg_text") + tracker.defer_parent("req1", {"out": 123}, stage_id=0) - # Parent should be returned since all companions are done and it is pending - assert parent_id == "req1" + assert not tracker.all_companions_done("req1") + assert tracker.on_companion_completed("req1__cfg_text") == "req1" assert tracker.all_companions_done("req1") - # Pop pending parent popped = tracker.pop_pending_parent("req1") assert popped is not None - assert popped["engine_outputs"] == engine_outputs + assert popped["engine_outputs"] == {"out": 123} assert popped["stage_id"] == 0 - - -def test_companion_lifecycle_failure(tracker): - request_id_to_prompt = {"req1": "expand_me"} - tracker.expand_prompts(request_id_to_prompt) - - tracker.defer_parent("req1", {"out": 123}, stage_id=0) - - # Companion fails - parent_id, aborted = tracker.on_companion_error("req1__cfg_text") - - assert parent_id == "req1" - assert aborted is True - assert tracker.is_parent_failed("req1") - - # Parent should be removed from pending list - assert tracker.pop_pending_parent("req1") is None - - # Consume failure - tracker.consume_parent_failure("req1") - assert not tracker.is_parent_failed("req1") - - -def test_companion_lifecycle_timeout(tracker): - request_id_to_prompt = {"req1": "expand_me"} - tracker.expand_prompts(request_id_to_prompt) - - tracker.defer_parent("req1", {"out": 123}, stage_id=0) - - # Initially no timeouts - timeouts = tracker.check_timeouts() - assert len(timeouts) == 0 - - # Wait for timeout - time.sleep(0.15) - - # Check timeouts again - timeouts = tracker.check_timeouts() - assert len(timeouts) == 1 - assert timeouts[0] == "req1" - - # Should be removed from pending - assert tracker.pop_pending_parent("req1") is None diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index f7e7d53d58b..0a1d9669b25 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -806,7 +806,6 @@ def _enqueue_cfg_companions( params=companion_params, supported_tasks=self.supported_tasks, ) - request = _upgrade_to_omni_request(request, companion_prompt) request.external_req_id = cid self.output_processors[0].add_request( diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index 386b545eb75..69050c19f6d 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -27,6 +27,7 @@ from vllm_omni.engine import ( OmniEngineCoreRequest, ) +from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.engine.serialization import serialize_additional_information from vllm_omni.metrics.stats import StageRequestStats as StageRequestMetrics from vllm_omni.metrics.stats import StageStats @@ -138,11 +139,7 @@ def __init__( self.request_states: dict[str, OrchestratorRequestState] = {} # CFG companion tracking - self._companion_map: dict[str, dict[str, str]] = {} - self._companion_to_parent: dict[str, str] = {} - self._companion_ids: set[str] = set() - self._companion_done: dict[str, set[str]] = {} - self._deferred_parents: dict[str, dict[str, Any]] = {} + self._cfg_tracker = CfgCompanionTracker() # Per-stage metrics accumulators. self._batch_seq: list[int] = [0] * self.num_stages @@ -317,7 +314,7 @@ async def _route_output( # CFG companion handling: companions don't produce user-visible output # and don't forward to the next stage directly. - if finished and req_id in self._companion_ids: + if finished and self._cfg_tracker.is_companion(req_id): await self._handle_cfg_companion_ready(req_id) self.request_states.pop(req_id, None) return @@ -351,57 +348,34 @@ async def _route_output( and not self.async_chunk and not self._next_stage_already_submitted(stage_id, req_state) ): - if req_id in self._companion_map and not self._all_companions_done(req_id): - self._deferred_parents[req_id] = { - "stage_id": stage_id, - "output": output, - } + if self._cfg_tracker.has_companions(req_id) and not self._cfg_tracker.all_companions_done(req_id): + self._cfg_tracker.defer_parent(req_id, output, stage_id) else: await self._forward_to_next_stage(req_id, stage_id, output, req_state) if finished and stage_id == req_state.final_stage_id: - self._cleanup_companion_state(req_id) + self._cfg_tracker.cleanup_parent(req_id) self.request_states.pop(req_id, None) - def _cleanup_companion_state(self, parent_id: str) -> None: - """Remove all companion tracking state for a completed parent.""" - role_map = self._companion_map.pop(parent_id, {}) - for cid in role_map.values(): - self._companion_ids.discard(cid) - self._companion_to_parent.pop(cid, None) - self._companion_done.pop(parent_id, None) - self._deferred_parents.pop(parent_id, None) - - def _all_companions_done(self, parent_id: str) -> bool: - """Check whether all CFG companions for a parent request have finished.""" - role_map = self._companion_map.get(parent_id, {}) - if not role_map: - return True - done_set = self._companion_done.get(parent_id, set()) - return all(cid in done_set for cid in role_map.values()) - def _next_stage_already_submitted(self, stage_id: int, req_state: OrchestratorRequestState) -> bool: return (stage_id + 1) in req_state.stage_submit_ts async def _handle_cfg_companion_ready(self, req_id: str) -> None: """Mark a CFG companion as done; if all companions are done, flush deferred parent.""" - parent_id = self._companion_to_parent.get(req_id) + parent_id = self._cfg_tracker.on_companion_completed(req_id) if parent_id is None: return - done_set = self._companion_done.setdefault(parent_id, set()) - if req_id in done_set: + deferred = self._cfg_tracker.pop_pending_parent(parent_id) + if deferred is None: return - done_set.add(req_id) - if parent_id in self._deferred_parents and self._all_companions_done(parent_id): - deferred = self._deferred_parents.pop(parent_id) - parent_state = self.request_states.get(parent_id) - if parent_state is not None and not self._next_stage_already_submitted(deferred["stage_id"], parent_state): - await self._forward_to_next_stage( - parent_id, - deferred["stage_id"], - deferred["output"], - parent_state, - ) + parent_state = self.request_states.get(parent_id) + if parent_state is not None and not self._next_stage_already_submitted(deferred["stage_id"], parent_state): + await self._forward_to_next_stage( + parent_id, + deferred["stage_id"], + deferred["engine_outputs"], + parent_state, + ) async def _handle_kv_ready_raw_outputs(self, stage_id: int, raw_outputs: EngineCoreOutputs) -> None: """Forward split requests once stage-0 KV is ready, not only when decode fully finishes.""" @@ -415,18 +389,15 @@ async def _handle_kv_ready_raw_outputs(self, stage_id: int, raw_outputs: EngineC req_state = self.request_states.get(req_id) if req_state is None: continue - if req_id in self._companion_ids: + if self._cfg_tracker.is_companion(req_id): await self._handle_cfg_companion_ready(req_id) continue if stage_id >= req_state.final_stage_id: continue if self._next_stage_already_submitted(stage_id, req_state): continue - if req_id in self._companion_map and not self._all_companions_done(req_id): - self._deferred_parents[req_id] = { - "stage_id": stage_id, - "output": raw_output, - } + if self._cfg_tracker.has_companions(req_id) and not self._cfg_tracker.all_companions_done(req_id): + self._cfg_tracker.defer_parent(req_id, raw_output, stage_id) else: await self._forward_to_next_stage(req_id, stage_id, raw_output, req_state) @@ -531,20 +502,7 @@ async def _forward_to_next_stage( else: diffusion_prompt = req_state.prompt - # Attach CFG companion KV request IDs so the diffusion model - # runner can fetch companion KV caches alongside the primary one. - cfg_ids = self._companion_map.get(req_id) - if cfg_ids: - from vllm_omni.inputs.data import OmniDiffusionSamplingParams - - if isinstance(params, OmniDiffusionSamplingParams): - params = copy.deepcopy(params) - params.cfg_kv_request_ids = cfg_ids - logger.info( - "[Orchestrator] Attaching cfg_kv_request_ids=%s to req %s", - cfg_ids, - req_id, - ) + params = self._cfg_tracker.attach_cfg_request_ids(req_id, params) source_stage_ids = list(getattr(next_client, "engine_input_source", None) or [stage_id]) kv_sender_info = self._build_kv_sender_info(sender_stage_ids=source_stage_ids) @@ -800,13 +758,7 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None: companion_prompt = msg["prompt"] sampling_params_list = msg["sampling_params_list"] - # Register companion mapping - if parent_id not in self._companion_map: - self._companion_map[parent_id] = {} - self._companion_map[parent_id][role] = companion_id - self._companion_ids.add(companion_id) - self._companion_to_parent[companion_id] = parent_id - self._companion_done.setdefault(parent_id, set()) + self._cfg_tracker.register_companion(parent_id, role, companion_id) companion_state = OrchestratorRequestState( request_id=companion_id, @@ -831,22 +783,10 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None: async def _handle_abort(self, msg: dict[str, Any]) -> None: """Handle an abort message from the main thread.""" request_ids = msg["request_ids"] - # Also abort any CFG companions for aborted parents - companion_ids_to_abort: list[str] = [] - for req_id in request_ids: - role_map = self._companion_map.pop(req_id, {}) - for cid in role_map.values(): - companion_ids_to_abort.append(cid) - self._companion_ids.discard(cid) - self._companion_to_parent.pop(cid, None) - self.request_states.pop(cid, None) - self._companion_done.pop(req_id, None) - self._deferred_parents.pop(req_id, None) - - all_ids_to_abort = list(request_ids) + companion_ids_to_abort + all_ids_to_abort = self._cfg_tracker.abort_parents(request_ids) for stage_id in range(self.num_stages): await self.stage_clients[stage_id].abort_requests_async(all_ids_to_abort) - for req_id in request_ids: + for req_id in all_ids_to_abort: self.request_states.pop(req_id, None) logger.info("[Orchestrator] Aborted request(s) %s", request_ids) diff --git a/vllm_omni/entrypoints/cfg_companion_tracker.py b/vllm_omni/entrypoints/cfg_companion_tracker.py index 9c2e835f074..be3db9d3e35 100644 --- a/vllm_omni/entrypoints/cfg_companion_tracker.py +++ b/vllm_omni/entrypoints/cfg_companion_tracker.py @@ -1,22 +1,16 @@ """CFG companion request tracker for the Omni orchestrator. Encapsulates all bookkeeping for Classifier-Free Guidance companion -requests (prompt expansion, parent/companion ID mapping, completion -tracking, deferred forwarding, failure propagation, and timeouts) -so that ``Omni._run_generation`` stays clean. +requests (parent/companion ID mapping, completion tracking, +deferred forwarding, and cleanup). """ from __future__ import annotations -import copy import logging -import os -import time -from collections.abc import Callable, Sequence from typing import Any -from vllm_omni.distributed.omni_connectors.adapter import try_send_via_connector -from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams +from vllm_omni.inputs.data import OmniDiffusionSamplingParams logger = logging.getLogger(__name__) @@ -24,66 +18,12 @@ class CfgCompanionTracker: """Manages CFG companion request lifecycle in the orchestrator scheduling loop.""" - def __init__( - self, - prompt_expand_func: Callable[..., Any] | None, - stage0_sampling_params: Any, - timeout_s: float | None = None, - ) -> None: - self._expand_func = prompt_expand_func - self._sp0 = stage0_sampling_params - self._timeout_s = ( - timeout_s if timeout_s is not None else float(os.environ.get("VLLM_CFG_PENDING_TIMEOUT_S", "120")) - ) - + def __init__(self) -> None: self._companion_map: dict[str, dict[str, str]] = {} # parent -> {role: companion_id} self._companion_ids: set[str] = set() self._companion_to_parent: dict[str, str] = {} # companion -> parent self._done: dict[str, set[str]] = {} # parent -> completed companion ids self._pending_parents: dict[str, dict[str, Any]] = {} # parent -> deferred result - self._failed_parents: set[str] = set() - - @property - def is_active(self) -> bool: - return bool(self._companion_ids) - - @property - def num_companions(self) -> int: - return len(self._companion_ids) - - @property - def stage0_sampling_params(self) -> Any: - return self._sp0 - - def expand_prompts( - self, - request_id_to_prompt: dict[str, Any], - ) -> list[tuple[str, Any]]: - """Expand user prompts into ``(companion_id, prompt)`` pairs via model-specific func.""" - if not self._expand_func: - return [] - - pairs: list[tuple[str, Any]] = [] - for rid, prompt in request_id_to_prompt.items(): - expanded = self._expand_func(prompt, self._sp0) - if not expanded: - continue - role_map: dict[str, str] = {} - for ep in expanded: - cid = f"{rid}{ep.request_id_suffix}" - role_map[ep.role] = cid - self._companion_ids.add(cid) - self._companion_to_parent[cid] = rid - pairs.append((cid, ep.prompt)) - self._companion_map[rid] = role_map - self._done[rid] = set() - - logger.debug( - "CFG expansion: %d parent(s) -> %d companion(s)", - len(self._companion_map), - len(self._companion_ids), - ) - return pairs def is_companion(self, req_id: str) -> bool: return req_id in self._companion_ids @@ -100,37 +40,44 @@ def get_companion_request_ids(self, parent_id: str) -> dict[str, str]: """Return ``{role: companion_request_id}`` for a parent.""" return self._companion_map.get(parent_id, {}) - def is_parent_failed(self, parent_id: str) -> bool: - return parent_id in self._failed_parents + def register_parent(self, parent_id: str) -> None: + self._companion_map.setdefault(parent_id, {}) + self._done.setdefault(parent_id, set()) - # -- Lifecycle events -- + def register_companion(self, parent_id: str, role: str, companion_id: str) -> None: + self.register_parent(parent_id) + self._companion_map[parent_id][role] = companion_id + self._companion_ids.add(companion_id) + self._companion_to_parent[companion_id] = parent_id - def on_companion_error(self, companion_id: str) -> tuple[str | None, bool]: - """Record failure. Returns ``(parent_id, parent_was_aborted)``.""" - parent_id = self._companion_to_parent.get(companion_id) - if parent_id is None: - return None, False - self._failed_parents.add(parent_id) - logger.error("CFG companion %s failed; marking parent %s as failed", companion_id, parent_id) - aborted = parent_id in self._pending_parents - if aborted: - self._pending_parents.pop(parent_id, None) - return parent_id, aborted + def attach_cfg_request_ids(self, parent_id: str, sampling_params: Any) -> Any: + cfg_ids = self.get_companion_request_ids(parent_id) + if not cfg_ids: + return sampling_params + + if isinstance(sampling_params, OmniDiffusionSamplingParams): + sampling_params = sampling_params.clone() + sampling_params.cfg_kv_request_ids = cfg_ids + logger.info( + "Attaching cfg_kv_request_ids=%s to request %s", + cfg_ids, + parent_id, + ) + return sampling_params def on_companion_completed(self, companion_id: str) -> str | None: """Mark done. Returns parent_id only if parent is pending and all companions finished.""" parent_id = self._companion_to_parent.get(companion_id) if parent_id is None: return None + if companion_id in self._done.get(parent_id, set()): + return None self._done[parent_id].add(companion_id) logger.debug("CFG companion %s completed (parent=%s)", companion_id, parent_id) if parent_id in self._pending_parents and self.all_companions_done(parent_id): return parent_id return None - def consume_parent_failure(self, parent_id: str) -> None: - self._failed_parents.discard(parent_id) - # -- Deferred parent management -- def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> None: @@ -138,96 +85,44 @@ def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> No self._pending_parents[parent_id] = { "engine_outputs": engine_outputs, "stage_id": stage_id, - "pending_since": time.time(), } logger.debug("Parent %s deferred, waiting for CFG companions", parent_id) def pop_pending_parent(self, parent_id: str) -> dict[str, Any] | None: return self._pending_parents.pop(parent_id, None) - def check_timeouts(self) -> list[str]: - """Return and remove parent IDs that exceeded the pending timeout.""" - if not self._pending_parents: - return [] - now = time.time() - timed_out: list[str] = [] - for pid in list(self._pending_parents): - pending_since = self._pending_parents[pid].get("pending_since", now) - if now - pending_since > self._timeout_s: - self._pending_parents.pop(pid) - self._failed_parents.discard(pid) - timed_out.append(pid) - logger.error("Parent %s timed out waiting for CFG companions (>%.0fs)", pid, self._timeout_s) - return timed_out - - # -- Forward parent with CFG KV -- - - def forward_parent_with_cfg( - self, - req_id: str, - parent_result: dict[str, Any], - stage_list: Sequence[Any], - connectors: dict[tuple[str, str], Any], - sampling_params_list: Sequence[OmniSamplingParams], - request_id_to_prompt: dict[str, Any], - final_stage_id_to_prompt: dict[str, int], - metrics: Any, - remaining_by_stage: list[int], - ) -> bool: - """Forward a parent request to the next stage with CFG KV request IDs attached.""" - stage_id = parent_result["stage_id"] - next_stage_id = stage_id + 1 - if next_stage_id > final_stage_id_to_prompt.get(req_id, 0): - return True - - next_stage = stage_list[next_stage_id] - try: - with metrics.stage_postprocess_timer(stage_id, req_id): - next_inputs = next_stage.process_engine_inputs( - stage_list, - [request_id_to_prompt[req_id]], - source_outputs_override=parent_result["engine_outputs"], - ) - except Exception as e: - logger.exception( - "Process engine inputs error for req %s at stage %d: %s", - req_id, - next_stage_id, - e, - ) - return False - - sp_next = copy.deepcopy(sampling_params_list[next_stage_id]) - if isinstance(sp_next, OmniDiffusionSamplingParams): - sp_next.cfg_kv_request_ids = self.get_companion_request_ids(req_id) - logger.info( - "Attaching cfg_kv_request_ids=%s to request %s", - sp_next.cfg_kv_request_ids, - req_id, - ) - - connector_key = (str(stage_id), str(next_stage_id)) - connector = connectors.get(connector_key) - sent_via_connector = False - if connector: - sent_via_connector = try_send_via_connector( - connector=connector, - stage_id=stage_id, - next_stage_id=next_stage_id, - req_id=req_id, - next_inputs=next_inputs, - sampling_params=sp_next, - original_prompt=request_id_to_prompt[req_id], - next_stage_queue_submit_fn=stage_list[next_stage_id].submit, - metrics=metrics, - ) - - if not sent_via_connector: - raise RuntimeError( - f"Failed to send CFG request {req_id} to stage-{next_stage_id} via connector. " - "Configure a connector for this edge or inspect connector logs for details." - ) - - logger.debug("Forwarded CFG-enabled request %s to stage-%d", req_id, next_stage_id) - remaining_by_stage[next_stage_id] += 1 - return True + # -- Cleanup / abort -- + + def cleanup_parent(self, parent_id: str) -> list[str]: + companion_ids = list(self._companion_map.pop(parent_id, {}).values()) + for companion_id in companion_ids: + self._companion_ids.discard(companion_id) + self._companion_to_parent.pop(companion_id, None) + self._done.pop(parent_id, None) + self._pending_parents.pop(parent_id, None) + return companion_ids + + def abort_parents(self, request_ids: list[str]) -> list[str]: + all_request_ids = list(request_ids) + seen = set(all_request_ids) + parents_to_cleanup: set[str] = set() + + for req_id in request_ids: + if req_id in self._companion_ids: + parent_id = self._companion_to_parent.get(req_id) + if parent_id is not None: + parents_to_cleanup.add(parent_id) + if parent_id not in seen: + seen.add(parent_id) + all_request_ids.append(parent_id) + else: + parents_to_cleanup.add(req_id) + + for parent_id in parents_to_cleanup: + companion_ids = self.cleanup_parent(parent_id) + for companion_id in companion_ids: + if companion_id not in seen: + seen.add(companion_id) + all_request_ids.append(companion_id) + + return all_request_ids From 40e19ea166690ac054900a958919ebfe81b005e5 Mon Sep 17 00:00:00 2001 From: yinpe <11810305@mail.sustech.edu.cn> Date: Thu, 9 Apr 2026 13:03:34 +0800 Subject: [PATCH 2/6] fix precommit Signed-off-by: yinpe <11810305@mail.sustech.edu.cn> --- vllm_omni/engine/orchestrator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index 69050c19f6d..3a94361466d 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -27,8 +27,8 @@ from vllm_omni.engine import ( OmniEngineCoreRequest, ) -from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.engine.serialization import serialize_additional_information +from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.metrics.stats import StageRequestStats as StageRequestMetrics from vllm_omni.metrics.stats import StageStats from vllm_omni.metrics.utils import count_tokens_from_outputs From 2237c940c759f72e66bcc6bf02dc8b58dd156534 Mon Sep 17 00:00:00 2001 From: yinpe <11810305@mail.sustech.edu.cn> Date: Thu, 9 Apr 2026 16:26:00 +0800 Subject: [PATCH 3/6] mv to engine Signed-off-by: yinpe <11810305@mail.sustech.edu.cn> --- docs/api/README.md | 2 +- tests/{entrypoints => engine}/test_cfg_companion_tracker.py | 2 +- vllm_omni/{entrypoints => engine}/cfg_companion_tracker.py | 4 ---- vllm_omni/engine/orchestrator.py | 2 +- 4 files changed, 3 insertions(+), 7 deletions(-) rename tests/{entrypoints => engine}/test_cfg_companion_tracker.py (96%) rename vllm_omni/{entrypoints => engine}/cfg_companion_tracker.py (98%) diff --git a/docs/api/README.md b/docs/api/README.md index f65cbb525d9..0147f19e126 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -5,7 +5,7 @@ Main entry points for vLLM-Omni inference and serving. - [vllm_omni.entrypoints.async_omni.AsyncOmni][] -- [vllm_omni.entrypoints.cfg_companion_tracker.CfgCompanionTracker][] +- [vllm_omni.engine.cfg_companion_tracker.CfgCompanionTracker][] - [vllm_omni.entrypoints.cli.benchmark.base.OmniBenchmarkSubcommandBase][] - [vllm_omni.entrypoints.cli.benchmark.main.OmniBenchmarkSubcommand][] - [vllm_omni.entrypoints.cli.benchmark.serve.OmniBenchmarkServingSubcommand][] diff --git a/tests/entrypoints/test_cfg_companion_tracker.py b/tests/engine/test_cfg_companion_tracker.py similarity index 96% rename from tests/entrypoints/test_cfg_companion_tracker.py rename to tests/engine/test_cfg_companion_tracker.py index d316410c52e..7729060ab55 100644 --- a/tests/entrypoints/test_cfg_companion_tracker.py +++ b/tests/engine/test_cfg_companion_tracker.py @@ -1,6 +1,6 @@ import pytest -from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker +from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.inputs.data import OmniDiffusionSamplingParams pytestmark = [pytest.mark.core_model, pytest.mark.cpu] diff --git a/vllm_omni/entrypoints/cfg_companion_tracker.py b/vllm_omni/engine/cfg_companion_tracker.py similarity index 98% rename from vllm_omni/entrypoints/cfg_companion_tracker.py rename to vllm_omni/engine/cfg_companion_tracker.py index be3db9d3e35..de4bc75090d 100644 --- a/vllm_omni/entrypoints/cfg_companion_tracker.py +++ b/vllm_omni/engine/cfg_companion_tracker.py @@ -78,8 +78,6 @@ def on_companion_completed(self, companion_id: str) -> str | None: return parent_id return None - # -- Deferred parent management -- - def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> None: """Hold parent result while waiting for companions to finish.""" self._pending_parents[parent_id] = { @@ -91,8 +89,6 @@ def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> No def pop_pending_parent(self, parent_id: str) -> dict[str, Any] | None: return self._pending_parents.pop(parent_id, None) - # -- Cleanup / abort -- - def cleanup_parent(self, parent_id: str) -> list[str]: companion_ids = list(self._companion_map.pop(parent_id, {}).values()) for companion_id in companion_ids: diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index 3a94361466d..e838718cd33 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -27,8 +27,8 @@ from vllm_omni.engine import ( OmniEngineCoreRequest, ) +from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.engine.serialization import serialize_additional_information -from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.metrics.stats import StageRequestStats as StageRequestMetrics from vllm_omni.metrics.stats import StageStats from vllm_omni.metrics.utils import count_tokens_from_outputs From d15335a24d80bb97ab92fd3988ee812d49934e93 Mon Sep 17 00:00:00 2001 From: yinpe <11810305@mail.sustech.edu.cn> Date: Mon, 13 Apr 2026 14:25:31 +0800 Subject: [PATCH 4/6] fix Signed-off-by: yinpe <11810305@mail.sustech.edu.cn> --- tests/engine/test_cfg_companion_tracker.py | 13 +++++++++++-- tests/engine/test_orchestrator_kv_sender_info.py | 5 +++-- vllm_omni/engine/cfg_companion_tracker.py | 9 +-------- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/engine/test_cfg_companion_tracker.py b/tests/engine/test_cfg_companion_tracker.py index 7729060ab55..763652314a4 100644 --- a/tests/engine/test_cfg_companion_tracker.py +++ b/tests/engine/test_cfg_companion_tracker.py @@ -37,18 +37,27 @@ def test_attach_cfg_request_ids_clones_diffusion_params(): assert updated.cfg_kv_request_ids == {"cfg_text": "req1__cfg_text"} -def test_abort_parents_with_parent_or_companion_ids(): +def test_abort_parent_expands_to_companions_and_cleans_up_deferred_parent(): tracker = CfgCompanionTracker() tracker.register_companion("req1", "cfg_text", "req1__cfg_text") tracker.defer_parent("req1", {"out": 1}, stage_id=0) - aborted = tracker.abort_parents(["req1__cfg_text"]) + aborted = tracker.abort_parents(["req1"]) assert sorted(aborted) == ["req1", "req1__cfg_text"] assert not tracker.is_companion("req1__cfg_text") assert tracker.pop_pending_parent("req1") is None +def test_abort_companion_does_not_expand_to_parent(): + tracker = CfgCompanionTracker() + tracker.register_companion("req1", "cfg_text", "req1__cfg_text") + + aborted = tracker.abort_parents(["req1__cfg_text"]) + + assert aborted == ["req1__cfg_text"] + + def test_companion_completion_flushes_deferred_parent(): tracker = CfgCompanionTracker() tracker.register_companion("req1", "cfg_text", "req1__cfg_text") diff --git a/tests/engine/test_orchestrator_kv_sender_info.py b/tests/engine/test_orchestrator_kv_sender_info.py index 94da4ce7179..7e3fe0906e8 100644 --- a/tests/engine/test_orchestrator_kv_sender_info.py +++ b/tests/engine/test_orchestrator_kv_sender_info.py @@ -4,6 +4,7 @@ import pytest from vllm import SamplingParams +from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.engine.orchestrator import Orchestrator, OrchestratorRequestState from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClient from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -130,7 +131,7 @@ def test_forward_to_diffusion_attaches_kv_sender_info(): orchestrator.num_stages = 2 orchestrator.stage_clients = [sender_stage, diffusion_stage] - orchestrator._companion_map = {} + orchestrator._cfg_tracker = CfgCompanionTracker() orchestrator.stage_vllm_configs = [None, None] orchestrator.output_processors = [None, None] @@ -161,7 +162,7 @@ def test_forward_to_diffusion_uses_engine_input_source_for_kv_sender_info(): orchestrator.num_stages = 3 orchestrator.stage_clients = [source_stage, previous_stage, diffusion_stage] - orchestrator._companion_map = {} + orchestrator._cfg_tracker = CfgCompanionTracker() orchestrator.stage_vllm_configs = [None, None, None] orchestrator.output_processors = [None, None, None] diff --git a/vllm_omni/engine/cfg_companion_tracker.py b/vllm_omni/engine/cfg_companion_tracker.py index de4bc75090d..706d783cdf9 100644 --- a/vllm_omni/engine/cfg_companion_tracker.py +++ b/vllm_omni/engine/cfg_companion_tracker.py @@ -104,14 +104,7 @@ def abort_parents(self, request_ids: list[str]) -> list[str]: parents_to_cleanup: set[str] = set() for req_id in request_ids: - if req_id in self._companion_ids: - parent_id = self._companion_to_parent.get(req_id) - if parent_id is not None: - parents_to_cleanup.add(parent_id) - if parent_id not in seen: - seen.add(parent_id) - all_request_ids.append(parent_id) - else: + if req_id not in self._companion_ids: parents_to_cleanup.add(req_id) for parent_id in parents_to_cleanup: From a4cfc9b3ba9710fc406e882e6b04dd18c8932685 Mon Sep 17 00:00:00 2001 From: yinpe <11810305@mail.sustech.edu.cn> Date: Fri, 17 Apr 2026 11:12:43 +0800 Subject: [PATCH 5/6] update cfg tracker Signed-off-by: yinpe <11810305@mail.sustech.edu.cn> --- tests/engine/test_cfg_companion_tracker.py | 9 +++++++++ vllm_omni/engine/cfg_companion_tracker.py | 14 ++++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/engine/test_cfg_companion_tracker.py b/tests/engine/test_cfg_companion_tracker.py index 763652314a4..f856a38c3e3 100644 --- a/tests/engine/test_cfg_companion_tracker.py +++ b/tests/engine/test_cfg_companion_tracker.py @@ -71,3 +71,12 @@ def test_companion_completion_flushes_deferred_parent(): assert popped is not None assert popped["engine_outputs"] == {"out": 123} assert popped["stage_id"] == 0 + + +def test_companion_completion_without_registered_parent_asserts(): + tracker = CfgCompanionTracker() + tracker._companion_ids.add("req1__cfg_text") + tracker._companion_to_parent["req1__cfg_text"] = "req1" + + with pytest.raises(AssertionError, match="completed before parent req1 was registered"): + tracker.on_companion_completed("req1__cfg_text") diff --git a/vllm_omni/engine/cfg_companion_tracker.py b/vllm_omni/engine/cfg_companion_tracker.py index 706d783cdf9..7f540280882 100644 --- a/vllm_omni/engine/cfg_companion_tracker.py +++ b/vllm_omni/engine/cfg_companion_tracker.py @@ -70,9 +70,13 @@ def on_companion_completed(self, companion_id: str) -> str | None: parent_id = self._companion_to_parent.get(companion_id) if parent_id is None: return None - if companion_id in self._done.get(parent_id, set()): + done_set = self._done.get(parent_id) + assert done_set is not None, ( + f"Companion {companion_id} completed before parent {parent_id} was registered" + ) + if companion_id in done_set: return None - self._done[parent_id].add(companion_id) + done_set.add(companion_id) logger.debug("CFG companion %s completed (parent=%s)", companion_id, parent_id) if parent_id in self._pending_parents and self.all_companions_done(parent_id): return parent_id @@ -80,6 +84,9 @@ def on_companion_completed(self, companion_id: str) -> str | None: def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> None: """Hold parent result while waiting for companions to finish.""" + # TODO: Add timeout/error recovery when the orchestrator grows a + # companion-failure path. Today deferred parents are released only when + # companions finish or the external layer aborts the request. self._pending_parents[parent_id] = { "engine_outputs": engine_outputs, "stage_id": stage_id, @@ -104,6 +111,9 @@ def abort_parents(self, request_ids: list[str]) -> list[str]: parents_to_cleanup: set[str] = set() for req_id in request_ids: + # The orchestrator calls this with parent request IDs. If a raw + # companion ID is passed here, keep it as a direct abort target and + # avoid tearing down parent tracking state implicitly. if req_id not in self._companion_ids: parents_to_cleanup.add(req_id) From e520144ba03a61ccafc21367817e4862549edde0 Mon Sep 17 00:00:00 2001 From: yinpe <11810305@mail.sustech.edu.cn> Date: Fri, 17 Apr 2026 11:15:00 +0800 Subject: [PATCH 6/6] fix precommit Signed-off-by: yinpe <11810305@mail.sustech.edu.cn> --- vllm_omni/engine/cfg_companion_tracker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_omni/engine/cfg_companion_tracker.py b/vllm_omni/engine/cfg_companion_tracker.py index 7f540280882..b9dfae833e2 100644 --- a/vllm_omni/engine/cfg_companion_tracker.py +++ b/vllm_omni/engine/cfg_companion_tracker.py @@ -71,9 +71,7 @@ def on_companion_completed(self, companion_id: str) -> str | None: if parent_id is None: return None done_set = self._done.get(parent_id) - assert done_set is not None, ( - f"Companion {companion_id} completed before parent {parent_id} was registered" - ) + assert done_set is not None, f"Companion {companion_id} completed before parent {parent_id} was registered" if companion_id in done_set: return None done_set.add(companion_id)