Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Main entry points for vLLM-Omni inference and serving.

- [vllm_omni.entrypoints.async_omni.AsyncOmni][]
- [vllm_omni.entrypoints.cfg_companion_tracker.CfgCompanionTracker][]
- [vllm_omni.engine.cfg_companion_tracker.CfgCompanionTracker][]
- [vllm_omni.entrypoints.cli.benchmark.base.OmniBenchmarkSubcommandBase][]
- [vllm_omni.entrypoints.cli.benchmark.main.OmniBenchmarkSubcommand][]
- [vllm_omni.entrypoints.cli.benchmark.serve.OmniBenchmarkServingSubcommand][]
Expand Down
82 changes: 82 additions & 0 deletions tests/engine/test_cfg_companion_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest

from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


def test_register_companion_and_cleanup():
tracker = CfgCompanionTracker()

tracker.register_companion("req1", "cfg_text", "req1__cfg_text")
tracker.register_companion("req1", "cfg_img", "req1__cfg_img")

assert tracker.is_companion("req1__cfg_text")
assert tracker.get_companion_request_ids("req1") == {
"cfg_text": "req1__cfg_text",
"cfg_img": "req1__cfg_img",
}

removed = tracker.cleanup_parent("req1")

assert sorted(removed) == ["req1__cfg_img", "req1__cfg_text"]
assert not tracker.is_companion("req1__cfg_text")
assert tracker.get_companion_request_ids("req1") == {}


def test_attach_cfg_request_ids_clones_diffusion_params():
tracker = CfgCompanionTracker()
tracker.register_companion("req1", "cfg_text", "req1__cfg_text")

params = OmniDiffusionSamplingParams()
updated = tracker.attach_cfg_request_ids("req1", params)

assert updated is not params
assert params.cfg_kv_request_ids is None
assert updated.cfg_kv_request_ids == {"cfg_text": "req1__cfg_text"}


def test_abort_parent_expands_to_companions_and_cleans_up_deferred_parent():
tracker = CfgCompanionTracker()
tracker.register_companion("req1", "cfg_text", "req1__cfg_text")
tracker.defer_parent("req1", {"out": 1}, stage_id=0)

aborted = tracker.abort_parents(["req1"])

assert sorted(aborted) == ["req1", "req1__cfg_text"]
assert not tracker.is_companion("req1__cfg_text")
assert tracker.pop_pending_parent("req1") is None


def test_abort_companion_does_not_expand_to_parent():
tracker = CfgCompanionTracker()
tracker.register_companion("req1", "cfg_text", "req1__cfg_text")

aborted = tracker.abort_parents(["req1__cfg_text"])

assert aborted == ["req1__cfg_text"]


def test_companion_completion_flushes_deferred_parent():
tracker = CfgCompanionTracker()
tracker.register_companion("req1", "cfg_text", "req1__cfg_text")
tracker.defer_parent("req1", {"out": 123}, stage_id=0)

assert not tracker.all_companions_done("req1")
assert tracker.on_companion_completed("req1__cfg_text") == "req1"
assert tracker.all_companions_done("req1")

popped = tracker.pop_pending_parent("req1")
assert popped is not None
assert popped["engine_outputs"] == {"out": 123}
assert popped["stage_id"] == 0


def test_companion_completion_without_registered_parent_asserts():
tracker = CfgCompanionTracker()
tracker._companion_ids.add("req1__cfg_text")
tracker._companion_to_parent["req1__cfg_text"] = "req1"

with pytest.raises(AssertionError, match="completed before parent req1 was registered"):
tracker.on_companion_completed("req1__cfg_text")
5 changes: 3 additions & 2 deletions tests/engine/test_orchestrator_kv_sender_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from vllm import SamplingParams

from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker
from vllm_omni.engine.orchestrator import Orchestrator, OrchestratorRequestState
from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClient
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
Expand Down Expand Up @@ -130,7 +131,7 @@ def test_forward_to_diffusion_attaches_kv_sender_info():

orchestrator.num_stages = 2
orchestrator.stage_clients = [sender_stage, diffusion_stage]
orchestrator._companion_map = {}
orchestrator._cfg_tracker = CfgCompanionTracker()
orchestrator.stage_vllm_configs = [None, None]
orchestrator.output_processors = [None, None]

Expand Down Expand Up @@ -161,7 +162,7 @@ def test_forward_to_diffusion_uses_engine_input_source_for_kv_sender_info():

orchestrator.num_stages = 3
orchestrator.stage_clients = [source_stage, previous_stage, diffusion_stage]
orchestrator._companion_map = {}
orchestrator._cfg_tracker = CfgCompanionTracker()
orchestrator.stage_vllm_configs = [None, None, None]
orchestrator.output_processors = [None, None, None]

Expand Down
114 changes: 0 additions & 114 deletions tests/entrypoints/test_cfg_companion_tracker.py

This file was deleted.

1 change: 0 additions & 1 deletion vllm_omni/engine/async_omni_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,6 @@ def _enqueue_cfg_companions(
params=companion_params,
supported_tasks=self.supported_tasks,
)
request = _upgrade_to_omni_request(request, companion_prompt)
request.external_req_id = cid

self.output_processors[0].add_request(
Expand Down
125 changes: 125 additions & 0 deletions vllm_omni/engine/cfg_companion_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""CFG companion request tracker for the Omni orchestrator.

Encapsulates all bookkeeping for Classifier-Free Guidance companion
requests (parent/companion ID mapping, completion tracking,
deferred forwarding, and cleanup).
"""

from __future__ import annotations

import logging
from typing import Any

from vllm_omni.inputs.data import OmniDiffusionSamplingParams

logger = logging.getLogger(__name__)


class CfgCompanionTracker:
"""Manages CFG companion request lifecycle in the orchestrator scheduling loop."""

def __init__(self) -> None:
self._companion_map: dict[str, dict[str, str]] = {} # parent -> {role: companion_id}
self._companion_ids: set[str] = set()
self._companion_to_parent: dict[str, str] = {} # companion -> parent
self._done: dict[str, set[str]] = {} # parent -> completed companion ids
self._pending_parents: dict[str, dict[str, Any]] = {} # parent -> deferred result

def is_companion(self, req_id: str) -> bool:
return req_id in self._companion_ids

def has_companions(self, parent_id: str) -> bool:
return parent_id in self._companion_map

def all_companions_done(self, parent_id: str) -> bool:
role_map = self._companion_map.get(parent_id, {})
done_set = self._done.get(parent_id, set())
return all(cid in done_set for cid in role_map.values())

def get_companion_request_ids(self, parent_id: str) -> dict[str, str]:
"""Return ``{role: companion_request_id}`` for a parent."""
return self._companion_map.get(parent_id, {})

def register_parent(self, parent_id: str) -> None:
self._companion_map.setdefault(parent_id, {})
self._done.setdefault(parent_id, set())

def register_companion(self, parent_id: str, role: str, companion_id: str) -> None:
self.register_parent(parent_id)
self._companion_map[parent_id][role] = companion_id
self._companion_ids.add(companion_id)
self._companion_to_parent[companion_id] = parent_id

def attach_cfg_request_ids(self, parent_id: str, sampling_params: Any) -> Any:
cfg_ids = self.get_companion_request_ids(parent_id)
if not cfg_ids:
return sampling_params

if isinstance(sampling_params, OmniDiffusionSamplingParams):
sampling_params = sampling_params.clone()
sampling_params.cfg_kv_request_ids = cfg_ids
logger.info(
"Attaching cfg_kv_request_ids=%s to request %s",
cfg_ids,
parent_id,
)
return sampling_params

def on_companion_completed(self, companion_id: str) -> str | None:
"""Mark done. Returns parent_id only if parent is pending and all companions finished."""
parent_id = self._companion_to_parent.get(companion_id)
if parent_id is None:
return None
done_set = self._done.get(parent_id)
assert done_set is not None, f"Companion {companion_id} completed before parent {parent_id} was registered"
if companion_id in done_set:
return None
done_set.add(companion_id)
logger.debug("CFG companion %s completed (parent=%s)", companion_id, parent_id)
if parent_id in self._pending_parents and self.all_companions_done(parent_id):
return parent_id
return None

def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> None:
"""Hold parent result while waiting for companions to finish."""
# TODO: Add timeout/error recovery when the orchestrator grows a
# companion-failure path. Today deferred parents are released only when
# companions finish or the external layer aborts the request.
self._pending_parents[parent_id] = {
"engine_outputs": engine_outputs,
"stage_id": stage_id,
}
logger.debug("Parent %s deferred, waiting for CFG companions", parent_id)

def pop_pending_parent(self, parent_id: str) -> dict[str, Any] | None:
return self._pending_parents.pop(parent_id, None)

def cleanup_parent(self, parent_id: str) -> list[str]:
companion_ids = list(self._companion_map.pop(parent_id, {}).values())
for companion_id in companion_ids:
self._companion_ids.discard(companion_id)
self._companion_to_parent.pop(companion_id, None)
self._done.pop(parent_id, None)
self._pending_parents.pop(parent_id, None)
return companion_ids

def abort_parents(self, request_ids: list[str]) -> list[str]:
all_request_ids = list(request_ids)
seen = set(all_request_ids)
parents_to_cleanup: set[str] = set()

for req_id in request_ids:
# The orchestrator calls this with parent request IDs. If a raw
# companion ID is passed here, keep it as a direct abort target and
# avoid tearing down parent tracking state implicitly.
if req_id not in self._companion_ids:
parents_to_cleanup.add(req_id)

for parent_id in parents_to_cleanup:
companion_ids = self.cleanup_parent(parent_id)
for companion_id in companion_ids:
if companion_id not in seen:
seen.add(companion_id)
all_request_ids.append(companion_id)

return all_request_ids
Loading
Loading