diff --git a/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py b/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py new file mode 100644 index 000000000000..f55ac87d03a9 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py @@ -0,0 +1,612 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for bi-directional KV cache transfer between P and D nodes. + +Tests cover the new behaviors added by the bi-directional KV transfer PR: +1. P-node scheduler lifecycle: P pulls KV from D using remote_block_ids, + eliminating redundant prefill computation in multi-turn conversations. +2. P-node metadata: NixlConnectorMetadata correctly populates recv metadata + when P pulls KV from D (do_remote_decode=True + remote_block_ids). +3. P-node worker: start_load_kv processes reqs_to_recv for KV pull from D. +4. D-node request_finished: returns kv_transfer_params with remote_block_ids + and remote_num_tokens so P can pull KV in future turns. + +P-node flags: do_remote_prefill=False (prefill locally), +do_remote_decode=True (don't decode locally, send KV to D). +P pulls KV from D when remote_block_ids is not None and +external tokens > 0. +""" + +import copy +import time +from unittest.mock import patch + +import pytest + +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector import ( + NixlConnector, + NixlConnectorMetadata, +) +from vllm.forward_context import ForwardContext +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + KVConnectorOutput, +) +from vllm.v1.request import RequestStatus + +from .test_nixl_connector import FakeNixlConnectorWorker, FakeNixlWrapper +from .utils import ( + assert_scheduler_empty, + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, + make_kv_cache_config, +) + +pytestmark = pytest.mark.cpu_test + +# Common extra config for all bi-directional KV transfer tests. +BIDIR_KV_EXTRA_CONFIG = {"bidirectional_kv_xfer": True} + + +# ----------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------- + + +def _make_p_node_turn2_request( + request_id, block_size, num_tokens, num_remote_blocks=3, remote_num_tokens=None +): + """Create a P-node Turn 2 request with remote_block_ids from D.""" + request = create_request( + request_id=request_id, + block_size=block_size, + num_tokens=num_tokens, + do_remote_decode=True, + ) + if remote_num_tokens is None: + remote_num_tokens = num_remote_blocks * block_size + request.kv_transfer_params["remote_block_ids"] = [list(range(num_remote_blocks))] + request.kv_transfer_params["remote_num_tokens"] = remote_num_tokens + request.kv_transfer_params["remote_engine_id"] = "decode-engine" + request.kv_transfer_params["remote_request_id"] = f"decode-{request_id}" + request.kv_transfer_params["remote_host"] = "decode-host" + request.kv_transfer_params["remote_port"] = 5678 + return request + + +def _make_connector_with_fake_worker( + hand_shake_latency=0, cycles_before_done=0, do_handshake=True +): + """Create a NixlConnector with FakeNixlConnectorWorker.""" + vllm_config = create_vllm_config() + kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2) + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, + connector.engine_id, + hand_shake_latency=hand_shake_latency, + kv_cache_config=kv_cache_config, + ) + worker = connector.connector_worker + assert isinstance(worker.nixl_wrapper, FakeNixlWrapper) + worker.nixl_wrapper.set_cycles_before_xfer_done(cycles_before_done) + worker.kv_cache_layout = "HND" + if do_handshake: + remote_agents = worker._nixl_handshake( + host="localhost", + port=1234, + remote_tp_size=1, + expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + ) + worker._remote_agents[FakeNixlConnectorWorker.REMOTE_ENGINE_ID] = remote_agents + return connector, worker + + +def _make_p_node_recv_metadata(request_id, local_blocks, remote_blocks): + """Build NixlConnectorMetadata for P-node pulling KV from D.""" + meta = NixlConnectorMetadata() + meta.add_new_req_to_recv( + request_id=request_id, + local_block_ids=(local_blocks,), + kv_transfer_params={ + "do_remote_prefill": False, + "do_remote_decode": True, + "remote_block_ids": (remote_blocks,), + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_request_id": f"decode-{request_id}", + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": 1, + }, + ) + return meta + + +def _do_load_kv(connector, metadata): + """Bind metadata and call start_load_kv.""" + connector.bind_connector_metadata(metadata) + ctx = ForwardContext(no_compile_layers={}, attn_metadata={}, slot_mapping={}) + connector.start_load_kv(ctx) + + +# ----------------------------------------------------------------------- +# 1. P-node scheduler lifecycle tests +# ----------------------------------------------------------------------- + + +def test_multiturn_lifecycle(): + """Full two-turn lifecycle on the P node: + Turn 1: P prefills locally (do_remote_prefill=False), sends KV to D + (do_remote_decode=True). Finishes LENGTH_CAPPED with remote_block_ids. + Turn 2: P receives remote_block_ids from D. P pulls KV from D because + remote_block_ids is not None and external tokens > 0. Computes only + new tokens, finishes LENGTH_CAPPED.""" + vllm_config = create_vllm_config( + kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG, + ) + scheduler = create_scheduler(vllm_config) + BS = vllm_config.cache_config.block_size + + t1 = create_request( + request_id=100, block_size=BS, num_tokens=int(BS * 2.5), do_remote_decode=True + ) + scheduler.add_request(t1) + t1_id = t1.request_id + so = scheduler.schedule() + mro = create_model_runner_output(reqs=[t1]) + eco = scheduler.update_from_output(so, mro) + assert t1.status == RequestStatus.FINISHED_LENGTH_CAPPED + kv = eco[0].outputs[0].kv_transfer_params + assert kv and sum(len(g) for g in kv["remote_block_ids"]) > 0 + so = scheduler.schedule() + scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT) + + t2 = _make_p_node_turn2_request(200, BS, int(BS * 2.5)) + scheduler.add_request(t2) + t2_id = t2.request_id + so = scheduler.schedule() + assert t2.status == RequestStatus.WAITING_FOR_REMOTE_KVS + scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT) + so = scheduler.schedule() + mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + mro.kv_connector_output = KVConnectorOutput(finished_recving={t2_id}) + scheduler.update_from_output(so, mro) + so = scheduler.schedule() + mro = create_model_runner_output(reqs=[t2]) + scheduler.update_from_output(so, mro) + assert t2.status == RequestStatus.FINISHED_LENGTH_CAPPED + so = scheduler.schedule() + scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT) + so = scheduler.schedule() + mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + mro.kv_connector_output = KVConnectorOutput(finished_sending={t1_id, t2_id}) + scheduler.update_from_output(so, mro) + assert_scheduler_empty(scheduler) + + +def test_first_turn_no_remote_blocks(): + """First turn: P has no remote_block_ids from D yet. + Standard local prefill, returns kv_transfer_params for future turns.""" + vllm_config = create_vllm_config( + kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG, + ) + scheduler = create_scheduler(vllm_config) + BS = vllm_config.cache_config.block_size + req = create_request( + request_id=3, block_size=BS, num_tokens=int(BS * 2.5), do_remote_decode=True + ) + scheduler.add_request(req) + req_id = req.request_id + so = scheduler.schedule() + assert req.status != RequestStatus.WAITING_FOR_REMOTE_KVS + mro = create_model_runner_output(reqs=[req]) + eco = scheduler.update_from_output(so, mro) + assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED + assert eco[0].outputs[0].kv_transfer_params is not None + so = scheduler.schedule() + scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT) + so = scheduler.schedule() + mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id}) + scheduler.update_from_output(so, mro) + assert_scheduler_empty(scheduler) + + +def test_abort_p_side_during_send(): + """P-side do_remote_decode=True: blocks held until finished_sending.""" + vllm_config = create_vllm_config( + kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG, + ) + scheduler = create_scheduler(vllm_config) + BS = vllm_config.cache_config.block_size + req = create_request( + request_id=42, block_size=BS, num_tokens=int(BS * 2.5), do_remote_decode=True + ) + scheduler.add_request(req) + req_id = req.request_id + so = scheduler.schedule() + mro = create_model_runner_output(reqs=[req]) + scheduler.update_from_output(so, mro) + assert req_id in scheduler.requests + so = scheduler.schedule() + scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT) + assert req_id in scheduler.requests + so = scheduler.schedule() + mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id}) + scheduler.update_from_output(so, mro) + assert_scheduler_empty(scheduler) + + +def test_abort_p_side_non_length_capped(): + """P-side abort with non-LENGTH_CAPPED → immediate block free.""" + vllm_config = create_vllm_config( + kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG, + ) + scheduler = create_scheduler(vllm_config) + BS = vllm_config.cache_config.block_size + req = create_request( + request_id=44, block_size=BS, num_tokens=int(BS * 2.5), do_remote_decode=True + ) + req.sampling_params.max_tokens = 100 + req.max_tokens = 100 + scheduler.add_request(req) + req_id = req.request_id + so = scheduler.schedule() + mro = create_model_runner_output(reqs=[req]) + scheduler.update_from_output(so, mro) + scheduler.finish_requests([req_id], RequestStatus.FINISHED_ABORTED) + conn = scheduler.connector.connector_scheduler + assert req_id in conn._reqs_not_processed + assert req_id not in scheduler.requests + so = scheduler.schedule() + scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT) + assert_scheduler_empty(scheduler) + + +def test_remote_blocks_exceed_prompt_tokens(): + """D provides more remote tokens than P's prompt needs. + P caps external tokens to prompt length.""" + vllm_config = create_vllm_config( + kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG, + ) + scheduler = create_scheduler(vllm_config) + BS = vllm_config.cache_config.block_size + NUM_TOKENS = int(BS * 2.5) + req = _make_p_node_turn2_request( + 300, BS, NUM_TOKENS, num_remote_blocks=5, remote_num_tokens=5 * BS + ) + scheduler.add_request(req) + req_id = req.request_id + so = scheduler.schedule() + assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert req.num_computed_tokens == NUM_TOKENS + scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT) + so = scheduler.schedule() + mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + mro.kv_connector_output = KVConnectorOutput(finished_recving={req_id}) + scheduler.update_from_output(so, mro) + so = scheduler.schedule() + mro = create_model_runner_output(reqs=[req]) + scheduler.update_from_output(so, mro) + assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED + so = scheduler.schedule() + scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT) + so = scheduler.schedule() + mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id}) + scheduler.update_from_output(so, mro) + assert_scheduler_empty(scheduler) + + +def test_p_node_pulls_partial_last_block_from_d(): + """D sends remote_block_ids with partially filled last block. + remote_num_tokens < len(remote_block_ids) * block_size. + P pulls only remote_num_tokens worth of external tokens.""" + vllm_config = create_vllm_config( + kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG, + ) + scheduler = create_scheduler(vllm_config) + BS = vllm_config.cache_config.block_size + num_remote_blocks = 3 + remote_num_tokens = int(BS * 2.5) + assert remote_num_tokens < num_remote_blocks * BS + NUM_TOKENS = int(BS * 3.5) + req = _make_p_node_turn2_request( + 400, + BS, + NUM_TOKENS, + num_remote_blocks=num_remote_blocks, + remote_num_tokens=remote_num_tokens, + ) + scheduler.add_request(req) + req_id = req.request_id + so = scheduler.schedule() + assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS + scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT) + so = scheduler.schedule() + mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + mro.kv_connector_output = KVConnectorOutput(finished_recving={req_id}) + scheduler.update_from_output(so, mro) + so = scheduler.schedule() + assert len(scheduler.running) == 1 + mro = create_model_runner_output(reqs=[req]) + scheduler.update_from_output(so, mro) + assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED + so = scheduler.schedule() + scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT) + so = scheduler.schedule() + mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id}) + scheduler.update_from_output(so, mro) + assert_scheduler_empty(scheduler) + + +# ----------------------------------------------------------------------- +# 2. P-node metadata tests +# ----------------------------------------------------------------------- + + +def test_add_new_req_to_recv_populates_remote_meta(): + """add_new_req_to_recv correctly populates RemoteMeta for P-node + bi-directional KV pull from D.""" + meta = NixlConnectorMetadata() + kv_params = { + "remote_block_ids": [[0, 1, 2]], + "remote_engine_id": "decode-engine", + "remote_request_id": "decode-req-123", + "remote_host": "decode-host", + "remote_port": 5678, + } + local_block_ids = ([10, 11, 12],) + meta.add_new_req_to_recv( + request_id="test-req", + local_block_ids=local_block_ids, + kv_transfer_params=kv_params, + ) + assert "test-req" in meta.reqs_to_recv + rm = meta.reqs_to_recv["test-req"] + assert rm.remote is not None + assert rm.remote.block_ids == kv_params["remote_block_ids"] + assert rm.remote.engine_id == "decode-engine" + assert rm.remote.request_id == "decode-req-123" + assert rm.remote.host == "decode-host" + assert rm.remote.port == 5678 + assert rm.local_block_ids == local_block_ids + + +def test_build_connector_meta_recv_entries(): + """P-node scheduler: do_remote_decode=True + remote_block_ids → + _reqs_need_recv populated, build_connector_meta produces reqs_to_recv.""" + vllm_config = create_vllm_config( + kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG, + ) + scheduler = create_scheduler(vllm_config) + BS = vllm_config.cache_config.block_size + req = _make_p_node_turn2_request(1, BS, int(BS * 2.5)) + scheduler.add_request(req) + req_id = req.request_id + so = scheduler.schedule() + assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS + meta = so.kv_connector_metadata + assert isinstance(meta, NixlConnectorMetadata) + assert req_id in meta.reqs_to_recv + rm = meta.reqs_to_recv[req_id] + assert rm.remote is not None + assert rm.remote.engine_id == "decode-engine" + + +def test_build_connector_meta_clears_reqs_need_recv(): + """After build_connector_meta, _reqs_need_recv is cleared.""" + vllm_config = create_vllm_config( + kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG, + ) + scheduler = create_scheduler(vllm_config) + BS = vllm_config.cache_config.block_size + req = _make_p_node_turn2_request(2, BS, int(BS * 2.5)) + scheduler.add_request(req) + conn = scheduler.connector.connector_scheduler + scheduler.schedule() + assert len(conn._reqs_need_recv) == 0 + + +def test_build_connector_meta_multiple_requests(): + """Multiple P-node requests all included in reqs_to_recv.""" + vllm_config = create_vllm_config( + kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG, + ) + scheduler = create_scheduler(vllm_config) + BS = vllm_config.cache_config.block_size + reqs = [_make_p_node_turn2_request(10 + i, BS, int(BS * 2.5)) for i in range(3)] + for r in reqs: + scheduler.add_request(r) + so = scheduler.schedule() + meta = so.kv_connector_metadata + assert isinstance(meta, NixlConnectorMetadata) + assert len(meta.reqs_to_recv) == 3 + for r in reqs: + assert r.request_id in meta.reqs_to_recv + + +# ----------------------------------------------------------------------- +# 3. P-node worker tests (FakeNixlWrapper) +# ----------------------------------------------------------------------- + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper", + FakeNixlWrapper, +) +def test_p_node_pull_kv_from_d(dist_init): + """P node pulls KV from D via start_load_kv with reqs_to_recv.""" + connector, worker = _make_connector_with_fake_worker() + meta = _make_p_node_recv_metadata("req-p1", [10, 11, 12], [20, 21, 22]) + _do_load_kv(connector, meta) + assert "req-p1" in worker._recving_metadata + _, done_recving = connector.get_finished(finished_req_ids=set()) + assert "req-p1" in done_recving + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper", + FakeNixlWrapper, +) +def test_p_node_pull_then_send_kv(dist_init): + """Full P-node bi-directional: pull KV from D → prefill → + send KV back to D via notification.""" + connector, worker = _make_connector_with_fake_worker() + meta = _make_p_node_recv_metadata("req-p2", [10, 11], [20, 21]) + _do_load_kv(connector, meta) + _, done_recving = connector.get_finished(finished_req_ids=set()) + assert "req-p2" in done_recving + worker._reqs_to_send["req-p2"] = time.perf_counter() + 60 + worker._reqs_to_process.add("req-p2") + notif = f"req-p2:{worker.world_size}".encode() + orig = worker.nixl_wrapper.get_new_notifs + worker.nixl_wrapper.get_new_notifs = lambda: {"agent": [notif]} + done_sending, _ = connector.get_finished(finished_req_ids=set()) + assert "req-p2" in done_sending + worker.nixl_wrapper.get_new_notifs = orig + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper", + FakeNixlWrapper, +) +def test_p_node_deferred_pull_on_no_handshake(dist_init): + """P defers KV pull when no prior handshake exists.""" + connector, worker = _make_connector_with_fake_worker( + hand_shake_latency=0, do_handshake=False + ) + meta = _make_p_node_recv_metadata("req-p3", [10, 11], [20, 21]) + _do_load_kv(connector, meta) + assert "req-p3" in worker._recving_metadata + timeout = 3.0 + start = time.perf_counter() + while time.perf_counter() - start < timeout: + connector.bind_connector_metadata(NixlConnectorMetadata()) + ctx = ForwardContext(no_compile_layers={}, attn_metadata={}, slot_mapping={}) + connector.start_load_kv(ctx) + _, done = connector.get_finished(finished_req_ids=set()) + if "req-p3" in done: + return + time.sleep(0.2) + raise AssertionError("Transfer did not complete after async handshake") + + +# ----------------------------------------------------------------------- +# 4. D-node request_finished returns kv_transfer_params (new behavior) +# ----------------------------------------------------------------------- + + +def test_d_node_request_finished_returns_kv_params(): + """D-node request_finished returns kv_transfer_params with + do_remote_decode=True, remote_block_ids, remote_num_tokens + for P to pull. These params go directly to P node.""" + vllm_config = create_vllm_config( + kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG, + ) + scheduler = create_scheduler(vllm_config) + BS = vllm_config.cache_config.block_size + req = create_request( + request_id=1, block_size=BS, num_tokens=int(BS * 2.5), do_remote_prefill=True + ) + scheduler.add_request(req) + req_id = req.request_id + so = scheduler.schedule() + scheduler.update_from_output( + so, create_model_runner_output(reqs=[], finished_recving={req_id}) + ) + so = scheduler.schedule() + eco = scheduler.update_from_output( + so, create_model_runner_output(reqs=[req], use_eos=True) + ) + assert req.status == RequestStatus.FINISHED_STOPPED + kv = eco[0].outputs[0].kv_transfer_params + assert kv is not None + assert kv["do_remote_decode"] is True + assert kv["do_remote_prefill"] is False + assert "remote_block_ids" in kv + assert "remote_num_tokens" in kv + assert kv["remote_num_tokens"] > 0 + + +def test_d_node_request_finished_delays_block_free(): + """D-node holds blocks (delay_free=True) until P reads them.""" + vllm_config = create_vllm_config( + kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG, + ) + scheduler = create_scheduler(vllm_config) + BS = vllm_config.cache_config.block_size + req = create_request( + request_id=2, block_size=BS, num_tokens=int(BS * 2.5), do_remote_prefill=True + ) + scheduler.add_request(req) + req_id = req.request_id + so = scheduler.schedule() + scheduler.update_from_output( + so, create_model_runner_output(reqs=[], finished_recving={req_id}) + ) + so = scheduler.schedule() + scheduler.update_from_output( + so, create_model_runner_output(reqs=[req], use_eos=True) + ) + assert req_id in scheduler.requests + conn = scheduler.connector.connector_scheduler + assert req_id in conn._reqs_need_send + + +def test_d_node_request_finished_remote_num_tokens(): + """D-node kv_transfer_params includes correct remote_num_tokens.""" + vllm_config = create_vllm_config( + kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG, + ) + scheduler = create_scheduler(vllm_config) + BS = vllm_config.cache_config.block_size + req = create_request( + request_id=3, block_size=BS, num_tokens=int(BS * 2.5), do_remote_prefill=True + ) + scheduler.add_request(req) + req_id = req.request_id + so = scheduler.schedule() + scheduler.update_from_output( + so, create_model_runner_output(reqs=[], finished_recving={req_id}) + ) + so = scheduler.schedule() + eco = scheduler.update_from_output( + so, create_model_runner_output(reqs=[req], use_eos=True) + ) + kv = eco[0].outputs[0].kv_transfer_params + assert kv["remote_num_tokens"] > 0 + assert sum(len(g) for g in kv["remote_block_ids"]) > 0 + + +def test_d_node_partial_last_block_remote_num_tokens(): + """D-node: remote_num_tokens < len(remote_block_ids) * block_size + when last block is partially filled.""" + vllm_config = create_vllm_config( + kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG, + ) + scheduler = create_scheduler(vllm_config) + BS = vllm_config.cache_config.block_size + req = create_request( + request_id=5, block_size=BS, num_tokens=int(BS * 2.5), do_remote_prefill=True + ) + scheduler.add_request(req) + req_id = req.request_id + so = scheduler.schedule() + scheduler.update_from_output( + so, create_model_runner_output(reqs=[], finished_recving={req_id}) + ) + so = scheduler.schedule() + eco = scheduler.update_from_output( + so, create_model_runner_output(reqs=[req], use_eos=True) + ) + kv = eco[0].outputs[0].kv_transfer_params + total_blocks = sum(len(g) for g in kv["remote_block_ids"]) + assert total_blocks == 3 + assert kv["remote_num_tokens"] < total_blocks * BS + assert kv["remote_num_tokens"] > 0 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py index 9f67d0fc525d..289bd8798e1d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py @@ -119,6 +119,28 @@ def __init__( for n_tokens, block_size in sw_sizes_tokens ] + # Threshold to decide whether to compute kv cache locally + # or pull from a remote node: minimum number of remote + # tokens to amortize the xfer latencies + self.kv_recompute_threshold: int = int( + vllm_config.kv_transfer_config.get_from_extra_config( + "kv_recompute_threshold", 0 + ) + ) + if self.kv_recompute_threshold > 0: + logger.info( + "Remote pull threshold set to %d tokens", + self.kv_recompute_threshold, + ) + + # Bi-directional KV transfer feature supports KV block + # transfers from D node to P node + self.is_bidirectional_kv_xfer_enabled = ( + vllm_config.kv_transfer_config.get_from_extra_config( + "bidirectional_kv_xfer", False + ) + ) + def shutdown(self): self._stop_event.set() if self._nixl_handshake_listener_t is not None: @@ -298,6 +320,47 @@ def get_num_new_matched_tokens( if params is not None and params.get("do_remote_decode") and self._has_mamba: self._truncate_mamba_request_for_prefill(request) + if ( + params is not None + and params.get("do_remote_decode") + and params.get("remote_block_ids") + and all( + p in params for p in ("remote_engine_id", "remote_host", "remote_port") + ) + ): + # Decode node has kv blocks for part of prefill request, so, provide them + # as an external token count to scheduler. + # The tokens will be loaded if not already present + # in the prefill node local cache + max_external = len(request.prompt_token_ids or []) + remote_num_tokens = params.get("remote_num_tokens") or 0 + remote_block_ids = params.get("remote_block_ids") or [] + + # remote_block_ids is list[list[int]] — one list per KV cache group + # Use the largest group (full attention) to validate token count + if remote_block_ids and isinstance(remote_block_ids[0], list): + max_blocks = max(len(g) for g in remote_block_ids) + else: + max_blocks = len(remote_block_ids) + + assert remote_num_tokens <= max_blocks * self.block_size + count = min(remote_num_tokens, max_external) - num_computed_tokens + if count > 0: + # Check kv_recompute_threshold: skip pull if + # remote tokens are below the threshold. + if ( + self.kv_recompute_threshold > 0 + and count < self.kv_recompute_threshold + ): + logger.debug( + "Skipping remote pull for %s: %d remote tokens < threshold %d", + request.request_id, + count, + self.kv_recompute_threshold, + ) + return 0, False + return count, True + # No remote prefill for this request. return 0, False @@ -315,13 +378,17 @@ def update_state_after_alloc( if not params: return - if params.get("do_remote_decode"): + if params.get("do_remote_decode") or ( + params.get("do_remote_prefill") and self.is_bidirectional_kv_xfer_enabled + ): self._reqs_in_batch.add(request.request_id) if self.use_host_buffer and params.get("do_remote_decode"): # NOTE: when accelerator is not directly supported by Nixl, # prefilled blocks need to be saved to host memory before transfer. self._reqs_need_save[request.request_id] = request - elif params.get("do_remote_prefill"): + elif params.get("do_remote_prefill") or ( + params.get("do_remote_decode") and self.is_bidirectional_kv_xfer_enabled + ): if params.get("remote_block_ids"): if all( p in params @@ -333,8 +400,8 @@ def update_state_after_alloc( ) ): # If remote_blocks and num_external_tokens = 0, we have - # a full prefix cache hit on the D worker. We need to call - # send_notif in _read_blocks to free the memory on the P. + # a full prefix cache hit on the local node. We need to call + # send_notif in _read_blocks to free the memory on the remote node. unhashed_local_block_ids: BlockIds = ( blocks.get_unhashed_block_ids_all_groups() @@ -461,9 +528,15 @@ def request_finished( params["do_remote_prefill"] = False return False, None - if not params.get("do_remote_decode"): + is_p_node = bool(params.get("do_remote_decode")) + + if not is_p_node and not self.is_bidirectional_kv_xfer_enabled: return False, None - if request.status != RequestStatus.FINISHED_LENGTH_CAPPED: + + if request.status not in ( + RequestStatus.FINISHED_LENGTH_CAPPED, + RequestStatus.FINISHED_STOPPED, + ): # Also include the case of a P/D Prefill request with immediate # block free (eg abort). Stop tracking this request. self._reqs_not_processed.add(request.request_id) @@ -474,6 +547,7 @@ def request_finished( # TODO: check whether block_ids actually ever be 0. If not we could # remove the conditional below delay_free_blocks = any(len(group) > 0 for group in block_ids) + remote_num_tokens = 0 if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion @@ -492,13 +566,24 @@ def request_finished( # Here we "unpad" blocks to send the actual remote blocks to be read. block_ids = self.get_sw_clipped_blocks(block_ids) + remote_num_tokens = min( + request.num_tokens, + max( + len(g) * self.block_size + for g in ( + block_ids if isinstance(block_ids[0], list) else [block_ids] + ) + ), + ) + return delay_free_blocks, dict( - do_remote_prefill=True, - do_remote_decode=False, + do_remote_prefill=is_p_node, + do_remote_decode=not is_p_node, remote_block_ids=block_ids, remote_engine_id=self.engine_id, remote_request_id=request.request_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, tp_size=self.vllm_config.parallel_config.tensor_parallel_size, + remote_num_tokens=remote_num_tokens, )