Skip to content
Merged
117 changes: 42 additions & 75 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
import tempfile
import textwrap
import time
import uuid
from collections import defaultdict
from typing import Optional
from unittest.mock import patch

import pytest
import ray

from vllm import LLM
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker)
from vllm.forward_context import ForwardContext
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper
from vllm.sampling_params import SamplingParams

from .utils import create_request, create_scheduler, create_vllm_config


def _make_stub_pkg() -> str:
"""Return a directory that makes
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper."""
td = tempfile.mkdtemp()
pkg_root = os.path.join(td, "nixl", "_api")
os.makedirs(pkg_root, exist_ok=True)

stub = textwrap.dedent("""\
# Forward the real FakeNixlWrapper that the driver already defined.
print("In fake package")
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent
""")
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
f.write(stub)

# touch parent package
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
return td


def test_basic_interface():
"""Unit test for basic NixlConnector interface functionality."""

Expand Down Expand Up @@ -87,77 +109,6 @@ def test_prompt_less_than_block_size():
assert len(scheduler_output.scheduled_new_reqs) == 1


class FakeNixlWrapper:
"""Mock implementation of NixlWrapper for testing.

We don't inherit from nixl._api.nixl_agent because nixl may not be
installed.
"""

AGENT_METADATA = b"fake_agent_metadata"
REMOTE_AGENT_NAME = "remote_agent"

def __init__(self, agent_name: str, *args, **kwargs):
self._cycles_before_xfer_done = 0
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
lambda: 0)

def get_reg_descs(self, caches_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in caches_data]

def register_memory(self, descs) -> None:
pass

def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in blocks_data]

def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
return uuid.uuid4().int

def get_agent_metadata(self) -> bytes:
return self.AGENT_METADATA

def add_remote_agent(self, agent_metadata: bytes) -> str:
return self.REMOTE_AGENT_NAME

def get_new_notifs(self) -> dict[str, list[bytes]]:
# Used to collect done_sending, which we don't test yet.
return {}

def check_xfer_state(self, handle: int) -> str:
if self._check_xfer_state_cycles[
handle] >= self._cycles_before_xfer_done:
return "DONE"
self._check_xfer_state_cycles[handle] += 1
return "PROC"

def release_xfer_handle(self, handle: int) -> None:
pass

def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
pass

def make_prepped_xfer(self,
xfer_type: str,
local_xfer_side_handle: int,
local_block_descs_ids: list[int],
remote_xfer_side_handle: int,
remote_block_descs_ids: list[int],
notif_msg: Optional[bytes] = None) -> int:
return uuid.uuid4().int

def transfer(self, handle: int) -> str:
return "PROC"

############################################################
# Follow are for changing the behavior during testing.
############################################################

def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done."""
self._cycles_before_xfer_done = cycles


class FakeNixlConnectorWorker(NixlConnectorWorker):

REMOTE_ENGINE_ID = "remote_engine"
Expand Down Expand Up @@ -378,10 +329,14 @@ def test_concurrent_load_kv(
raise TimeoutError("Took too long to complete async handshake.")


# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kouroshHakha could you elaborate on the resource cleanup problem with mp. Are there processes left running?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah the second parametrization regardless of whether it's "mp" or "ray" hits an OOM issue on gpu. I tried using vllm.distributed.cleanup_dist_env_and_memory but didn't quite work.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized this test originally did not run on mp, distributed_executor_backend used to be None so I changed the latest version to None to only add ray to what was covered before.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, we should look into the mp thing as a follow-on I guess

# we put here is important. First run ray, it will clean up the resources, then
# the rest of the tests.
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_abort_timeout_on_prefiller(monkeypatch):
def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
"""
Test lifecycle of an aborted Remote Prefill request hitting the timeout.
-----> P
Expand All @@ -399,11 +354,23 @@ def test_abort_timeout_on_prefiller(monkeypatch):
timeout = 6
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))

# Build runtime_env only if we’re using Ray
if distributed_executor_backend == "ray":
runtime_env = {
"working_dir": _make_stub_pkg(), # ship stub package
"env_vars": {
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
},
}
ray.init(runtime_env=runtime_env)

llm = LLM(
model=model_name,
enforce_eager=True,
gpu_memory_utilization=0.5,
kv_transfer_config=kv_transfer_config,
distributed_executor_backend=distributed_executor_backend,
)
remote_prefill_opts = {
"do_remote_decode": True,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the tests and bug fix from #21048

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's now been merged to main so can rebase.

# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from collections import defaultdict
from concurrent.futures import Future
from typing import Optional

from vllm.v1.executor.multiproc_executor import MultiprocExecutor
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.v1.outputs import ModelRunnerOutput


class DummyMultiprocExecutor(MultiprocExecutor):

def __init__(self, output_rank, world_size):
# Manually initialize minimal required fields
self.output_rank = output_rank
self.world_size = world_size
self._send_remaining_count = defaultdict[str,
int](lambda: self.world_size)
self._recv_remaining_count = defaultdict[str,
int](lambda: self.world_size)
self.io_thread_pool = None
self.shutdown_event = threading.Event()


class DummyModelRunnerOutput(ModelRunnerOutput):

def __init__(self,
Expand All @@ -33,14 +17,14 @@ def __init__(self,


def test_aggregate_workers_output():
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
aggregator = KVOutputAggregator(world_size=2)

output1 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
output2 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)

aggregated = executor._aggregate_workers_output([output1, output2])
aggregated = aggregator.aggregate([output1, output2])

assert aggregated is output1
assert aggregated.finished_sending is None
Expand All @@ -51,7 +35,7 @@ def test_aggregate_workers_output():
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving=None)

aggregated = executor._aggregate_workers_output([output1, output2])
aggregated = aggregator.aggregate([output1, output2])

assert aggregated is output1
assert aggregated.finished_sending == {'req1'}
Expand All @@ -62,20 +46,19 @@ def test_aggregate_workers_output():
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})

aggregated = executor._aggregate_workers_output([output1, output2])
aggregated = aggregator.aggregate([output1, output2])

assert aggregated is output1
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {'req2'}


def test_async_aggregate_workers_output():
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
aggregator = KVOutputAggregator(world_size=2)

future1: Future[DummyModelRunnerOutput] = Future()
future2: Future[DummyModelRunnerOutput] = Future()
result_future = executor._async_aggregate_workers_output(
[future1, future2])
result_future = aggregator.async_aggregate([future1, future2])

output1 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
Expand All @@ -92,8 +75,7 @@ def test_async_aggregate_workers_output():

future1 = Future()
future2 = Future()
result_future = executor._async_aggregate_workers_output(
[future1, future2])
result_future = aggregator.async_aggregate([future1, future2])

output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
Expand All @@ -110,8 +92,7 @@ def test_async_aggregate_workers_output():

future1 = Future()
future2 = Future()
result_future = executor._async_aggregate_workers_output(
[future1, future2])
result_future = aggregator.async_aggregate([future1, future2])

output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
Expand Down
90 changes: 90 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,18 @@
"""
KV cache helper for store.
"""
from collections import defaultdict
from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import Optional, cast

import torch

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.v1.outputs import ModelRunnerOutput

logger = init_logger(__name__)

Expand Down Expand Up @@ -107,3 +113,87 @@ def get_kv_connector_cache_layout():
"layout to HND for better xfer performance.")
return "HND"
return "NHD"


class KVOutputAggregator:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This utility class LGTM

"""Utility class to aggregate the output of all workers into a single
output corresponding to Rank 0 for scheduler."""

def __init__(self, world_size: int):
# Complete transfer tracker. Used by to track finished requests
# [req_id -> n_finished_workers]
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
self._send_remaining_count = defaultdict[str, int](lambda: world_size)

def aggregate(self,
outputs: list[ModelRunnerOutput],
output_rank: int = 0) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers

def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int],
finished_set: set[str]) -> None:
for req_id in req_ids or ():
new_count = remaining_count_dict[req_id] - 1
if new_count == 0:
finished_set.add(req_id)
del remaining_count_dict[req_id]
else:
remaining_count_dict[req_id] = new_count

finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,
self._recv_remaining_count, finished_recving)

# select output of the worker specified by output_rank
output = outputs[output_rank]

# set the aggregated finished_sending / finished_recving
# if output.finished_sending/recving is not empty, but the other ranks
# still have unfinished send/recv, we want to set the aggregated
# finished_sending/recving to None until all ranks have finished
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why it's set to None instead of empty set ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's imposed by higher level logic. This part of the PR is inheriting existing logic on master at the time btw.

# send/recv
output.finished_sending = finished_sending if finished_sending else None
output.finished_recving = finished_recving if finished_recving else None

return output

def async_aggregate(self,
output_futures: Sequence[Future[ModelRunnerOutput]],
output_rank: int = 0) -> Future[ModelRunnerOutput]:
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future()

outputs: list[Optional[ModelRunnerOutput]] = [None
] * len(output_futures)

def make_callback(idx):

def callback(fut):
if result_future.done():
return

try:
outputs[idx] = fut.result()
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)

# this check assumes io_thread_pool uses a single thread
if all(outputs):
result_future.set_result(
self.aggregate(cast(list[ModelRunnerOutput], outputs),
output_rank))

return callback

for i, output_future in enumerate(output_futures):
output_future.add_done_callback(make_callback(i))

return result_future
2 changes: 1 addition & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def get_finished(
"""
Notifies worker-side connector ids of requests that have
finished generating tokens on the worker.
The scheduler process (via the MultiprocExecutor) will use this output
The scheduler process (via the Executors) will use this output
to track which workers are done.

Returns:
Expand Down
Empty file added vllm/mocks/__init__.py
Empty file.
Loading