Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,19 @@ def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig":

# TODO : merge notify_port and handshake_port to simplify port management
# supports non-contiguous ports
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
assert (
vllm_config.kv_transfer_config is not None
), "kv_transfer_config must be set for MoRIIOConnector"
_warn_deprecated_env_vars()
kv_transfer_config = vllm_config.kv_transfer_config
extra_config = kv_transfer_config.kv_connector_extra_config
tp_rank = get_tensor_model_parallel_rank()
dp_rank = vllm_config.parallel_config.data_parallel_rank
dp_rank = (
vllm_config.parallel_config.data_parallel_rank
% vllm_config.parallel_config.data_parallel_size_local
)
base_notify_port = int(extra_config["notify_port"])
dp_size = vllm_config.parallel_config.data_parallel_size
dp_size = vllm_config.parallel_config.data_parallel_size_local
tp_size = get_tensor_model_parallel_world_size()
port_offset = get_port_offset(dp_rank, tp_rank)
backend = str(extra_config.get("backend", "rdma")).lower()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import math
import os
import queue
import threading
import time
Expand Down Expand Up @@ -97,9 +98,9 @@ def __init__(
kv_cache_config: "KVCacheConfig",
):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
assert (
vllm_config.kv_transfer_config is not None
), "kv_transfer_config must be set for MoRIIOConnector"

self.kv_transfer_config = vllm_config.kv_transfer_config
self._set_port_defaults(vllm_config)
Expand Down Expand Up @@ -129,9 +130,9 @@ def __init__(
############################################################

def _set_port_defaults(self, vllm_config: VllmConfig):
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
assert (
vllm_config.kv_transfer_config is not None
), "kv_transfer_config must be set for MoRIIOConnector"
kv_transfer_config = vllm_config.kv_transfer_config
extra_config = kv_transfer_config.kv_connector_extra_config

Expand Down Expand Up @@ -209,13 +210,13 @@ def save_kv_layer(
# Only producer/prefill saves KV Cache
if get_role() == ROLE.CONSUMER:
return
assert self.connector_worker is not None, (
"save_kv_layer called on scheduler role"
)
assert (
self.connector_worker is not None
), "save_kv_layer called on scheduler role"

assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata), (
"Connector metadata not initialized yet"
)
assert isinstance(
self._connector_metadata, MoRIIOConnectorMetadata
), "Connector metadata not initialized yet"
self.connector_worker.save_kv_layer(
self._connector_metadata, layer_name, kv_layer, attn_metadata, **kwargs
)
Expand Down Expand Up @@ -249,9 +250,9 @@ class MoRIIOConnectorScheduler:
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config

assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
assert (
vllm_config.kv_transfer_config is not None
), "kv_transfer_config must be set for MoRIIOConnector"
self.kv_transfer_config = vllm_config.kv_transfer_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id: EngineId = engine_id
Expand All @@ -266,13 +267,21 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
"notify_port"
]
self.tp_size = self.vllm_config.parallel_config.tensor_parallel_size
self.dp_rank = self.vllm_config.parallel_config.data_parallel_rank
self.dp_rank = (
self.vllm_config.parallel_config.data_parallel_rank
% self.vllm_config.parallel_config.data_parallel_size_local
)
self._is_kv_master = (
self.vllm_config.parallel_config.data_parallel_rank
< self.vllm_config.parallel_config.data_parallel_size_local
)
self.is_producer = self.kv_transfer_config.kv_role == "kv_producer"
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
self._req_kv_params: dict[ReqId, dict] = {}

# For chunked prefill, we perform layer-wise access within the final chunk.
# TODO: Perform transfer at end chunk.
Expand Down Expand Up @@ -393,6 +402,7 @@ def update_state_after_alloc(
if params.get("do_remote_decode"):
local_block_ids = blocks.get_block_ids()[0]
self._reqs_need_save[request.request_id] = (request, local_block_ids)
self._req_kv_params[request.request_id] = dict(params)

if params is not None and params.get("do_remote_prefill"):
if self.mode == MoRIIOMode.READ:
Expand All @@ -416,6 +426,7 @@ def update_state_after_alloc(
request,
local_block_ids,
)
self._req_kv_params[request.request_id] = dict(params)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
Expand All @@ -425,30 +436,33 @@ def update_state_after_alloc(

else:
# WRITE mode, decode side: notify P that blocks are ready
assert request.kv_transfer_params is not None, (
"kv_transfer_params should not be None"
)
assert (
request.kv_transfer_params is not None
), "kv_transfer_params should not be None"

remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0)

peer_zmq = get_peer_zmq_from_request_id(
request.request_id, is_producer=False
)
remote_host, _, remote_notify_port = parse_moriio_zmq_address(peer_zmq)

for tp_index in range(self.tp_size):
target_port = remote_notify_port + get_port_offset(
remote_dp_rank, tp_index
if self._is_kv_master:
peer_zmq = get_peer_zmq_from_request_id(
request.request_id, is_producer=False
)

self.send_notify_block(
req_id=request.request_id,
transfer_id=request.kv_transfer_params["transfer_id"],
block_notify_list=blocks.get_block_ids()[0],
host=remote_host,
port=target_port,
remote_host, _, remote_notify_port = parse_moriio_zmq_address(
peer_zmq
)

for tp_index in range(self.tp_size):
target_port = remote_notify_port + get_port_offset(
remote_dp_rank, tp_index
)

self.send_notify_block(
req_id=request.request_id,
transfer_id=request.kv_transfer_params["transfer_id"],
block_notify_list=blocks.get_block_ids()[0],
host=remote_host,
port=target_port,
)

# Only trigger 1 KV transfer per request.

params["do_remote_prefill"] = False
Expand All @@ -468,12 +482,12 @@ def build_connector_meta(
for new_req in scheduler_output.scheduled_new_reqs:
red_id = new_req.req_id
local_block_ids = list(new_req.block_ids)[0]
assert new_req.sampling_params is not None, (
f"sampling_params is None for req {new_req.req_id}"
)
assert hasattr(new_req.sampling_params, "extra_args"), (
f"sampling_params missing extra_args for req {new_req.req_id}"
)
assert (
new_req.sampling_params is not None
), f"sampling_params is None for req {new_req.req_id}"
assert hasattr(
new_req.sampling_params, "extra_args"
), f"sampling_params missing extra_args for req {new_req.req_id}"
kv_transfer_params = (
new_req.sampling_params.extra_args.get("kv_transfer_params", {})
if new_req.sampling_params.extra_args
Expand Down Expand Up @@ -507,39 +521,47 @@ def build_connector_meta(
* self.block_size
>= req.num_prompt_tokens
):
kv_params = self._req_kv_params.pop(
req_id, req.kv_transfer_params or {}
)
meta.add_new_req(
request_id=req_id,
local_block_ids=self._reqs_need_pending_save[req_id][1],
kv_transfer_params=req.kv_transfer_params or {},
kv_transfer_params=kv_params,
write_mode=True,
)
del self._reqs_need_pending_save[req_id]

# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
kv_params = self._req_kv_params.get(req_id, req.kv_transfer_params or {})
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
kv_transfer_params=kv_params,
)

for req_id, (req, block_ids) in self._reqs_need_save.items():
assert req.kv_transfer_params is not None
kv_params = self._req_kv_params.get(req_id, req.kv_transfer_params or {})
if req.num_prompt_tokens > len(block_ids) * self.block_size:
# not last chunk prefill
self._reqs_need_pending_save[req_id] = (req, block_ids)
continue
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
kv_transfer_params=kv_params,
write_mode=True,
)
# Clear the list once workers start the transfers

meta.reqs_to_send = self._reqs_need_send

for req_id in self._reqs_need_recv:
self._req_kv_params.pop(req_id, None)
for req_id in self._reqs_need_save:
if req_id not in self._reqs_need_pending_save:
self._req_kv_params.pop(req_id, None)
self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_need_send = {}
Expand Down Expand Up @@ -700,9 +722,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):

# Config.
self.vllm_config = vllm_config
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
assert (
vllm_config.kv_transfer_config is not None
), "kv_transfer_config must be set for MoRIIOConnector"
self.kv_transfer_config = vllm_config.kv_transfer_config
self.is_producer = self.kv_transfer_config.is_kv_producer

Expand Down Expand Up @@ -839,6 +861,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._recving_transfers: defaultdict[ReqId, list] = defaultdict(list)
# Values are (remote_host, remote_notify_port, transfer_id).
self._recving_transfers_callback_addr: dict[ReqId, tuple[str, str, str]] = {}
# Track when each transfer started (keyed by transfer_id) for
# deferred-write timeout tracking.
self._recving_transfers_start: dict[str, float] = {}

# Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {}
Expand Down Expand Up @@ -1242,7 +1267,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
self.slot_size_bytes = (
kv_elem_size * n_kv_heads * head_dim
) # 1 token 1 layer size , slot size
assert block_size == self.block_size
if block_size != self.block_size:
logger.info(
"KV cache block_size=%d differs from config block_size=%d; "
"using actual tensor shape (attention backend override).",
block_size,
self.block_size,
)
self.block_size = block_size
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
Expand Down Expand Up @@ -1382,6 +1414,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:

def _pop_done_transfers(self) -> set[str]:
done_req_ids: set[str] = set()
_xfer_timeout = int(os.environ.get("VLLM_MORIIO_TRANSFER_TIMEOUT_S", "120"))
with self.moriio_wrapper.lock:
to_remove = []
for req_id, status_list in self._recving_transfers.items():
Expand Down Expand Up @@ -1411,10 +1444,20 @@ def _pop_done_transfers(self) -> set[str]:
to_remove.append(req_id)
# Do NOT add to done_req_ids: decode KV cache is incomplete.
# The request will expire via the normal request timeout.
elif req_id in self._recving_transfers_start:
# PR #39276: timeout tracking — abort still-in-flight
# transfers that exceeded the deadline, so we don't hang
# forever waiting on a lost RDMA completion.
_age = time.monotonic() - self._recving_transfers_start[req_id]
if _age > _xfer_timeout:
logger.error(
"RDMA read TIMED OUT for req %s after %.1fs", req_id, _age
)
to_remove.append(req_id)
for req_id in to_remove:
del self._recving_transfers[req_id]
del self._recving_transfers_callback_addr[req_id]

self._recving_transfers_start.pop(req_id, None)
return done_req_ids

def save_kv_layer(
Expand Down Expand Up @@ -1737,6 +1780,7 @@ def _read_blocks(
)
with self.moriio_wrapper.lock:
self._recving_transfers[request_id].append(transfer_status)
self._recving_transfers_start.setdefault(request_id, time.monotonic())
self._recving_transfers_callback_addr[request_id] = (
remote_host,
str(remote_notify_port + self.tp_rank),
Expand Down
Loading
Loading