Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion vllm_omni/diffusion/stage_diffusion_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import asyncio
import contextlib
import os
import signal
import time
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -21,7 +22,7 @@
from PIL import Image
from vllm.logger import init_logger
from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx
from vllm.utils.system_utils import get_mp_context
from vllm.utils.system_utils import decorate_logs, get_mp_context, set_process_title
from vllm.v1.utils import shutdown

from vllm_omni.diffusion.data import DiffusionRequestAbortedError
Expand Down Expand Up @@ -550,6 +551,12 @@ def signal_handler(signum: int, frame: Any) -> None:
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)

stage_label = f"stage{omni_stage_id}" if omni_stage_id is not None else "noid"
replica_id = max(int(omni_replica_id), 0)
set_process_title(f"StageDiffusionProc_{stage_label}_replica{replica_id}")
decorate_logs()
os.environ["VLLM_OMNI_REPLICA_ID"] = str(replica_id)

proc = cls(model, od_config)
coord_client: OmniCoordClientForStage | None = None
try:
Expand Down
16 changes: 13 additions & 3 deletions vllm_omni/distributed/omni_connectors/kv_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,13 +1018,18 @@ def receive_kv_cache_for_request(
)
pending_pairs = list(recv_key_pairs)
received_payloads: dict[str, tuple[dict[str, Any], int]] = {}
replica_id = get_omni_replica_id()

logger.info(
"Wait for KV cache for request %s from stage %s to %s via %s key(s)...",
"[KV recv stage-%s rep-%s] Wait for KV cache for request %s from stage %s via %s key(s) "
"(sender=%s:%s)...",
to_stage,
replica_id,
request_id,
from_stage,
to_stage,
len(recv_key_pairs),
self._sender_base_host,
self._sender_base_zmq_port,
)

try:
Expand Down Expand Up @@ -1110,12 +1115,17 @@ def receive_kv_cache_for_request(
logger.exception("Failed to move KV cache tensors to target device")

logger.info(
"Successfully received KV cache for %s, %s bytes across %s key(s), wait=%.3fs, link=%.1fms",
"[KV recv stage-%s rep-%s] Successfully received KV cache for %s, "
"%s bytes across %s key(s), wait=%.3fs, link=%.1fms (sender=%s:%s)",
to_stage,
replica_id,
request_id,
total_size,
len(recv_key_pairs),
elapsed,
link_ms,
self._sender_base_host,
self._sender_base_zmq_port,
)
return data, total_size

Expand Down
12 changes: 12 additions & 0 deletions vllm_omni/engine/stage_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,12 @@ async def submit_initial(
request_id,
affinity_request_id=affinity_request_id,
)
logger.debug(
"[StagePool] stage-%s selected diffusion replica %s for request %s",
self.stage_id,
replica_id,
request_id,
)
client = self._diffusion_client(replica_id)
if isinstance(request, list):
await client.add_batch_request_async(request_id, request, params, **submit_kwargs)
Expand All @@ -505,6 +511,12 @@ async def submit_initial(
request_id,
affinity_request_id=affinity_request_id,
)
logger.debug(
"[StagePool] stage-%s selected LLM replica %s for request %s",
self.stage_id,
replica_id,
request_id,
)
client = self.clients[replica_id]
if client is None:
raise RuntimeError(f"stage {self.stage_id} replica {replica_id} is not attached")
Expand Down