Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions docs/features/nixl_connector_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,15 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
- Set when prefiller and decoder are on different machines
- Connection info is passed via KVTransferParams from prefiller to decoder for handshake

- `VLLM_NIXL_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional)
- `kv_lease_duration` (via `kv_connector_extra_config`): Lease duration (in seconds) for the prefiller's KV cache blocks. (Optional)
- Default: 30
- When a prefill request finishes, its KV blocks are held for this duration waiting for the decoder to read them. While the request is queued on the decoder, periodic heartbeats automatically extend the lease. If neither a heartbeat nor a read notification arrives before the lease expires, the blocks are freed. The heartbeat interval and extension amount are derived automatically from this value.
- Example: `--kv-transfer-config '{"kv_connector_extra_config": {"kv_lease_duration": 60}}'`

- `decoder_kv_blocks_ttl` (via `kv_connector_extra_config`): TTL (in seconds) for KV blocks cached on the decoder in bidirectional transfer mode. (Optional)
- Default: 480
- If a request is aborted and the decoder has not yet read the KV-cache blocks through the nixl channel, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely.
- In bidirectional mode, the decoder caches KV blocks for multi-turn conversations. This TTL controls how long those blocks are held before being released. Unlike the prefiller lease, this TTL is not renewed via heartbeats.
- Example: `--kv-transfer-config '{"kv_connector_extra_config": {"decoder_kv_blocks_ttl": 600}}'`

## Multi-Instance Setup

Expand Down
19 changes: 13 additions & 6 deletions tests/v1/kv_connector/unit/test_multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,11 @@ def test_multi_example_connector_consistency():

events = get_connector_events()
# First event is set_xfer_handshake_metadata from initialization, then
# on_new_request when the request is enqueued, then
# get_num_new_matched_tokens and update_state_after_alloc from generate().
assert events["storage1-SCHEDULER"][:4] == [
assert events["storage1-SCHEDULER"][:5] == [
"set_xfer_handshake_metadata",
"on_new_request",
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
Expand All @@ -281,8 +283,9 @@ def test_multi_example_connector_consistency():
"wait_for_layer_load",
"save_kv_layer",
]
assert events["storage2-SCHEDULER"][:4] == [
assert events["storage2-SCHEDULER"][:5] == [
"set_xfer_handshake_metadata",
"on_new_request",
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
Expand Down Expand Up @@ -310,12 +313,14 @@ def test_multi_example_connector_consistency():
# connector so update_state_after_alloc will be with allocated blocks
# on that one but with zero blocks for others (first nonzero match is
# chosen).
assert events["storage1-SCHEDULER"][:3] == [
assert events["storage1-SCHEDULER"][:4] == [
"on_new_request",
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[7] 96",
"build_connector_meta",
]
assert events["storage2-SCHEDULER"][:3] == [
assert events["storage2-SCHEDULER"][:4] == [
"on_new_request",
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
Expand All @@ -336,12 +341,14 @@ def test_multi_example_connector_consistency():
# return 0 from the first connector, but the second connector should have
# a hit, so update_state_after_alloc will only be called with allocated
# blocks for the second connector.
assert events["storage1-SCHEDULER"][:3] == [
assert events["storage1-SCHEDULER"][:4] == [
"on_new_request",
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
assert events["storage2-SCHEDULER"][:3] == [
assert events["storage2-SCHEDULER"][:4] == [
"on_new_request",
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[7] 96",
"build_connector_meta",
Expand Down
6 changes: 2 additions & 4 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,9 +1346,11 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
| {eventually free blocks}
"""
model_name = "Qwen/Qwen3-0.6B"
timeout = 6
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
kv_connector_extra_config={"kv_lease_duration": timeout},
)
llm_kwargs = {
"model": model_name,
Expand All @@ -1358,9 +1360,7 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
"distributed_executor_backend": distributed_executor_backend,
}

timeout = 6
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))

def run_test_and_cleanup():
llm = LLM(**llm_kwargs)
Expand All @@ -1375,8 +1375,6 @@ def run_test_and_cleanup():
runtime_env = {
"working_dir": working_dir, # ship fake nixl package
"env_vars": {
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
# TODO: for ray to carry over, remove once we set
"NIXL_TELEMETRY_ENABLE": "1",
},
}
Expand Down
165 changes: 165 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_heartbeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for the scheduler-driven heartbeat / lease-renewal system."""

import time
from unittest.mock import MagicMock

import pytest

from vllm.v1.outputs import KVConnectorOutput

from .utils import create_request, make_nixl_scheduler

_ENGINE_A = "my-engine-id"


def _sched(kv_lease_duration: int = 30):
return make_nixl_scheduler(heartbeat=True, kv_lease_duration=kv_lease_duration)


def _req(request_id: int = 1):
return create_request(request_id=request_id, do_remote_prefill=True)


def _worker_stub():
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)

w = object.__new__(NixlConnectorWorker)
w._reqs_to_send = {}
w._lease_extension = 20
return w


# ===================================================================
# Scheduler: on_new_request
# ===================================================================


def test_on_new_request_tracks_and_groups():
"""Add two reqs to same engine, one to another; verify grouping."""
s = _sched()
s.on_new_request(_req(1))
s.on_new_request(_req(2))

assert s._heartbeat_by_engine[_ENGINE_A].req_ids == {"prefill-1", "prefill-2"}
info = s._heartbeat_by_engine[_ENGINE_A]
assert (info.host, info.port, info.tp_size) == ("my-host", 1234, 1)
assert s._heartbeat_req_engine["id-1"] == (_ENGINE_A, "prefill-1")

# Different engine.
r3 = _req(3)
r3.kv_transfer_params["remote_engine_id"] = "engine-b"
s.on_new_request(r3)
assert len(s._heartbeat_by_engine) == 2


@pytest.mark.parametrize(
"make_req",
[
lambda: create_request(request_id=2, do_remote_decode=True),
lambda: create_request(request_id=3), # no kv_transfer_params
],
ids=["decode", "plain"],
)
def test_on_new_request_ignores_non_prefill(make_req):
s = _sched()
s.on_new_request(make_req())
assert len(s._heartbeat_by_engine) == 0


# ===================================================================
# Scheduler: _stop_heartbeat
# ===================================================================


def test_stop_heartbeat_partial_and_full():
"""Stop one of two reqs on same engine, then stop the other."""
s = _sched()
s.on_new_request(_req(1))
s.on_new_request(_req(2))

s._stop_heartbeat("id-1")
assert s._heartbeat_by_engine[_ENGINE_A].req_ids == {"prefill-2"}
assert "id-1" not in s._heartbeat_req_engine

s._stop_heartbeat("id-2")
assert len(s._heartbeat_by_engine) == 0
assert len(s._heartbeat_req_engine) == 0


# ===================================================================
# Scheduler: build_connector_meta throttling
# ===================================================================


def test_build_connector_meta_heartbeat_throttling():
# kv_lease_duration=30 => _heartbeat_interval = 30 // 6 = 5
s = _sched(kv_lease_duration=30)
s.on_new_request(_req(1))

# Ensure the first call triggers by placing last_heartbeat far in the past.
s._last_heartbeat_time = time.perf_counter() - 10
meta1 = s.build_connector_meta(MagicMock())
assert _ENGINE_A in meta1.heartbeat_by_engine

# Immediate second call is throttled (< 5s since last).
meta2 = s.build_connector_meta(MagicMock())
assert len(meta2.heartbeat_by_engine) == 0


# ===================================================================
# Scheduler: cleanup paths (update_connector_output / request_finished)
# ===================================================================


def test_update_connector_output_stops_heartbeat():
s = _sched()
s.on_new_request(_req(1))

s.update_connector_output(
KVConnectorOutput(
finished_sending=None,
finished_recving={"id-1"},
invalid_block_ids=set(),
)
)

assert len(s._heartbeat_by_engine) == 0
assert len(s._heartbeat_req_engine) == 0


def test_request_finished_stops_heartbeat():
s = _sched()
r = _req(1)
s.on_new_request(r)

# Simulate update_state_after_alloc having consumed do_remote_prefill.
r.kv_transfer_params["do_remote_prefill"] = False
s.request_finished(r, block_ids=())

assert len(s._heartbeat_by_engine) == 0
assert len(s._heartbeat_req_engine) == 0


# ===================================================================
# Worker: _handle_heartbeat
# ===================================================================


def test_handle_heartbeat():
w = _worker_stub()
far_future = time.perf_counter() + 99999
w._reqs_to_send = {"req-a": 100.0, "req-b": far_future}

before = time.perf_counter()
w._handle_heartbeat("req-a,req-b,req-unknown")

# req-a: pushed forward to ~now+20.
assert w._reqs_to_send["req-a"] >= before + 20
# req-b: already far out, max() keeps it.
assert w._reqs_to_send["req-b"] >= far_future
# req-unknown: not added.
assert "req-unknown" not in w._reqs_to_send
30 changes: 28 additions & 2 deletions tests/v1/kv_connector/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def create_request(
remote_block_ids=list(range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
tp_size=1,
)

max_tokens = 1 if do_remote_decode else max_tokens
Expand Down Expand Up @@ -477,10 +478,16 @@ def make_kv_cache_config(
)


def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False):
def make_nixl_scheduler(
has_mamba: bool = False,
is_hma_required: bool = False,
heartbeat: bool = False,
kv_lease_duration: int = 30,
):
"""Create a NixlConnectorScheduler via __new__ (skipping __init__).

Only sets the two flags needed by the N-1 prefill logic.
Only sets the flags needed by the tests. When *heartbeat=True* the
scheduler-side heartbeat bookkeeping fields are also initialised.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import (
NixlConnectorScheduler,
Expand All @@ -489,4 +496,23 @@ def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False):
sched = object.__new__(NixlConnectorScheduler)
sched._has_mamba = has_mamba
sched._is_hma_required = is_hma_required

if heartbeat:
sched._heartbeat_by_engine = {}
sched._heartbeat_req_engine = {}
sched._last_heartbeat_time = 0.0
sched._kv_lease_duration = kv_lease_duration
sched._heartbeat_interval = kv_lease_duration // 6
# Fields touched by build_connector_meta / request_finished:
sched._reqs_need_recv = {}
sched._reqs_need_send = {}
sched._reqs_in_batch = set()
sched._reqs_not_processed = set()
sched._reqs_need_save = {}
sched.use_host_buffer = False
sched.engine_id = "test-engine"
sched.side_channel_host = "localhost"
sched.side_channel_port = 5555
sched.blocks_per_sw = []
sched.is_bidirectional_kv_xfer_enabled = False
return sched
8 changes: 8 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,14 @@ def build_connector_meta(
"""
pass

def on_new_request(self, request: "Request") -> None:
"""Called by the scheduler when a new request is added.

Connectors can override this to inspect the request and perform
bookkeeping. The default implementation is a no-op.
"""
return

def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ def update_state_after_alloc(
# Call with empty blocks for other connectors.
c.update_state_after_alloc(request, empty_blocks, 0)

def on_new_request(self, request: "Request") -> None:
for c in self._connectors:
c.on_new_request(request)

def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> MultiKVConnectorMetadata:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import MambaSpec
from vllm.v1.outputs import KVConnectorOutput

if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
Expand Down Expand Up @@ -156,6 +157,14 @@ def build_connector_meta(
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)

def on_new_request(self, request: "Request") -> None:
assert self.connector_scheduler is not None
self.connector_scheduler.on_new_request(request)

def update_connector_output(self, connector_output: KVConnectorOutput):
assert self.connector_scheduler is not None
self.connector_scheduler.update_connector_output(connector_output)

def request_finished(
self,
request: "Request",
Expand Down
Loading
Loading