diff --git a/docs/api/README.md b/docs/api/README.md index f65cbb525d9..0147f19e126 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -5,7 +5,7 @@ Main entry points for vLLM-Omni inference and serving. - [vllm_omni.entrypoints.async_omni.AsyncOmni][] -- [vllm_omni.entrypoints.cfg_companion_tracker.CfgCompanionTracker][] +- [vllm_omni.engine.cfg_companion_tracker.CfgCompanionTracker][] - [vllm_omni.entrypoints.cli.benchmark.base.OmniBenchmarkSubcommandBase][] - [vllm_omni.entrypoints.cli.benchmark.main.OmniBenchmarkSubcommand][] - [vllm_omni.entrypoints.cli.benchmark.serve.OmniBenchmarkServingSubcommand][] diff --git a/tests/engine/test_cfg_companion_tracker.py b/tests/engine/test_cfg_companion_tracker.py new file mode 100644 index 00000000000..f856a38c3e3 --- /dev/null +++ b/tests/engine/test_cfg_companion_tracker.py @@ -0,0 +1,82 @@ +import pytest + +from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def test_register_companion_and_cleanup(): + tracker = CfgCompanionTracker() + + tracker.register_companion("req1", "cfg_text", "req1__cfg_text") + tracker.register_companion("req1", "cfg_img", "req1__cfg_img") + + assert tracker.is_companion("req1__cfg_text") + assert tracker.get_companion_request_ids("req1") == { + "cfg_text": "req1__cfg_text", + "cfg_img": "req1__cfg_img", + } + + removed = tracker.cleanup_parent("req1") + + assert sorted(removed) == ["req1__cfg_img", "req1__cfg_text"] + assert not tracker.is_companion("req1__cfg_text") + assert tracker.get_companion_request_ids("req1") == {} + + +def test_attach_cfg_request_ids_clones_diffusion_params(): + tracker = CfgCompanionTracker() + tracker.register_companion("req1", "cfg_text", "req1__cfg_text") + + params = OmniDiffusionSamplingParams() + updated = tracker.attach_cfg_request_ids("req1", params) + + assert updated is not params + assert params.cfg_kv_request_ids is None + assert updated.cfg_kv_request_ids == {"cfg_text": "req1__cfg_text"} + + +def test_abort_parent_expands_to_companions_and_cleans_up_deferred_parent(): + tracker = CfgCompanionTracker() + tracker.register_companion("req1", "cfg_text", "req1__cfg_text") + tracker.defer_parent("req1", {"out": 1}, stage_id=0) + + aborted = tracker.abort_parents(["req1"]) + + assert sorted(aborted) == ["req1", "req1__cfg_text"] + assert not tracker.is_companion("req1__cfg_text") + assert tracker.pop_pending_parent("req1") is None + + +def test_abort_companion_does_not_expand_to_parent(): + tracker = CfgCompanionTracker() + tracker.register_companion("req1", "cfg_text", "req1__cfg_text") + + aborted = tracker.abort_parents(["req1__cfg_text"]) + + assert aborted == ["req1__cfg_text"] + + +def test_companion_completion_flushes_deferred_parent(): + tracker = CfgCompanionTracker() + tracker.register_companion("req1", "cfg_text", "req1__cfg_text") + tracker.defer_parent("req1", {"out": 123}, stage_id=0) + + assert not tracker.all_companions_done("req1") + assert tracker.on_companion_completed("req1__cfg_text") == "req1" + assert tracker.all_companions_done("req1") + + popped = tracker.pop_pending_parent("req1") + assert popped is not None + assert popped["engine_outputs"] == {"out": 123} + assert popped["stage_id"] == 0 + + +def test_companion_completion_without_registered_parent_asserts(): + tracker = CfgCompanionTracker() + tracker._companion_ids.add("req1__cfg_text") + tracker._companion_to_parent["req1__cfg_text"] = "req1" + + with pytest.raises(AssertionError, match="completed before parent req1 was registered"): + tracker.on_companion_completed("req1__cfg_text") diff --git a/tests/engine/test_orchestrator_kv_sender_info.py b/tests/engine/test_orchestrator_kv_sender_info.py index 94da4ce7179..7e3fe0906e8 100644 --- a/tests/engine/test_orchestrator_kv_sender_info.py +++ b/tests/engine/test_orchestrator_kv_sender_info.py @@ -4,6 +4,7 @@ import pytest from vllm import SamplingParams +from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.engine.orchestrator import Orchestrator, OrchestratorRequestState from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClient from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -130,7 +131,7 @@ def test_forward_to_diffusion_attaches_kv_sender_info(): orchestrator.num_stages = 2 orchestrator.stage_clients = [sender_stage, diffusion_stage] - orchestrator._companion_map = {} + orchestrator._cfg_tracker = CfgCompanionTracker() orchestrator.stage_vllm_configs = [None, None] orchestrator.output_processors = [None, None] @@ -161,7 +162,7 @@ def test_forward_to_diffusion_uses_engine_input_source_for_kv_sender_info(): orchestrator.num_stages = 3 orchestrator.stage_clients = [source_stage, previous_stage, diffusion_stage] - orchestrator._companion_map = {} + orchestrator._cfg_tracker = CfgCompanionTracker() orchestrator.stage_vllm_configs = [None, None, None] orchestrator.output_processors = [None, None, None] diff --git a/tests/entrypoints/test_cfg_companion_tracker.py b/tests/entrypoints/test_cfg_companion_tracker.py deleted file mode 100644 index 941ead41ff0..00000000000 --- a/tests/entrypoints/test_cfg_companion_tracker.py +++ /dev/null @@ -1,114 +0,0 @@ -import time -from types import SimpleNamespace - -import pytest - -from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - - -def dummy_expand_func(prompt, sp0): - if prompt == "expand_me": - return [SimpleNamespace(prompt={"prompt": "neg"}, role="cfg_text", request_id_suffix="__cfg_text")] - return [] - - -@pytest.fixture -def tracker(): - sp0 = SimpleNamespace() - return CfgCompanionTracker(prompt_expand_func=dummy_expand_func, stage0_sampling_params=sp0, timeout_s=0.1) - - -def test_companion_tracker_initialization(tracker): - assert not tracker.is_active - assert tracker.num_companions == 0 - - -def test_expand_prompts_registers_companions(tracker): - request_id_to_prompt = {"req1": "expand_me", "req2": "do_not_expand"} - - pairs = tracker.expand_prompts(request_id_to_prompt) - - assert len(pairs) == 1 - companion_id, prompt = pairs[0] - assert companion_id == "req1__cfg_text" - assert prompt == {"prompt": "neg"} - - assert tracker.is_active - assert tracker.num_companions == 1 - assert tracker.is_companion("req1__cfg_text") - assert not tracker.is_companion("req2__cfg_text") - assert tracker.has_companions("req1") - assert not tracker.has_companions("req2") - - comp_map = tracker.get_companion_request_ids("req1") - assert comp_map == {"cfg_text": "req1__cfg_text"} - - -def test_companion_lifecycle_success(tracker): - request_id_to_prompt = {"req1": "expand_me"} - tracker.expand_prompts(request_id_to_prompt) - - # Defer parent - engine_outputs = {"out": 123} - tracker.defer_parent("req1", engine_outputs, stage_id=0) - - # Initially not done - assert not tracker.all_companions_done("req1") - - # Companion completes - parent_id = tracker.on_companion_completed("req1__cfg_text") - - # Parent should be returned since all companions are done and it is pending - assert parent_id == "req1" - assert tracker.all_companions_done("req1") - - # Pop pending parent - popped = tracker.pop_pending_parent("req1") - assert popped is not None - assert popped["engine_outputs"] == engine_outputs - assert popped["stage_id"] == 0 - - -def test_companion_lifecycle_failure(tracker): - request_id_to_prompt = {"req1": "expand_me"} - tracker.expand_prompts(request_id_to_prompt) - - tracker.defer_parent("req1", {"out": 123}, stage_id=0) - - # Companion fails - parent_id, aborted = tracker.on_companion_error("req1__cfg_text") - - assert parent_id == "req1" - assert aborted is True - assert tracker.is_parent_failed("req1") - - # Parent should be removed from pending list - assert tracker.pop_pending_parent("req1") is None - - # Consume failure - tracker.consume_parent_failure("req1") - assert not tracker.is_parent_failed("req1") - - -def test_companion_lifecycle_timeout(tracker): - request_id_to_prompt = {"req1": "expand_me"} - tracker.expand_prompts(request_id_to_prompt) - - tracker.defer_parent("req1", {"out": 123}, stage_id=0) - - # Initially no timeouts - timeouts = tracker.check_timeouts() - assert len(timeouts) == 0 - - # Wait for timeout - time.sleep(0.15) - - # Check timeouts again - timeouts = tracker.check_timeouts() - assert len(timeouts) == 1 - assert timeouts[0] == "req1" - - # Should be removed from pending - assert tracker.pop_pending_parent("req1") is None diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 01c11cb9603..238bbdcdbd4 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1072,7 +1072,6 @@ def _enqueue_cfg_companions( params=companion_params, supported_tasks=self.supported_tasks, ) - request = _upgrade_to_omni_request(request, companion_prompt) request.external_req_id = cid self.output_processors[0].add_request( diff --git a/vllm_omni/engine/cfg_companion_tracker.py b/vllm_omni/engine/cfg_companion_tracker.py new file mode 100644 index 00000000000..b9dfae833e2 --- /dev/null +++ b/vllm_omni/engine/cfg_companion_tracker.py @@ -0,0 +1,125 @@ +"""CFG companion request tracker for the Omni orchestrator. + +Encapsulates all bookkeeping for Classifier-Free Guidance companion +requests (parent/companion ID mapping, completion tracking, +deferred forwarding, and cleanup). +""" + +from __future__ import annotations + +import logging +from typing import Any + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +logger = logging.getLogger(__name__) + + +class CfgCompanionTracker: + """Manages CFG companion request lifecycle in the orchestrator scheduling loop.""" + + def __init__(self) -> None: + self._companion_map: dict[str, dict[str, str]] = {} # parent -> {role: companion_id} + self._companion_ids: set[str] = set() + self._companion_to_parent: dict[str, str] = {} # companion -> parent + self._done: dict[str, set[str]] = {} # parent -> completed companion ids + self._pending_parents: dict[str, dict[str, Any]] = {} # parent -> deferred result + + def is_companion(self, req_id: str) -> bool: + return req_id in self._companion_ids + + def has_companions(self, parent_id: str) -> bool: + return parent_id in self._companion_map + + def all_companions_done(self, parent_id: str) -> bool: + role_map = self._companion_map.get(parent_id, {}) + done_set = self._done.get(parent_id, set()) + return all(cid in done_set for cid in role_map.values()) + + def get_companion_request_ids(self, parent_id: str) -> dict[str, str]: + """Return ``{role: companion_request_id}`` for a parent.""" + return self._companion_map.get(parent_id, {}) + + def register_parent(self, parent_id: str) -> None: + self._companion_map.setdefault(parent_id, {}) + self._done.setdefault(parent_id, set()) + + def register_companion(self, parent_id: str, role: str, companion_id: str) -> None: + self.register_parent(parent_id) + self._companion_map[parent_id][role] = companion_id + self._companion_ids.add(companion_id) + self._companion_to_parent[companion_id] = parent_id + + def attach_cfg_request_ids(self, parent_id: str, sampling_params: Any) -> Any: + cfg_ids = self.get_companion_request_ids(parent_id) + if not cfg_ids: + return sampling_params + + if isinstance(sampling_params, OmniDiffusionSamplingParams): + sampling_params = sampling_params.clone() + sampling_params.cfg_kv_request_ids = cfg_ids + logger.info( + "Attaching cfg_kv_request_ids=%s to request %s", + cfg_ids, + parent_id, + ) + return sampling_params + + def on_companion_completed(self, companion_id: str) -> str | None: + """Mark done. Returns parent_id only if parent is pending and all companions finished.""" + parent_id = self._companion_to_parent.get(companion_id) + if parent_id is None: + return None + done_set = self._done.get(parent_id) + assert done_set is not None, f"Companion {companion_id} completed before parent {parent_id} was registered" + if companion_id in done_set: + return None + done_set.add(companion_id) + logger.debug("CFG companion %s completed (parent=%s)", companion_id, parent_id) + if parent_id in self._pending_parents and self.all_companions_done(parent_id): + return parent_id + return None + + def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> None: + """Hold parent result while waiting for companions to finish.""" + # TODO: Add timeout/error recovery when the orchestrator grows a + # companion-failure path. Today deferred parents are released only when + # companions finish or the external layer aborts the request. + self._pending_parents[parent_id] = { + "engine_outputs": engine_outputs, + "stage_id": stage_id, + } + logger.debug("Parent %s deferred, waiting for CFG companions", parent_id) + + def pop_pending_parent(self, parent_id: str) -> dict[str, Any] | None: + return self._pending_parents.pop(parent_id, None) + + def cleanup_parent(self, parent_id: str) -> list[str]: + companion_ids = list(self._companion_map.pop(parent_id, {}).values()) + for companion_id in companion_ids: + self._companion_ids.discard(companion_id) + self._companion_to_parent.pop(companion_id, None) + self._done.pop(parent_id, None) + self._pending_parents.pop(parent_id, None) + return companion_ids + + def abort_parents(self, request_ids: list[str]) -> list[str]: + all_request_ids = list(request_ids) + seen = set(all_request_ids) + parents_to_cleanup: set[str] = set() + + for req_id in request_ids: + # The orchestrator calls this with parent request IDs. If a raw + # companion ID is passed here, keep it as a direct abort target and + # avoid tearing down parent tracking state implicitly. + if req_id not in self._companion_ids: + parents_to_cleanup.add(req_id) + + for parent_id in parents_to_cleanup: + companion_ids = self.cleanup_parent(parent_id) + for companion_id in companion_ids: + if companion_id not in seen: + seen.add(companion_id) + all_request_ids.append(companion_id) + + return all_request_ids diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index 0fdab9c0d2b..1de4282fea2 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -27,6 +27,7 @@ from vllm_omni.engine import ( OmniEngineCoreRequest, ) +from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.engine.serialization import serialize_additional_information from vllm_omni.metrics.stats import StageRequestStats as StageRequestMetrics from vllm_omni.metrics.stats import StageStats @@ -138,11 +139,7 @@ def __init__( self.request_states: dict[str, OrchestratorRequestState] = {} # CFG companion tracking - self._companion_map: dict[str, dict[str, str]] = {} - self._companion_to_parent: dict[str, str] = {} - self._companion_ids: set[str] = set() - self._companion_done: dict[str, set[str]] = {} - self._deferred_parents: dict[str, dict[str, Any]] = {} + self._cfg_tracker = CfgCompanionTracker() # Per-stage metrics accumulators. self._batch_seq: list[int] = [0] * self.num_stages @@ -334,7 +331,7 @@ async def _route_output( # CFG companion handling: companions don't produce user-visible output # and don't forward to the next stage directly. - if finished and req_id in self._companion_ids: + if finished and self._cfg_tracker.is_companion(req_id): await self._handle_cfg_companion_ready(req_id) self.request_states.pop(req_id, None) return @@ -368,57 +365,34 @@ async def _route_output( and not self.async_chunk and not self._next_stage_already_submitted(stage_id, req_state) ): - if req_id in self._companion_map and not self._all_companions_done(req_id): - self._deferred_parents[req_id] = { - "stage_id": stage_id, - "output": output, - } + if self._cfg_tracker.has_companions(req_id) and not self._cfg_tracker.all_companions_done(req_id): + self._cfg_tracker.defer_parent(req_id, output, stage_id) else: await self._forward_to_next_stage(req_id, stage_id, output, req_state) if finished and stage_id == req_state.final_stage_id: - self._cleanup_companion_state(req_id) + self._cfg_tracker.cleanup_parent(req_id) self.request_states.pop(req_id, None) - def _cleanup_companion_state(self, parent_id: str) -> None: - """Remove all companion tracking state for a completed parent.""" - role_map = self._companion_map.pop(parent_id, {}) - for cid in role_map.values(): - self._companion_ids.discard(cid) - self._companion_to_parent.pop(cid, None) - self._companion_done.pop(parent_id, None) - self._deferred_parents.pop(parent_id, None) - - def _all_companions_done(self, parent_id: str) -> bool: - """Check whether all CFG companions for a parent request have finished.""" - role_map = self._companion_map.get(parent_id, {}) - if not role_map: - return True - done_set = self._companion_done.get(parent_id, set()) - return all(cid in done_set for cid in role_map.values()) - def _next_stage_already_submitted(self, stage_id: int, req_state: OrchestratorRequestState) -> bool: return (stage_id + 1) in req_state.stage_submit_ts async def _handle_cfg_companion_ready(self, req_id: str) -> None: """Mark a CFG companion as done; if all companions are done, flush deferred parent.""" - parent_id = self._companion_to_parent.get(req_id) + parent_id = self._cfg_tracker.on_companion_completed(req_id) if parent_id is None: return - done_set = self._companion_done.setdefault(parent_id, set()) - if req_id in done_set: + deferred = self._cfg_tracker.pop_pending_parent(parent_id) + if deferred is None: return - done_set.add(req_id) - if parent_id in self._deferred_parents and self._all_companions_done(parent_id): - deferred = self._deferred_parents.pop(parent_id) - parent_state = self.request_states.get(parent_id) - if parent_state is not None and not self._next_stage_already_submitted(deferred["stage_id"], parent_state): - await self._forward_to_next_stage( - parent_id, - deferred["stage_id"], - deferred["output"], - parent_state, - ) + parent_state = self.request_states.get(parent_id) + if parent_state is not None and not self._next_stage_already_submitted(deferred["stage_id"], parent_state): + await self._forward_to_next_stage( + parent_id, + deferred["stage_id"], + deferred["engine_outputs"], + parent_state, + ) async def _handle_kv_ready_raw_outputs(self, stage_id: int, raw_outputs: EngineCoreOutputs) -> None: """Forward split requests once stage-0 KV is ready, not only when decode fully finishes.""" @@ -432,18 +406,15 @@ async def _handle_kv_ready_raw_outputs(self, stage_id: int, raw_outputs: EngineC req_state = self.request_states.get(req_id) if req_state is None: continue - if req_id in self._companion_ids: + if self._cfg_tracker.is_companion(req_id): await self._handle_cfg_companion_ready(req_id) continue if stage_id >= req_state.final_stage_id: continue if self._next_stage_already_submitted(stage_id, req_state): continue - if req_id in self._companion_map and not self._all_companions_done(req_id): - self._deferred_parents[req_id] = { - "stage_id": stage_id, - "output": raw_output, - } + if self._cfg_tracker.has_companions(req_id) and not self._cfg_tracker.all_companions_done(req_id): + self._cfg_tracker.defer_parent(req_id, raw_output, stage_id) else: await self._forward_to_next_stage(req_id, stage_id, raw_output, req_state) @@ -548,20 +519,7 @@ async def _forward_to_next_stage( else: diffusion_prompt = req_state.prompt - # Attach CFG companion KV request IDs so the diffusion model - # runner can fetch companion KV caches alongside the primary one. - cfg_ids = self._companion_map.get(req_id) - if cfg_ids: - from vllm_omni.inputs.data import OmniDiffusionSamplingParams - - if isinstance(params, OmniDiffusionSamplingParams): - params = copy.deepcopy(params) - params.cfg_kv_request_ids = cfg_ids - logger.info( - "[Orchestrator] Attaching cfg_kv_request_ids=%s to req %s", - cfg_ids, - req_id, - ) + params = self._cfg_tracker.attach_cfg_request_ids(req_id, params) source_stage_ids = list(getattr(next_client, "engine_input_source", None) or [stage_id]) kv_sender_info = self._build_kv_sender_info(sender_stage_ids=source_stage_ids) @@ -817,13 +775,7 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None: companion_prompt = msg["prompt"] sampling_params_list = msg["sampling_params_list"] - # Register companion mapping - if parent_id not in self._companion_map: - self._companion_map[parent_id] = {} - self._companion_map[parent_id][role] = companion_id - self._companion_ids.add(companion_id) - self._companion_to_parent[companion_id] = parent_id - self._companion_done.setdefault(parent_id, set()) + self._cfg_tracker.register_companion(parent_id, role, companion_id) companion_state = OrchestratorRequestState( request_id=companion_id, @@ -848,22 +800,10 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None: async def _handle_abort(self, msg: dict[str, Any]) -> None: """Handle an abort message from the main thread.""" request_ids = msg["request_ids"] - # Also abort any CFG companions for aborted parents - companion_ids_to_abort: list[str] = [] - for req_id in request_ids: - role_map = self._companion_map.pop(req_id, {}) - for cid in role_map.values(): - companion_ids_to_abort.append(cid) - self._companion_ids.discard(cid) - self._companion_to_parent.pop(cid, None) - self.request_states.pop(cid, None) - self._companion_done.pop(req_id, None) - self._deferred_parents.pop(req_id, None) - - all_ids_to_abort = list(request_ids) + companion_ids_to_abort + all_ids_to_abort = self._cfg_tracker.abort_parents(request_ids) for stage_id in range(self.num_stages): await self.stage_clients[stage_id].abort_requests_async(all_ids_to_abort) - for req_id in request_ids: + for req_id in all_ids_to_abort: self.request_states.pop(req_id, None) logger.info("[Orchestrator] Aborted request(s) %s", request_ids) diff --git a/vllm_omni/entrypoints/cfg_companion_tracker.py b/vllm_omni/entrypoints/cfg_companion_tracker.py deleted file mode 100644 index 9c2e835f074..00000000000 --- a/vllm_omni/entrypoints/cfg_companion_tracker.py +++ /dev/null @@ -1,233 +0,0 @@ -"""CFG companion request tracker for the Omni orchestrator. - -Encapsulates all bookkeeping for Classifier-Free Guidance companion -requests (prompt expansion, parent/companion ID mapping, completion -tracking, deferred forwarding, failure propagation, and timeouts) -so that ``Omni._run_generation`` stays clean. -""" - -from __future__ import annotations - -import copy -import logging -import os -import time -from collections.abc import Callable, Sequence -from typing import Any - -from vllm_omni.distributed.omni_connectors.adapter import try_send_via_connector -from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams - -logger = logging.getLogger(__name__) - - -class CfgCompanionTracker: - """Manages CFG companion request lifecycle in the orchestrator scheduling loop.""" - - def __init__( - self, - prompt_expand_func: Callable[..., Any] | None, - stage0_sampling_params: Any, - timeout_s: float | None = None, - ) -> None: - self._expand_func = prompt_expand_func - self._sp0 = stage0_sampling_params - self._timeout_s = ( - timeout_s if timeout_s is not None else float(os.environ.get("VLLM_CFG_PENDING_TIMEOUT_S", "120")) - ) - - self._companion_map: dict[str, dict[str, str]] = {} # parent -> {role: companion_id} - self._companion_ids: set[str] = set() - self._companion_to_parent: dict[str, str] = {} # companion -> parent - self._done: dict[str, set[str]] = {} # parent -> completed companion ids - self._pending_parents: dict[str, dict[str, Any]] = {} # parent -> deferred result - self._failed_parents: set[str] = set() - - @property - def is_active(self) -> bool: - return bool(self._companion_ids) - - @property - def num_companions(self) -> int: - return len(self._companion_ids) - - @property - def stage0_sampling_params(self) -> Any: - return self._sp0 - - def expand_prompts( - self, - request_id_to_prompt: dict[str, Any], - ) -> list[tuple[str, Any]]: - """Expand user prompts into ``(companion_id, prompt)`` pairs via model-specific func.""" - if not self._expand_func: - return [] - - pairs: list[tuple[str, Any]] = [] - for rid, prompt in request_id_to_prompt.items(): - expanded = self._expand_func(prompt, self._sp0) - if not expanded: - continue - role_map: dict[str, str] = {} - for ep in expanded: - cid = f"{rid}{ep.request_id_suffix}" - role_map[ep.role] = cid - self._companion_ids.add(cid) - self._companion_to_parent[cid] = rid - pairs.append((cid, ep.prompt)) - self._companion_map[rid] = role_map - self._done[rid] = set() - - logger.debug( - "CFG expansion: %d parent(s) -> %d companion(s)", - len(self._companion_map), - len(self._companion_ids), - ) - return pairs - - def is_companion(self, req_id: str) -> bool: - return req_id in self._companion_ids - - def has_companions(self, parent_id: str) -> bool: - return parent_id in self._companion_map - - def all_companions_done(self, parent_id: str) -> bool: - role_map = self._companion_map.get(parent_id, {}) - done_set = self._done.get(parent_id, set()) - return all(cid in done_set for cid in role_map.values()) - - def get_companion_request_ids(self, parent_id: str) -> dict[str, str]: - """Return ``{role: companion_request_id}`` for a parent.""" - return self._companion_map.get(parent_id, {}) - - def is_parent_failed(self, parent_id: str) -> bool: - return parent_id in self._failed_parents - - # -- Lifecycle events -- - - def on_companion_error(self, companion_id: str) -> tuple[str | None, bool]: - """Record failure. Returns ``(parent_id, parent_was_aborted)``.""" - parent_id = self._companion_to_parent.get(companion_id) - if parent_id is None: - return None, False - self._failed_parents.add(parent_id) - logger.error("CFG companion %s failed; marking parent %s as failed", companion_id, parent_id) - aborted = parent_id in self._pending_parents - if aborted: - self._pending_parents.pop(parent_id, None) - return parent_id, aborted - - def on_companion_completed(self, companion_id: str) -> str | None: - """Mark done. Returns parent_id only if parent is pending and all companions finished.""" - parent_id = self._companion_to_parent.get(companion_id) - if parent_id is None: - return None - self._done[parent_id].add(companion_id) - logger.debug("CFG companion %s completed (parent=%s)", companion_id, parent_id) - if parent_id in self._pending_parents and self.all_companions_done(parent_id): - return parent_id - return None - - def consume_parent_failure(self, parent_id: str) -> None: - self._failed_parents.discard(parent_id) - - # -- Deferred parent management -- - - def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> None: - """Hold parent result while waiting for companions to finish.""" - self._pending_parents[parent_id] = { - "engine_outputs": engine_outputs, - "stage_id": stage_id, - "pending_since": time.time(), - } - logger.debug("Parent %s deferred, waiting for CFG companions", parent_id) - - def pop_pending_parent(self, parent_id: str) -> dict[str, Any] | None: - return self._pending_parents.pop(parent_id, None) - - def check_timeouts(self) -> list[str]: - """Return and remove parent IDs that exceeded the pending timeout.""" - if not self._pending_parents: - return [] - now = time.time() - timed_out: list[str] = [] - for pid in list(self._pending_parents): - pending_since = self._pending_parents[pid].get("pending_since", now) - if now - pending_since > self._timeout_s: - self._pending_parents.pop(pid) - self._failed_parents.discard(pid) - timed_out.append(pid) - logger.error("Parent %s timed out waiting for CFG companions (>%.0fs)", pid, self._timeout_s) - return timed_out - - # -- Forward parent with CFG KV -- - - def forward_parent_with_cfg( - self, - req_id: str, - parent_result: dict[str, Any], - stage_list: Sequence[Any], - connectors: dict[tuple[str, str], Any], - sampling_params_list: Sequence[OmniSamplingParams], - request_id_to_prompt: dict[str, Any], - final_stage_id_to_prompt: dict[str, int], - metrics: Any, - remaining_by_stage: list[int], - ) -> bool: - """Forward a parent request to the next stage with CFG KV request IDs attached.""" - stage_id = parent_result["stage_id"] - next_stage_id = stage_id + 1 - if next_stage_id > final_stage_id_to_prompt.get(req_id, 0): - return True - - next_stage = stage_list[next_stage_id] - try: - with metrics.stage_postprocess_timer(stage_id, req_id): - next_inputs = next_stage.process_engine_inputs( - stage_list, - [request_id_to_prompt[req_id]], - source_outputs_override=parent_result["engine_outputs"], - ) - except Exception as e: - logger.exception( - "Process engine inputs error for req %s at stage %d: %s", - req_id, - next_stage_id, - e, - ) - return False - - sp_next = copy.deepcopy(sampling_params_list[next_stage_id]) - if isinstance(sp_next, OmniDiffusionSamplingParams): - sp_next.cfg_kv_request_ids = self.get_companion_request_ids(req_id) - logger.info( - "Attaching cfg_kv_request_ids=%s to request %s", - sp_next.cfg_kv_request_ids, - req_id, - ) - - connector_key = (str(stage_id), str(next_stage_id)) - connector = connectors.get(connector_key) - sent_via_connector = False - if connector: - sent_via_connector = try_send_via_connector( - connector=connector, - stage_id=stage_id, - next_stage_id=next_stage_id, - req_id=req_id, - next_inputs=next_inputs, - sampling_params=sp_next, - original_prompt=request_id_to_prompt[req_id], - next_stage_queue_submit_fn=stage_list[next_stage_id].submit, - metrics=metrics, - ) - - if not sent_via_connector: - raise RuntimeError( - f"Failed to send CFG request {req_id} to stage-{next_stage_id} via connector. " - "Configure a connector for this edge or inspect connector logs for details." - ) - - logger.debug("Forwarded CFG-enabled request %s to stage-%d", req_id, next_stage_id) - remaining_by_stage[next_stage_id] += 1 - return True