diff --git a/vllm/distributed/eplb/eplb_communicator.py b/vllm/distributed/eplb/eplb_communicator.py index 95a5ae5ff45d..f8ee90b934fb 100644 --- a/vllm/distributed/eplb/eplb_communicator.py +++ b/vllm/distributed/eplb/eplb_communicator.py @@ -19,6 +19,7 @@ batch_isend_irecv, ) +import vllm.distributed.nixl_utils as nixl_utils from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import ( ncclDataTypeEnum, @@ -37,9 +38,7 @@ def has_nixl() -> bool: """Whether the optional NIXL / RIXL package is available.""" - from vllm.distributed.nixl_utils import NixlWrapper - - return NixlWrapper is not None + return nixl_utils.NixlWrapper is not None class EplbCommunicator(ABC): @@ -233,10 +232,9 @@ def __init__( expert_weights: Sequence[torch.Tensor], cuda_stream: torch.cuda.Stream | None = None, ) -> None: - from vllm.distributed.nixl_utils import NixlWrapper, nixl_agent_config - assert expert_weights, "NixlEplbCommunicator requires non-empty expert_weights." - if NixlWrapper is None: + nixl_wrapper_cls = nixl_utils.NixlWrapper + if nixl_wrapper_cls is None: raise RuntimeError("NIXL/ RIXL is unavailable.") self._cpu_group = cpu_group self._cuda_stream = cuda_stream @@ -254,12 +252,13 @@ def __init__( f"expected={self._device}, got={tensor.device}" ) + nixl_agent_config = nixl_utils.nixl_agent_config config = ( nixl_agent_config(capture_telemetry=False) if nixl_agent_config is not None else None ) - self._nixl_wrapper = NixlWrapper(self._make_agent_name(), config) + self._nixl_wrapper = nixl_wrapper_cls(self._make_agent_name(), config) self._nixl_memory_type = "VRAM" self._registered_desc: object | None = None self._remote_agents: dict[int, str] = {} diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/stats.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/stats.py index 65c553cfec30..1e4f5c48e0f7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/stats.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/stats.py @@ -4,7 +4,7 @@ import copy from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np @@ -15,9 +15,11 @@ PromMetric, PromMetricT, ) -from vllm.distributed.nixl_utils import nixlXferTelemetry from vllm.v1.metrics.utils import create_metric_per_engine +if TYPE_CHECKING: + from vllm.distributed.nixl_utils import nixlXferTelemetry + @dataclass class NixlKVConnectorStats(KVConnectorStats): @@ -40,7 +42,7 @@ def reset(self): "num_kv_expired_reqs": [], } - def record_transfer(self, res: nixlXferTelemetry): + def record_transfer(self, res: "nixlXferTelemetry"): # Keep metrics units consistent with rest of the code: time us->s self.data["transfer_duration"].append(res.xferDuration / 1e6) self.data["post_duration"].append(res.postDuration / 1e6) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index eb90a4e17295..37a810691f37 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -198,7 +198,8 @@ def __init__( engine_id: str, kv_cache_config: "KVCacheConfig", ): - if NixlWrapper is None: + nixl_wrapper_cls = NixlWrapper + if nixl_wrapper_cls is None: logger.error("NIXL is not available") raise RuntimeError("NIXL is not available") logger.info("Initializing NIXL wrapper") @@ -284,7 +285,7 @@ def __init__( else nixl_agent_config(num_threads=num_threads, capture_telemetry=True) ) - self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) + self.nixl_wrapper = nixl_wrapper_cls(str(uuid.uuid4()), config) # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) diff --git a/vllm/distributed/nixl_utils.py b/vllm/distributed/nixl_utils.py index 2da37017a37f..d7d262672d39 100644 --- a/vllm/distributed/nixl_utils.py +++ b/vllm/distributed/nixl_utils.py @@ -1,54 +1,82 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib import os import sys +from typing import Any from vllm.logger import init_logger from vllm.platforms import current_platform logger = init_logger(__name__) -if "UCX_RCACHE_MAX_UNRELEASED" not in os.environ: +# declaration for static analyzers +NixlWrapper: Any +nixl_agent_config: Any +nixlXferTelemetry: Any + + +def _maybe_set_ucx_rcache_limit() -> None: + if "UCX_RCACHE_MAX_UNRELEASED" in os.environ: + return + if "nixl" in sys.modules or "rixl" in sys.modules: logger.warning_once( "NIXL was already imported, we can't reset " "UCX_RCACHE_MAX_UNRELEASED. " "Please set it to '1024' manually." ) - else: - logger.info_once( - "Setting UCX_RCACHE_MAX_UNRELEASED to '1024' to avoid a rare " - "memory leak in UCX when using NIXL." - ) - os.environ["UCX_RCACHE_MAX_UNRELEASED"] = "1024" + return -try: - if not current_platform.is_rocm(): - from nixl._api import nixl_agent as NixlWrapper - else: - from rixl._api import nixl_agent as NixlWrapper + logger.info_once( + "Setting UCX_RCACHE_MAX_UNRELEASED to '1024' to avoid a rare " + "memory leak in UCX when using NIXL." + ) + os.environ["UCX_RCACHE_MAX_UNRELEASED"] = "1024" - logger.info_once("NIXL is available") -except ImportError: - logger.warning_once("NIXL is not available") - NixlWrapper = None # type: ignore[assignment, misc] -try: - if not current_platform.is_rocm(): - from nixl._api import nixl_agent_config - else: - from rixl._api import nixl_agent_config -except ImportError: - nixl_agent_config = None # type: ignore[assignment] - logger.warning_once("NIXL agent config is not available") - -try: - if not current_platform.is_rocm(): - from nixl._bindings import nixlXferTelemetry +def _get_nixl_module_name(name: str) -> str: + package_name = "rixl" if current_platform.is_rocm() else "nixl" + if name == "nixlXferTelemetry": + return f"{package_name}._bindings" + return f"{package_name}._api" + + +def _load_nixl_attr(name: str) -> Any: + attr_name = { + "NixlWrapper": "nixl_agent", + "nixl_agent_config": "nixl_agent_config", + "nixlXferTelemetry": "nixlXferTelemetry", + }[name] + + _maybe_set_ucx_rcache_limit() + try: + module = importlib.import_module(_get_nixl_module_name(name)) + except ImportError: + if name == "NixlWrapper": + logger.warning_once("NIXL is not available") + elif name == "nixl_agent_config": + logger.warning_once("NIXL agent config is not available") + value = None else: - from rixl._bindings import nixlXferTelemetry -except ImportError: - nixlXferTelemetry = None # type: ignore[assignment, misc] + value = getattr(module, attr_name, None) + if name == "NixlWrapper": + if value is None: + logger.warning_once("NIXL is not available") + else: + logger.info_once("NIXL is available") + elif name == "nixl_agent_config" and value is None: + logger.warning_once("NIXL agent config is not available") + + globals()[name] = value + return value + + +def __getattr__(name: str) -> Any: + if name in __all__: + return _load_nixl_attr(name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + __all__ = ["NixlWrapper", "nixl_agent_config", "nixlXferTelemetry"]