Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions tests/v1/kv_connector/unit/test_moriio_connector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
import os
import subprocess
import uuid
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -202,6 +201,7 @@ def create_vllm_config(
enable_chunked_prefill: bool = True,
enable_permute_local_kv: bool = False,
role="kv_consumer",
read_mode: bool = False,
) -> VllmConfig:
"""Initialize VllmConfig for testing."""
scheduler_config = SchedulerConfig(
Expand All @@ -228,6 +228,7 @@ def create_vllm_config(
kv_connector="MoRIIOConnector",
kv_role=role,
enable_permute_local_kv=enable_permute_local_kv,
kv_connector_extra_config={"read_mode": read_mode},
)
return VllmConfig(
scheduler_config=scheduler_config,
Expand All @@ -238,15 +239,6 @@ def create_vllm_config(
)


@pytest.fixture
def moriio_read_mode():
"""Force the connector into read mode via env for tests."""
os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True"
yield
# Cleanup after test
os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None)


def test_write_mode_saves_local_block_ids():
"""Write mode records local block ids in MoRIIOConnectorMetadata.reqs_to_save."""

Expand Down Expand Up @@ -358,11 +350,11 @@ def test_write_mode_with_chunked_prefill_saves_local_block_ids():
assert block_id == block.block_id, f"{block_id} != {block.block_id}"


def test_read_mode_loads_remote_block_ids(moriio_read_mode):
def test_read_mode_loads_remote_block_ids():
"""Read mode loads remote block ids into local cache mapping."""

# Setup Scheduler and Request
vllm_config = create_vllm_config(role="kv_consumer")
vllm_config = create_vllm_config(role="kv_consumer", read_mode=True)
scheduler = create_scheduler(vllm_config)

# 2 Full Blocks and 1 Half Block.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import threading
import time
from collections.abc import Iterator
Expand All @@ -12,8 +13,7 @@
import torch
import zmq

from vllm import envs
from vllm.config import VllmConfig
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
)
Expand Down Expand Up @@ -162,8 +162,10 @@ class TransferError(MoRIIOError):
pass


def get_moriio_mode() -> MoRIIOMode:
read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE
def get_moriio_mode(kv_transfer_config: KVTransferConfig) -> MoRIIOMode:
read_mode = str(
kv_transfer_config.kv_connector_extra_config.get("read_mode", "false")
).lower().strip() in ("true", "1")
logger.debug("MoRIIO Connector read_mode: %s", read_mode)
if read_mode:
return MoRIIOMode.READ
Expand All @@ -175,6 +177,26 @@ def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
return (dp_rank) * tp_size + tp_rank


_DEPRECATED_ENV_VARS: dict[str, str] = {
"VLLM_MORIIO_CONNECTOR_READ_MODE": "read_mode",
"VLLM_MORIIO_QP_PER_TRANSFER": "qp_per_transfer",
"VLLM_MORIIO_POST_BATCH_SIZE": "post_batch_size",
"VLLM_MORIIO_NUM_WORKERS": "num_workers",
}


def _warn_deprecated_env_vars() -> None:
for env_var, new_key in _DEPRECATED_ENV_VARS.items():
if env_var in os.environ:
logger.warning_once(
"The environment variable %s is deprecated and ignored. "
"Set %r inside kv_transfer_config.kv_connector_extra_config "
"instead.",
env_var,
new_key,
)


@dataclass
class MoRIIOConfig:
local_ip: str
Expand All @@ -189,6 +211,10 @@ class MoRIIOConfig:
dp_rank: int
dp_size: int
tp_size: int
read_mode: bool = False
qp_per_transfer: int = 1
post_batch_size: int = -1
num_workers: int = 1
backend: str = "rdma"

@classmethod
Expand All @@ -201,11 +227,24 @@ def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig":
# notify_port -> For synchronizing stages between prefill and decode
# handshake_port -> For initial handshake between mori engine

# Optional tuning knobs
# read_mode -> If true, run the connector in READ mode (consumer
# pulls KV from producer) instead of the default
# WRITE mode.

# Knobs for RDMA transfers, ignored if on xgmi backend
# qp_per_transfer -> Number of RDMA Queue Pairs per KV transfer.
# post_batch_size -> Batch size for posting transfer work requests
# (-1 lets the MoRI backend choose).
# num_workers -> Number of background worker threads the MoRI
# engine uses for transfer processing.

# TODO : merge notify_port and handshake_port to simplify port management
# supports non-contiguous ports
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
_warn_deprecated_env_vars()
kv_transfer_config = vllm_config.kv_transfer_config
extra_config = kv_transfer_config.kv_connector_extra_config
tp_rank = get_tensor_model_parallel_rank()
Expand Down Expand Up @@ -234,6 +273,10 @@ def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig":
dp_rank=dp_rank,
dp_size=dp_size,
tp_size=tp_size,
read_mode=get_moriio_mode(kv_transfer_config) == MoRIIOMode.READ,
qp_per_transfer=int(extra_config.get("qp_per_transfer", 1)),
post_batch_size=int(extra_config.get("post_batch_size", -1)),
num_workers=int(extra_config.get("num_workers", 1)),
backend=backend,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
+ ":"
+ str(self.kv_transfer_config.kv_connector_extra_config["handshake_port"])
)
self.mode = get_moriio_mode()
self.mode = get_moriio_mode(self.kv_transfer_config)
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: MoRIIOConnectorScheduler | None = (
MoRIIOConnectorScheduler(vllm_config, self.engine_id)
Expand Down Expand Up @@ -250,7 +250,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.kv_transfer_config = vllm_config.kv_transfer_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id: EngineId = engine_id
self.mode = get_moriio_mode()
self.mode = get_moriio_mode(self.kv_transfer_config)
self.host_ip = get_ip()
self.handshake_port = self.kv_transfer_config.kv_connector_extra_config[
"handshake_port"
Expand Down Expand Up @@ -615,8 +615,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
"is installed and properly configured."
)

assert vllm_config.kv_transfer_config is not None
self.moriio_config = MoRIIOConfig.from_vllm_config(vllm_config)
self.mode = get_moriio_mode()
self.mode = (
MoRIIOMode.READ if self.moriio_config.read_mode else MoRIIOMode.WRITE
)

logger.info("Initializing MoRIIO worker %s", engine_id)

Expand Down Expand Up @@ -700,7 +703,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
if self.moriio_config.backend == "xgmi"
else BackendType.RDMA
)
self.moriio_wrapper.set_backend_type(backend)
self.moriio_wrapper.set_backend_type(
backend,
qp_per_transfer=self.moriio_config.qp_per_transfer,
post_batch_size=self.moriio_config.post_batch_size,
num_workers=self.moriio_config.num_workers,
)
self.moriio_wrapper.notify_port = self.moriio_config.notify_port
self.local_kv_cache_metadata: list[bytes] = []
self.local_kv_cache_size: list[int] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
import torch
import zmq

from vllm import envs
from vllm.logger import init_logger
from vllm.utils.network_utils import (
make_zmq_path,
make_zmq_socket,
)

if TYPE_CHECKING:
pass
from mori.io import BackendType

from queue import Empty, Queue

Expand Down Expand Up @@ -376,7 +375,13 @@ def set_moriio_engine(self, moriio_engine):
)
self.moriio_engine = moriio_engine

def set_backend_type(self, backend_type):
def set_backend_type(
self,
backend_type: "BackendType",
qp_per_transfer: int = 1,
post_batch_size: int = -1,
num_workers: int = 1,
) -> None:
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
if backend_type == BackendType.XGMI:
logger.info("Using MoRIIO backend: XGMI")
Expand All @@ -385,14 +390,14 @@ def set_backend_type(self, backend_type):
logger.info(
"Using MoRIIO backend: RDMA "
"(qp_per_transfer=%d, post_batch_size=%d, num_workers=%d)",
envs.VLLM_MORIIO_QP_PER_TRANSFER,
envs.VLLM_MORIIO_POST_BATCH_SIZE,
envs.VLLM_MORIIO_NUM_WORKERS,
qp_per_transfer,
post_batch_size,
num_workers,
)
rdma_cfg = RdmaBackendConfig(
envs.VLLM_MORIIO_QP_PER_TRANSFER,
envs.VLLM_MORIIO_POST_BATCH_SIZE,
envs.VLLM_MORIIO_NUM_WORKERS,
qp_per_transfer,
post_batch_size,
num_workers,
PollCqMode.POLLING,
)
self.moriio_engine.create_backend(backend_type, rdma_cfg)
Expand Down
18 changes: 0 additions & 18 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,6 @@
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
VLLM_ROCM_QUICK_REDUCE_MIN_SIZE_BYTES_MB: int | None = None
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION_MIN_SIZE_KB: int | None = None
VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False
VLLM_MORIIO_QP_PER_TRANSFER: int = 1
VLLM_MORIIO_POST_BATCH_SIZE: int = -1
VLLM_MORIIO_NUM_WORKERS: int = 1
VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
VLLM_LOOPBACK_IP: str = ""
Expand Down Expand Up @@ -1642,20 +1638,6 @@ def _resolve_rust_frontend_path() -> str | None:
"Use --linear-backend emulation.",
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))),
),
# Controls the read mode for the Mori-IO connector
"VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: (
os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() in ("true", "1")
),
# Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector
"VLLM_MORIIO_QP_PER_TRANSFER": lambda: int(
os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1")
),
# Controls the post-processing batch size for the Mori-IO connector
"VLLM_MORIIO_POST_BATCH_SIZE": lambda: int(
os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1")
),
# Controls the number of workers for Mori operations for the Mori-IO connector
"VLLM_MORIIO_NUM_WORKERS": lambda: int(os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")),
# Timeout (in seconds) for MooncakeConnector in PD disaggregated setup.
"VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT": lambda: int(
os.getenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "480")
Expand Down
Loading