diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py index da78b62b9a03..78269bfe40a7 100644 --- a/tests/v1/kv_connector/unit/test_moriio_connector.py +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib.util -import os import subprocess import uuid from unittest.mock import MagicMock, patch @@ -202,6 +201,7 @@ def create_vllm_config( enable_chunked_prefill: bool = True, enable_permute_local_kv: bool = False, role="kv_consumer", + read_mode: bool = False, ) -> VllmConfig: """Initialize VllmConfig for testing.""" scheduler_config = SchedulerConfig( @@ -228,6 +228,7 @@ def create_vllm_config( kv_connector="MoRIIOConnector", kv_role=role, enable_permute_local_kv=enable_permute_local_kv, + kv_connector_extra_config={"read_mode": read_mode}, ) return VllmConfig( scheduler_config=scheduler_config, @@ -238,15 +239,6 @@ def create_vllm_config( ) -@pytest.fixture -def moriio_read_mode(): - """Force the connector into read mode via env for tests.""" - os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True" - yield - # Cleanup after test - os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None) - - def test_write_mode_saves_local_block_ids(): """Write mode records local block ids in MoRIIOConnectorMetadata.reqs_to_save.""" @@ -358,11 +350,11 @@ def test_write_mode_with_chunked_prefill_saves_local_block_ids(): assert block_id == block.block_id, f"{block_id} != {block.block_id}" -def test_read_mode_loads_remote_block_ids(moriio_read_mode): +def test_read_mode_loads_remote_block_ids(): """Read mode loads remote block ids into local cache mapping.""" # Setup Scheduler and Request - vllm_config = create_vllm_config(role="kv_consumer") + vllm_config = create_vllm_config(role="kv_consumer", read_mode=True) scheduler = create_scheduler(vllm_config) # 2 Full Blocks and 1 Half Block. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py index 2733b2e0a878..89c314a66fe8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import os import threading import time from collections.abc import Iterator @@ -12,8 +13,7 @@ import torch import zmq -from vllm import envs -from vllm.config import VllmConfig +from vllm.config import KVTransferConfig, VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata, ) @@ -162,8 +162,10 @@ class TransferError(MoRIIOError): pass -def get_moriio_mode() -> MoRIIOMode: - read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE +def get_moriio_mode(kv_transfer_config: KVTransferConfig) -> MoRIIOMode: + read_mode = str( + kv_transfer_config.kv_connector_extra_config.get("read_mode", "false") + ).lower().strip() in ("true", "1") logger.debug("MoRIIO Connector read_mode: %s", read_mode) if read_mode: return MoRIIOMode.READ @@ -175,6 +177,26 @@ def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: return (dp_rank) * tp_size + tp_rank +_DEPRECATED_ENV_VARS: dict[str, str] = { + "VLLM_MORIIO_CONNECTOR_READ_MODE": "read_mode", + "VLLM_MORIIO_QP_PER_TRANSFER": "qp_per_transfer", + "VLLM_MORIIO_POST_BATCH_SIZE": "post_batch_size", + "VLLM_MORIIO_NUM_WORKERS": "num_workers", +} + + +def _warn_deprecated_env_vars() -> None: + for env_var, new_key in _DEPRECATED_ENV_VARS.items(): + if env_var in os.environ: + logger.warning_once( + "The environment variable %s is deprecated and ignored. " + "Set %r inside kv_transfer_config.kv_connector_extra_config " + "instead.", + env_var, + new_key, + ) + + @dataclass class MoRIIOConfig: local_ip: str @@ -189,6 +211,10 @@ class MoRIIOConfig: dp_rank: int dp_size: int tp_size: int + read_mode: bool = False + qp_per_transfer: int = 1 + post_batch_size: int = -1 + num_workers: int = 1 backend: str = "rdma" @classmethod @@ -201,11 +227,24 @@ def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": # notify_port -> For synchronizing stages between prefill and decode # handshake_port -> For initial handshake between mori engine + # Optional tuning knobs + # read_mode -> If true, run the connector in READ mode (consumer + # pulls KV from producer) instead of the default + # WRITE mode. + + # Knobs for RDMA transfers, ignored if on xgmi backend + # qp_per_transfer -> Number of RDMA Queue Pairs per KV transfer. + # post_batch_size -> Batch size for posting transfer work requests + # (-1 lets the MoRI backend choose). + # num_workers -> Number of background worker threads the MoRI + # engine uses for transfer processing. + # TODO : merge notify_port and handshake_port to simplify port management # supports non-contiguous ports assert vllm_config.kv_transfer_config is not None, ( "kv_transfer_config must be set for MoRIIOConnector" ) + _warn_deprecated_env_vars() kv_transfer_config = vllm_config.kv_transfer_config extra_config = kv_transfer_config.kv_connector_extra_config tp_rank = get_tensor_model_parallel_rank() @@ -234,6 +273,10 @@ def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": dp_rank=dp_rank, dp_size=dp_size, tp_size=tp_size, + read_mode=get_moriio_mode(kv_transfer_config) == MoRIIOMode.READ, + qp_per_transfer=int(extra_config.get("qp_per_transfer", 1)), + post_batch_size=int(extra_config.get("post_batch_size", -1)), + num_workers=int(extra_config.get("num_workers", 1)), backend=backend, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index 804103275ae3..dc7264c8f7cf 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -108,7 +108,7 @@ def __init__( + ":" + str(self.kv_transfer_config.kv_connector_extra_config["handshake_port"]) ) - self.mode = get_moriio_mode() + self.mode = get_moriio_mode(self.kv_transfer_config) if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: MoRIIOConnectorScheduler | None = ( MoRIIOConnectorScheduler(vllm_config, self.engine_id) @@ -250,7 +250,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.kv_transfer_config = vllm_config.kv_transfer_config self.block_size = vllm_config.cache_config.block_size self.engine_id: EngineId = engine_id - self.mode = get_moriio_mode() + self.mode = get_moriio_mode(self.kv_transfer_config) self.host_ip = get_ip() self.handshake_port = self.kv_transfer_config.kv_connector_extra_config[ "handshake_port" @@ -615,8 +615,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): "is installed and properly configured." ) + assert vllm_config.kv_transfer_config is not None self.moriio_config = MoRIIOConfig.from_vllm_config(vllm_config) - self.mode = get_moriio_mode() + self.mode = ( + MoRIIOMode.READ if self.moriio_config.read_mode else MoRIIOMode.WRITE + ) logger.info("Initializing MoRIIO worker %s", engine_id) @@ -700,7 +703,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): if self.moriio_config.backend == "xgmi" else BackendType.RDMA ) - self.moriio_wrapper.set_backend_type(backend) + self.moriio_wrapper.set_backend_type( + backend, + qp_per_transfer=self.moriio_config.qp_per_transfer, + post_batch_size=self.moriio_config.post_batch_size, + num_workers=self.moriio_config.num_workers, + ) self.moriio_wrapper.notify_port = self.moriio_config.notify_port self.local_kv_cache_metadata: list[bytes] = [] self.local_kv_cache_size: list[int] = [] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py index 78c8d4860c1b..86f2533837ba 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py @@ -8,7 +8,6 @@ import torch import zmq -from vllm import envs from vllm.logger import init_logger from vllm.utils.network_utils import ( make_zmq_path, @@ -16,7 +15,7 @@ ) if TYPE_CHECKING: - pass + from mori.io import BackendType from queue import Empty, Queue @@ -376,7 +375,13 @@ def set_moriio_engine(self, moriio_engine): ) self.moriio_engine = moriio_engine - def set_backend_type(self, backend_type): + def set_backend_type( + self, + backend_type: "BackendType", + qp_per_transfer: int = 1, + post_batch_size: int = -1, + num_workers: int = 1, + ) -> None: assert self.moriio_engine is not None, "MoRIIO engine must be set first" if backend_type == BackendType.XGMI: logger.info("Using MoRIIO backend: XGMI") @@ -385,14 +390,14 @@ def set_backend_type(self, backend_type): logger.info( "Using MoRIIO backend: RDMA " "(qp_per_transfer=%d, post_batch_size=%d, num_workers=%d)", - envs.VLLM_MORIIO_QP_PER_TRANSFER, - envs.VLLM_MORIIO_POST_BATCH_SIZE, - envs.VLLM_MORIIO_NUM_WORKERS, + qp_per_transfer, + post_batch_size, + num_workers, ) rdma_cfg = RdmaBackendConfig( - envs.VLLM_MORIIO_QP_PER_TRANSFER, - envs.VLLM_MORIIO_POST_BATCH_SIZE, - envs.VLLM_MORIIO_NUM_WORKERS, + qp_per_transfer, + post_batch_size, + num_workers, PollCqMode.POLLING, ) self.moriio_engine.create_backend(backend_type, rdma_cfg) diff --git a/vllm/envs.py b/vllm/envs.py index b2c5f22567fa..9b4a6790b6cf 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -216,10 +216,6 @@ VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_ROCM_QUICK_REDUCE_MIN_SIZE_BYTES_MB: int | None = None VLLM_ROCM_QUICK_REDUCE_QUANTIZATION_MIN_SIZE_KB: int | None = None - VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False - VLLM_MORIIO_QP_PER_TRANSFER: int = 1 - VLLM_MORIIO_POST_BATCH_SIZE: int = -1 - VLLM_MORIIO_NUM_WORKERS: int = 1 VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" @@ -1642,20 +1638,6 @@ def _resolve_rust_frontend_path() -> str | None: "Use --linear-backend emulation.", lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))), ), - # Controls the read mode for the Mori-IO connector - "VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: ( - os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() in ("true", "1") - ), - # Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector - "VLLM_MORIIO_QP_PER_TRANSFER": lambda: int( - os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1") - ), - # Controls the post-processing batch size for the Mori-IO connector - "VLLM_MORIIO_POST_BATCH_SIZE": lambda: int( - os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1") - ), - # Controls the number of workers for Mori operations for the Mori-IO connector - "VLLM_MORIIO_NUM_WORKERS": lambda: int(os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")), # Timeout (in seconds) for MooncakeConnector in PD disaggregated setup. "VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT": lambda: int( os.getenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "480")