Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions vllm/distributed/eplb/eplb_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
batch_isend_irecv,
)

import vllm.distributed.nixl_utils as nixl_utils
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import (
ncclDataTypeEnum,
Expand All @@ -37,9 +38,7 @@

def has_nixl() -> bool:
"""Whether the optional NIXL / RIXL package is available."""
from vllm.distributed.nixl_utils import NixlWrapper

return NixlWrapper is not None
return nixl_utils.NixlWrapper is not None


class EplbCommunicator(ABC):
Expand Down Expand Up @@ -233,10 +232,9 @@ def __init__(
expert_weights: Sequence[torch.Tensor],
cuda_stream: torch.cuda.Stream | None = None,
) -> None:
from vllm.distributed.nixl_utils import NixlWrapper, nixl_agent_config

assert expert_weights, "NixlEplbCommunicator requires non-empty expert_weights."
if NixlWrapper is None:
nixl_wrapper_cls = nixl_utils.NixlWrapper
if nixl_wrapper_cls is None:
raise RuntimeError("NIXL/ RIXL is unavailable.")
self._cpu_group = cpu_group
self._cuda_stream = cuda_stream
Expand All @@ -254,12 +252,13 @@ def __init__(
f"expected={self._device}, got={tensor.device}"
)

nixl_agent_config = nixl_utils.nixl_agent_config
config = (
nixl_agent_config(capture_telemetry=False)
if nixl_agent_config is not None
else None
)
self._nixl_wrapper = NixlWrapper(self._make_agent_name(), config)
self._nixl_wrapper = nixl_wrapper_cls(self._make_agent_name(), config)
self._nixl_memory_type = "VRAM"
self._registered_desc: object | None = None
self._remote_agents: dict[int, str] = {}
Expand Down
8 changes: 5 additions & 3 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import copy
from dataclasses import dataclass
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np

Expand All @@ -15,9 +15,11 @@
PromMetric,
PromMetricT,
)
from vllm.distributed.nixl_utils import nixlXferTelemetry
from vllm.v1.metrics.utils import create_metric_per_engine

if TYPE_CHECKING:
from vllm.distributed.nixl_utils import nixlXferTelemetry


@dataclass
class NixlKVConnectorStats(KVConnectorStats):
Expand All @@ -40,7 +42,7 @@ def reset(self):
"num_kv_expired_reqs": [],
}

def record_transfer(self, res: nixlXferTelemetry):
def record_transfer(self, res: "nixlXferTelemetry"):
# Keep metrics units consistent with rest of the code: time us->s
self.data["transfer_duration"].append(res.xferDuration / 1e6)
self.data["post_duration"].append(res.postDuration / 1e6)
Expand Down
5 changes: 3 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def __init__(
engine_id: str,
kv_cache_config: "KVCacheConfig",
):
if NixlWrapper is None:
nixl_wrapper_cls = NixlWrapper
if nixl_wrapper_cls is None:
logger.error("NIXL is not available")
raise RuntimeError("NIXL is not available")
logger.info("Initializing NIXL wrapper")
Expand Down Expand Up @@ -283,7 +284,7 @@ def __init__(
else nixl_agent_config(num_threads=num_threads, capture_telemetry=True)
)

self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
self.nixl_wrapper = nixl_wrapper_cls(str(uuid.uuid4()), config)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)

Expand Down
90 changes: 59 additions & 31 deletions vllm/distributed/nixl_utils.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,82 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import importlib
import os
import sys
from typing import Any

from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

if "UCX_RCACHE_MAX_UNRELEASED" not in os.environ:
# declaration for static analyzers
NixlWrapper: Any
nixl_agent_config: Any
nixlXferTelemetry: Any


def _maybe_set_ucx_rcache_limit() -> None:
if "UCX_RCACHE_MAX_UNRELEASED" in os.environ:
return

if "nixl" in sys.modules or "rixl" in sys.modules:
logger.warning_once(
"NIXL was already imported, we can't reset "
"UCX_RCACHE_MAX_UNRELEASED. "
"Please set it to '1024' manually."
)
else:
logger.info_once(
"Setting UCX_RCACHE_MAX_UNRELEASED to '1024' to avoid a rare "
"memory leak in UCX when using NIXL."
)
os.environ["UCX_RCACHE_MAX_UNRELEASED"] = "1024"
return

try:
if not current_platform.is_rocm():
from nixl._api import nixl_agent as NixlWrapper
else:
from rixl._api import nixl_agent as NixlWrapper
logger.info_once(
"Setting UCX_RCACHE_MAX_UNRELEASED to '1024' to avoid a rare "
"memory leak in UCX when using NIXL."
)
os.environ["UCX_RCACHE_MAX_UNRELEASED"] = "1024"

logger.info_once("NIXL is available")
except ImportError:
logger.warning_once("NIXL is not available")
NixlWrapper = None # type: ignore[assignment, misc]

try:
if not current_platform.is_rocm():
from nixl._api import nixl_agent_config
else:
from rixl._api import nixl_agent_config
except ImportError:
nixl_agent_config = None # type: ignore[assignment]
logger.warning_once("NIXL agent config is not available")

try:
if not current_platform.is_rocm():
from nixl._bindings import nixlXferTelemetry
def _get_nixl_module_name(name: str) -> str:
package_name = "rixl" if current_platform.is_rocm() else "nixl"
if name == "nixlXferTelemetry":
return f"{package_name}._bindings"
return f"{package_name}._api"


def _load_nixl_attr(name: str) -> Any:
attr_name = {
"NixlWrapper": "nixl_agent",
"nixl_agent_config": "nixl_agent_config",
"nixlXferTelemetry": "nixlXferTelemetry",
}[name]

_maybe_set_ucx_rcache_limit()
try:
module = importlib.import_module(_get_nixl_module_name(name))
except ImportError:
if name == "NixlWrapper":
logger.warning_once("NIXL is not available")
elif name == "nixl_agent_config":
logger.warning_once("NIXL agent config is not available")
value = None
else:
from rixl._bindings import nixlXferTelemetry
except ImportError:
nixlXferTelemetry = None # type: ignore[assignment, misc]
value = getattr(module, attr_name, None)
if name == "NixlWrapper":
if value is None:
logger.warning_once("NIXL is not available")
else:
logger.info_once("NIXL is available")
elif name == "nixl_agent_config" and value is None:
logger.warning_once("NIXL agent config is not available")

globals()[name] = value
return value


def __getattr__(name: str) -> Any:
if name in __all__:
return _load_nixl_attr(name)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


__all__ = ["NixlWrapper", "nixl_agent_config", "nixlXferTelemetry"]
Loading