Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def test_multi_xfer_one_engine(
metadata = NixlConnectorMetadata()
if num_xfers > 0:
num_xfers -= 1
metadata.add_new_req(
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3],
kv_transfer_params={
Expand Down Expand Up @@ -532,7 +532,7 @@ def test_async_load_kv(
vllm_config, connector.engine_id
)
metadata = NixlConnectorMetadata()
metadata.add_new_req(
metadata.add_new_req_to_recv(
request_id="id",
local_block_ids=[1, 2, 3],
kv_transfer_params={
Expand Down Expand Up @@ -588,7 +588,7 @@ def test_concurrent_load_kv(
metadata = NixlConnectorMetadata()
total_reqs = 5
for i in range(total_reqs):
metadata.add_new_req(
metadata.add_new_req_to_recv(
request_id=f"id_{i}",
local_block_ids=[1, 2, 3],
kv_transfer_params={
Expand Down Expand Up @@ -752,7 +752,7 @@ def test_kv_connector_stats(dist_init):
# Create transfer metadata
request_id = "test_req_for_stats"
metadata = NixlConnectorMetadata()
metadata.add_new_req(
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[1, 2, 3],
kv_transfer_params={
Expand Down Expand Up @@ -1515,7 +1515,7 @@ def test_handshake_failure_returns_finished(dist_init):

request_id = "test_handshake_fail"
metadata = NixlConnectorMetadata()
metadata.add_new_req(
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[1, 2, 3],
kv_transfer_params={
Expand Down Expand Up @@ -1565,7 +1565,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):

request_id = "test_transfer_fail"
metadata = NixlConnectorMetadata()
metadata.add_new_req(
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[7, 8, 9],
kv_transfer_params={
Expand Down
95 changes: 56 additions & 39 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,22 @@ def compute_nixl_compatibility_hash(
return compat_hash


@dataclass
Copy link
Copy Markdown
Member

@njhill njhill Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: perf benefit

Suggested change
@dataclass
@dataclass(slots=True)

can add it to other dataclasses here too!

Edit: can do this in a follow-on

class RemoteMeta:
block_ids: list[int]
host: str
port: int
engine_id: str
request_id: str


@dataclass
class ReqMeta:
local_block_ids: list[int]
# To be used when logical block size does not match the kernel block size
local_physical_block_ids: list[int]
remote_block_ids: list[int]
remote_host: str
remote_port: int
remote_engine_id: str
remote_request_id: str
tp_size: int
remote: RemoteMeta | None = None


class NixlConnectorMetadata(KVConnectorMetadata):
Expand All @@ -223,31 +228,43 @@ def __init__(self):
self.reqs_in_batch: set[ReqId] = set()
self.reqs_not_processed: set[ReqId] = set()

def add_new_req(
def _add_new_req(
self,
request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
load_remote_cache: bool = True,
save_to_host: bool = False,
):
# save and load are mutually exclusive
assert load_remote_cache ^ save_to_host
_req = ReqMeta(
) -> ReqMeta:
return ReqMeta(
local_block_ids=local_block_ids,
local_physical_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_request_id=kv_transfer_params["remote_request_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
# P workers don't need to receive tp_size from proxy here.
tp_size=kv_transfer_params.get("tp_size", 1),
)
if save_to_host:
self.reqs_to_save[request_id] = _req
if load_remote_cache:
self.reqs_to_recv[request_id] = _req

def add_new_req_to_save(
self,
request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
):
self.reqs_to_save[request_id] = self._add_new_req(
local_block_ids, kv_transfer_params
)

def add_new_req_to_recv(
self,
request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
):
req = self._add_new_req(local_block_ids, kv_transfer_params)
req.remote = RemoteMeta(
block_ids=kv_transfer_params["remote_block_ids"],
engine_id=kv_transfer_params["remote_engine_id"],
request_id=kv_transfer_params["remote_request_id"],
host=kv_transfer_params["remote_host"],
port=kv_transfer_params["remote_port"],
)
self.reqs_to_recv[request_id] = req


class NixlConnector(KVConnectorBase_V1):
Expand Down Expand Up @@ -666,22 +683,18 @@ def build_connector_meta(
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
meta.add_new_req(
meta.add_new_req_to_recv(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
load_remote_cache=True,
save_to_host=False,
)

for req_id, (req, block_ids) in self._reqs_need_save.items():
assert req.kv_transfer_params is not None
meta.add_new_req(
meta.add_new_req_to_save(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
load_remote_cache=False,
save_to_host=True,
)

meta.reqs_to_send = self._reqs_need_send
Expand Down Expand Up @@ -1124,10 +1137,11 @@ def _background_nixl_handshake(
# Do NIXL handshake in background and add to _ready_requests when done.
fut = self._handshake_futures.get(remote_engine_id)
if fut is None:
assert meta.remote is not None
fut = self._handshake_initiation_executor.submit(
self._nixl_handshake,
meta.remote_host,
meta.remote_port,
meta.remote.host,
meta.remote.port,
meta.tp_size,
remote_engine_id,
)
Expand Down Expand Up @@ -1774,14 +1788,15 @@ def get_finished(self) -> tuple[set[str], set[str]]:
# clean up metadata for completed requests
meta = self._recving_metadata.pop(req_id, None)
assert meta is not None, f"{req_id} not found in recving_metadata list"
assert meta.remote is not None
if self.use_host_buffer:
self.sync_recved_kv_to_device(req_id, meta)
if self.enable_permute_local_kv:
block_ids_to_permute += meta.local_physical_block_ids

# post processing for heteroblocksize
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
meta.remote_engine_id
meta.remote.engine_id
)
if (
not self.use_mla
Expand Down Expand Up @@ -1916,17 +1931,18 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
meta.remote_block_ids = self._logical_to_kernel_block_ids(
meta.remote_block_ids
assert meta.remote is not None
meta.remote.block_ids = self._logical_to_kernel_block_ids(
meta.remote.block_ids
)
remote_engine_id = meta.remote_engine_id
remote_engine_id = meta.remote.engine_id
logger.debug(
"start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
req_id,
remote_engine_id,
len(meta.local_physical_block_ids),
len(meta.remote_block_ids),
len(meta.remote.block_ids),
)
# always store metadata for failure recovery
self._recving_metadata[req_id] = meta
Expand Down Expand Up @@ -1965,17 +1981,18 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
self._reqs_to_send[req_id] = expiration_time

def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None
logger.debug(
"Remote agent %s available, calling _read_blocks for req %s",
meta.remote_engine_id,
meta.remote.engine_id,
req_id,
)
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
remote_request_id=meta.remote_request_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_block_ids=meta.remote.block_ids,
)

def _read_blocks(
Expand Down