Skip to content
261 changes: 101 additions & 160 deletions vllm_omni/entrypoints/omni_llm.py

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions vllm_omni/entrypoints/omni_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_to_dict,
maybe_dump_to_shm,
maybe_load_from_ipc_with_metrics,
set_stage_gpu_devices,
set_stage_devices,
)
from vllm_omni.inputs.data import OmniTokensPrompt

Expand Down Expand Up @@ -281,7 +281,9 @@ def filter(self, record: _logging.LogRecord) -> bool:

# Device mapping
try:
set_stage_gpu_devices(stage_id, runtime_cfg.get("devices"))
from vllm_omni.utils import detect_device_type
device_type = detect_device_type()
set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type)
except Exception as e:
_logging.getLogger(__name__).warning("[Stage-%s] Device setup failed: %s", stage_id, e)

Expand Down Expand Up @@ -556,7 +558,9 @@ def filter(self, record: _logging.LogRecord) -> bool:

# Device mapping
try:
set_stage_gpu_devices(stage_id, runtime_cfg.get("devices"))
from vllm_omni.utils import detect_device_type
device_type = detect_device_type()
set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type)
except Exception as e:
_logging.getLogger(__name__).warning("[Stage-%s] Device setup failed: %s", stage_id, e)

Expand Down Expand Up @@ -727,4 +731,4 @@ def filter(self, record: _logging.LogRecord) -> bool:
"stage_id": stage_id,
"error": str(e),
}
)
)
2 changes: 1 addition & 1 deletion vllm_omni/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,4 +776,4 @@ def _create_audio_choice(self, omni_outputs: OmniRequestOutput, role: str):
stop_reason=None,
)
choices.append(choice_data)
return choices
return choices
189 changes: 104 additions & 85 deletions vllm_omni/entrypoints/stage_utils.py
Original file line number Diff line number Diff line change
@@ -1,122 +1,144 @@
from __future__ import annotations

import json
from typing import Any, Dict, Optional, Tuple, Union

import logging
import json
import os
import pickle
from multiprocessing import shared_memory as _shm
from typing import Any

import cloudpickle
from omegaconf import OmegaConf
import cloudpickle

logger = logging.getLogger(__name__)


def set_stage_gpu_devices(stage_id: int, devices: str | int | None) -> None:
"""Configure per-stage CUDA visibility and current device.

Behavior
- Comma-separated string (e.g. "2,5,7"): set CUDA_VISIBLE_DEVICES exactly
to this list; logical index 0 is used as current device.
- Integer or digit-string: treat as logical index (0-based) into the current
CUDA_VISIBLE_DEVICES mapping; map to the physical device, and then set
CUDA_VISIBLE_DEVICES to this single device.
- None/"cpu": keep default visibility.
- Otherwise: set CUDA_VISIBLE_DEVICES to the provided single device string.
def set_stage_devices(
stage_id: int,
devices: Optional[Union[str, int]],
device_type: Optional[str] = None,
) -> None:
"""Configure per-stage device visibility and current device (CUDA or NPU).

This function sets environment variables that control which devices are visible
to the process, and sets the current device. It must be called BEFORE worker
initialization so that workers see the correct devices.

Args:
stage_id: Stage identifier for logging
devices: Device specification:
- Comma-separated string (e.g. "2,5,7"): set device visibility env var
exactly to this list; logical index 0 is used as current device.
- Integer or digit-string: treat as logical index (0-based) into the
current device visibility mapping; map to physical device, then set
env var to this single device.
- None/"cpu": keep default visibility.
- Otherwise: set env var to the provided single device string.
device_type: Device type ("cuda" or "npu"). If None, auto-detects.

Behavior:
- CUDA: Sets CUDA_VISIBLE_DEVICES and calls torch.cuda.set_device()
- NPU: Sets ASCEND_RT_VISIBLE_DEVICES and calls torch.npu.set_device()
"""
from vllm_omni.utils import detect_device_type, get_device_control_env_var

if device_type is None:
device_type = detect_device_type()

env_var = get_device_control_env_var()

# Select device-specific torch functions
if device_type == "npu":
try:
import torch.npu # type: ignore[import-untyped]
except ImportError:
logger.debug("[Stage-%s] torch.npu not available, skipping NPU device setup", stage_id)
return

is_available_fn = torch.npu.is_available
set_device_fn = torch.npu.set_device
device_count_fn = torch.npu.device_count
get_device_properties_fn = torch.npu.get_device_properties
mem_get_info_fn = torch.npu.mem_get_info
get_device_name_fn = torch.npu.get_device_name
device_type_label = "NPU"
elif device_type == "cuda":
import torch # noqa: WPS433

is_available_fn = torch.cuda.is_available
set_device_fn = torch.cuda.set_device
device_count_fn = torch.cuda.device_count
get_device_properties_fn = torch.cuda.get_device_properties
mem_get_info_fn = torch.cuda.mem_get_info
get_device_name_fn = torch.cuda.get_device_name
device_type_label = "CUDA"
else:
logger.debug("[Stage-%s] Unsupported device type: %s", stage_id, device_type)
return

try:
selected_physical: int | None = None
logical_idx: int | None = None
selected_physical: Optional[int] = None
logical_idx: Optional[int] = None

if isinstance(devices, str) and "," in devices:
os.environ["CUDA_VISIBLE_DEVICES"] = devices
os.environ[env_var] = devices
toks = [t.strip() for t in devices.split(",") if t.strip() != ""]
if toks:
try:
selected_physical = int(toks[0])
logger.debug(
"[Stage-%s] Set CUDA_VISIBLE_DEVICES to %s; logical 0 -> physical %s",
stage_id,
devices,
selected_physical,
"[Stage-%s] Set %s to %s; logical 0 -> physical %s",
stage_id, env_var, devices, selected_physical,
)
except Exception as e:
logger.debug("[Stage-%s] Failed to parse first CUDA device: %s", stage_id, e)
logger.debug("[Stage-%s] Failed to parse first %s device: %s", stage_id, device_type_label, e)
selected_physical = None
elif isinstance(devices, (int, str)) and (isinstance(devices, int) or str(devices).isdigit()):
logical_idx = max(0, int(devices))
vis = os.environ.get("CUDA_VISIBLE_DEVICES")
vis = os.environ.get(env_var)
if vis:
try:
mapping = [int(x) for x in vis.split(",") if x.strip() != ""]
if 0 <= logical_idx < len(mapping):
selected_physical = mapping[logical_idx]
except Exception as e:
logger.debug(
"[Stage-%s] Failed to map logical index via CUDA_VISIBLE_DEVICES: %s",
stage_id,
e,
)
logger.debug("[Stage-%s] Failed to map logical index via %s: %s", stage_id, env_var, e)
selected_physical = None
if selected_physical is None:
selected_physical = int(logical_idx)
os.environ["CUDA_VISIBLE_DEVICES"] = str(selected_physical)
os.environ[env_var] = str(selected_physical)
logger.debug(
"[Stage-%s] Logical index %d -> physical %s; set CUDA_VISIBLE_DEVICES to single device",
stage_id,
logical_idx + 1,
selected_physical,
"[Stage-%s] Logical index %d -> physical %s; set %s to single device",
stage_id, logical_idx + 1, selected_physical, env_var,
)
elif devices in (None, "cpu"):
logger.debug(
"[Stage-%s] Using default device visibility (devices=%s)",
stage_id,
devices,
)
logger.debug("[Stage-%s] Using default device visibility (devices=%s)", stage_id, devices)
else:
selected_physical = int(str(devices))
os.environ["CUDA_VISIBLE_DEVICES"] = str(selected_physical)
logger.debug(
"[Stage-%s] Set CUDA_VISIBLE_DEVICES to single device %s (fallback)",
stage_id,
selected_physical,
)
os.environ[env_var] = str(selected_physical)
logger.debug("[Stage-%s] Set %s to single device %s (fallback)", stage_id, env_var, selected_physical)

try:
import torch # noqa: WPS433

if torch.cuda.is_available():
if is_available_fn():
try:
torch.cuda.set_device(0)
set_device_fn(0)
except Exception as e:
logger.debug(
"[Stage-%s] torch.cuda.set_device(0) failed: %s",
stage_id,
e,
exc_info=True,
)
num = torch.cuda.device_count()
logger.debug("[Stage-%s] %s set_device(0) failed: %s", stage_id, device_type_label, e, exc_info=True)
num = device_count_fn()
info = []
for i in range(num):
total = torch.cuda.get_device_properties(i).total_memory
free, _ = torch.cuda.mem_get_info(i)
info.append(
{
"idx": i,
"name": torch.cuda.get_device_name(i),
"total": int(total),
"free": int(free),
}
)
logger.debug("[Stage-%s] CUDA devices visible=%s info=%s", stage_id, num, info)
total = get_device_properties_fn(i).total_memory
free, _ = mem_get_info_fn(i)
info.append({
"idx": i,
"name": get_device_name_fn(i),
"total": int(total),
"free": int(free),
})
logger.debug("[Stage-%s] %s devices visible=%s info=%s", stage_id, device_type_label, num, info)
except Exception as e:
logger.debug(
"[Stage-%s] Failed to query CUDA devices: %s",
stage_id,
e,
exc_info=True,
)
logger.debug("[Stage-%s] Failed to query %s devices: %s", stage_id, device_type_label, e, exc_info=True)
except Exception as e:
logger.warning("Failed to interpret devices for stage %s: %s", stage_id, e)

Expand All @@ -126,7 +148,7 @@ def serialize_obj(obj: Any) -> bytes:
return cloudpickle.dumps(obj)


def shm_write_bytes(payload: bytes) -> dict[str, Any]:
def shm_write_bytes(payload: bytes) -> Dict[str, Any]:
"""Write bytes into SharedMemory and return meta dict {name,size}.

Caller should close the segment; the receiver should unlink.
Expand All @@ -143,7 +165,7 @@ def shm_write_bytes(payload: bytes) -> dict[str, Any]:
return meta


def shm_read_bytes(meta: dict[str, Any]) -> bytes:
def shm_read_bytes(meta: Dict[str, Any]) -> bytes:
"""Read bytes from SharedMemory by meta {name,size} and cleanup."""
shm = _shm.SharedMemory(name=meta["name"]) # type: ignore[index]
mv = memoryview(shm.buf)
Expand All @@ -170,7 +192,7 @@ def _ensure_parent_dir(path: str) -> None:
pass


def append_jsonl(path: str, record: dict[str, Any]) -> None:
def append_jsonl(path: str, record: Dict[str, Any]) -> None:
"""Append a JSON record as one line to a JSONL file (best-effort).

This is safe to call from multiple processes when each process writes
Expand All @@ -187,7 +209,7 @@ def append_jsonl(path: str, record: dict[str, Any]) -> None:
logger.exception("Failed to append JSONL to %s", path)


def maybe_dump_to_shm(obj: Any, threshold: int) -> tuple[bool, Any]:
def maybe_dump_to_shm(obj: Any, threshold: int) -> Tuple[bool, Any]:
"""Dump object to SHM if serialized size exceeds threshold.

Returns (True, meta) when dumped; otherwise (False, original_obj).
Expand All @@ -198,7 +220,7 @@ def maybe_dump_to_shm(obj: Any, threshold: int) -> tuple[bool, Any]:
return False, obj


def maybe_load_from_ipc(container: dict[str, Any], obj_key: str, shm_key: str) -> Any:
def maybe_load_from_ipc(container: Dict[str, Any], obj_key: str, shm_key: str) -> Any:
"""Load object from container that may carry SHM or inline object.

Deprecated: prefer `maybe_load_from_ipc_with_metrics` to also obtain
Expand All @@ -209,17 +231,14 @@ def maybe_load_from_ipc(container: dict[str, Any], obj_key: str, shm_key: str) -
return container[obj_key]


def maybe_load_from_ipc_with_metrics(
container: dict[str, Any], obj_key: str, shm_key: str
) -> tuple[Any, dict[str, float]]:
def maybe_load_from_ipc_with_metrics(container: Dict[str, Any], obj_key: str, shm_key: str) -> tuple[Any, Dict[str, float]]:
"""Load object and return (object, metrics) with RX bytes and decode time.

Metrics keys:
- rx_transfer_bytes: int
- rx_decode_time_ms: float
"""
import time as _time # local import to avoid overhead at module import

t0 = _time.time()
if shm_key in container:
meta = container[shm_key] # type: ignore[index]
Expand All @@ -243,13 +262,13 @@ def maybe_load_from_ipc_with_metrics(
}


def encode_for_ipc(obj: Any, threshold: int, obj_key: str, shm_key: str) -> dict[str, Any]:
def encode_for_ipc(obj: Any, threshold: int, obj_key: str, shm_key: str) -> Dict[str, Any]:
"""Return a dict payload for IPC: inline (obj_key) or SHM (shm_key).

When serialized size exceeds threshold, returns {shm_key: {name,size}};
otherwise returns {obj_key: obj}.
"""
payload: dict[str, Any] = {}
payload: Dict[str, Any] = {}
use_shm, data = maybe_dump_to_shm(obj, threshold)
if use_shm:
payload[shm_key] = data
Expand All @@ -259,7 +278,7 @@ def encode_for_ipc(obj: Any, threshold: int, obj_key: str, shm_key: str) -> dict


# Convert OmegaConf/objects to plain dicts
def _to_dict(x: Any) -> dict[str, Any]:
def _to_dict(x: Any) -> Dict[str, Any]:
try:
if isinstance(x, dict):
return dict(x)
Expand All @@ -268,4 +287,4 @@ def _to_dict(x: Any) -> dict[str, Any]:
try:
return dict(x)
except Exception:
return {}
return {}
Loading