diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index d072c03c440b..defe802a078e 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -132,7 +132,7 @@ async def _run_prefill( payload: dict, headers: dict[str, str], request_id: str, - ): + ) -> dict: url = f"{PREFILL_BASE}{request_path}" start_ts = time.perf_counter() logger.info("[prefill] start request_id=%s url=%s", request_id, url) @@ -146,13 +146,14 @@ async def _run_prefill( raise RuntimeError( f"Prefill backend error {resp.status}: {error_text}" ) - await resp.read() + response_data = await resp.json() logger.info( "[prefill] done request_id=%s status=%s elapsed=%.2fs", request_id, resp.status, time.perf_counter() - start_ts, ) + return response_data except asyncio.TimeoutError as exc: raise RuntimeError(f"Prefill service timeout at {url}") from exc except aiohttp.ClientError as exc: @@ -203,29 +204,31 @@ async def process_request(): try: original_request_data = await request.get_json() - # Create prefill request (max_tokens=1) prefill_request = original_request_data.copy() prefill_request["max_tokens"] = 1 + prefill_request["stream"] = False if "max_completion_tokens" in prefill_request: prefill_request["max_completion_tokens"] = 1 + prefill_request["kv_transfer_params"] = { + "remote_kv_addr": DECODE_KV_ADDR, + } - # Execute prefill stage - # The request id encodes both KV socket addresses so the backend can - # shuttle tensors directly via NCCL once the prefill response - # completes. - request_id = ( - f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_" - f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}" - ) + request_id = str(uuid.uuid4()) headers = _build_headers(request_id) - await _run_prefill(request.path, prefill_request, headers, request_id) + prefill_response = await _run_prefill( + request.path, prefill_request, headers, request_id + ) + + kv_transfer_params = prefill_response.get("kv_transfer_params", {}) + logger.info("[proxy] kv_transfer_params: %s", kv_transfer_params) + + decode_request = original_request_data.copy() + if kv_transfer_params: + decode_request["kv_transfer_params"] = kv_transfer_params - # Execute decode stage and stream response - # Pass the unmodified user request so the decode phase can continue - # sampling with the already-populated KV cache. generator = _stream_decode( - request.path, original_request_data, headers, request_id + request.path, decode_request, headers, request_id ) response = await make_response(generator) response.timeout = None # Disable timeout for streaming response diff --git a/docs/design/p2p_nccl_connector.md b/docs/design/p2p_nccl_connector.md index 4674bef8d2b6..2a2590571a0f 100644 --- a/docs/design/p2p_nccl_connector.md +++ b/docs/design/p2p_nccl_connector.md @@ -9,9 +9,9 @@ An implementation of xPyD with dynamic scaling based on point-to-point communica As shown in Figure 1, the overall process of this **PD disaggregation** solution is described through a request flow: 1. The client sends an HTTP request to the Proxy/Router's `/v1/completions` interface. -2. The Proxy/Router selects a **1P1D (1 Prefill instance + 1 Decode instance)** through either through round-robin or random selection, generates a `request_id` (rules to be introduced later), modifies the `max_tokens` in the HTTP request message to **1**, and then forwards the request to the **P instance**. -3. Immediately afterward, the Proxy/Router forwards the **original HTTP request** to the **D instance**. -4. The **P instance** performs **Prefill** and then **actively sends the generated KV cache** to the D instance (using **PUT_ASYNC** mode). The D instance's `zmq_addr` can be resolved through the `request_id`. +2. The Proxy/Router selects a **1P1D (1 Prefill instance + 1 Decode instance)** through either round-robin or random selection, generates an `request_id`, modifies the `max_tokens` in the HTTP request message to **1**, disables streaming, injects `kv_transfer_params` containing the D instance's KV address, and then forwards the request to the **P instance**. +3. The Proxy/Router waits for the P instance's response, extracts the returned `kv_transfer_params` (containing the P instance's `request_id` and KV address), and forwards them along with the **original HTTP request** to the **D instance**. +4. The **P instance** performs **Prefill** and then **actively sends the generated KV cache** to the D instance (using **PUT_ASYNC** mode). The D instance's KV address is provided via `kv_transfer_params`. 5. The **D instance** has a **dedicated thread** for receiving the KV cache (to avoid blocking the main process). The received KV cache is saved into the **GPU memory buffer**, the size of which is determined by the vLLM startup parameter `kv_buffer_size`. When the GPU buffer is full, the KV cache is stored in the **local Tensor memory pool**. 6. During the **Decode**, the D instance's main process retrieves the KV cache (transmitted by the P instance) from either the **GPU buffer** or the **memory pool**, thereby **skipping Prefill**. 7. After completing **Decode**, the D instance returns the result to the **Proxy/Router**, which then forwards it to the **client**. @@ -22,11 +22,11 @@ As shown in Figure 1, the overall process of this **PD disaggregation** solution A simple HTTP service acts as the entry point for client requests and starts a background thread to listen for P/D instances reporting their HTTP IP and PORT, as well as ZMQ IP and PORT. It maintains a dictionary of `http_addr -> zmq_addr`. The `http_addr` is the IP:PORT for the vLLM instance's request, while the `zmq_addr` is the address for KV cache handshake and metadata reception. -The Proxy/Router is responsible for selecting 1P1D based on the characteristics of the client request, such as the prompt, and generating a corresponding `request_id`, for example: +The Proxy/Router is responsible for selecting 1P1D based on the characteristics of the client request and coordinating the two-phase handshake via `kv_transfer_params`: -```text -cmpl-___prefill_addr_10.0.1.2:21001___decode_addr_10.0.1.3:22001_93923d63113b4b338973f24d19d4bf11-0 -``` +1. **Prefill request**: The proxy generates a UUID `request_id` and injects `kv_transfer_params` containing the D instance's KV address (`remote_kv_addr`) into the request body. Streaming is disabled so the proxy can read the full JSON response. +2. **Prefill response**: The P instance's completion response includes `kv_transfer_params` with its `request_id` and KV address, which the proxy extracts from the JSON body. +3. **Decode request**: The proxy forwards the prefill's `kv_transfer_params` to the D instance, which uses them to coordinate the KV cache transfer. Currently, to quickly verify whether xPyD can work, a round-robin selection of 1P1D is used. In the future, it is planned to use a trie combined with the load status of instances to select appropriate P and D. diff --git a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py index 0c7d32d7862e..9275e19be336 100644 --- a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py +++ b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py @@ -104,20 +104,34 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) -async def forward_request(url, data, request_id): - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id, - } - async with session.post(url=url, json=data, headers=headers) as response: - if response.status == 200: - if True: - async for chunk_bytes in response.content.iter_chunked(1024): - yield chunk_bytes - else: - content = await response.read() - yield content +def _build_headers(request_id): + headers = {"X-Request-Id": request_id} + api_key = os.environ.get("OPENAI_API_KEY") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + + +async def forward_request(url, data, headers): + async with ( + aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session, + session.post(url=url, json=data, headers=headers) as response, + ): + if response.status == 200: + async for chunk_bytes in response.content.iter_chunked(1024): + yield chunk_bytes + + +async def run_prefill(url, data, headers): + async with ( + aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session, + session.post(url=url, json=data, headers=headers) as response, + ): + if response.status == 200: + return await response.json() + raise RuntimeError( + f"Prefill backend error {response.status}: {await response.text()}" + ) @app.route("/v1/completions", methods=["POST"]) @@ -154,20 +168,28 @@ async def handle_request(): ) count += 1 - request_id = ( - f"___prefill_addr_{prefill_zmq_addr}___decode_addr_" - f"{decode_zmq_addr}_{random_uuid()}" - ) + request_id = random_uuid() + headers = _build_headers(request_id) + + prefill_request["stream"] = False + prefill_request["kv_transfer_params"] = { + "remote_kv_addr": decode_zmq_addr, + } # finish prefill - async for _ in forward_request( - f"http://{prefill_addr}{request.path}", prefill_request, request_id - ): - continue + prefill_response = await run_prefill( + f"http://{prefill_addr}{request.path}", prefill_request, headers + ) + + # forward kv_transfer_params from prefill to decode + kv_transfer_params = prefill_response.get("kv_transfer_params", {}) + decode_request = original_request_data.copy() + if kv_transfer_params: + decode_request["kv_transfer_params"] = kv_transfer_params # return decode generator = forward_request( - f"http://{decode_addr}{request.path}", original_request_data, request_id + f"http://{decode_addr}{request.path}", decode_request, headers ) response = await make_response(generator) response.timeout = None diff --git a/tests/v1/kv_connector/unit/test_p2p_nccl_connector.py b/tests/v1/kv_connector/unit/test_p2p_nccl_connector.py new file mode 100644 index 000000000000..455284895d59 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_p2p_nccl_connector.py @@ -0,0 +1,337 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for P2pNcclConnector's kv_transfer_params flow. + +Tests the scheduler-side contract without GPU, NCCL, or distributed init. +""" + +import pytest + +from vllm import SamplingParams +from vllm.config import ( + AttentionConfig, + CacheConfig, + DeviceConfig, + KVTransferConfig, + ModelConfig, + SchedulerConfig, + VllmConfig, +) +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector import ( + P2pNcclConnector, +) +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.kv_cache_utils import KVCacheBlock +from vllm.v1.core.sched.output import ( + CachedRequestData, + NewRequestData, + SchedulerOutput, +) +from vllm.v1.request import Request + +pytestmark = pytest.mark.cpu_test + +BLOCK_SIZE = 16 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_connector(kv_role: str) -> P2pNcclConnector: + """Build a scheduler-side P2pNcclConnector.""" + model_config = ModelConfig( + model="facebook/opt-125m", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + config = VllmConfig( + model_config=model_config, + scheduler_config=SchedulerConfig( + max_num_seqs=16, + max_num_batched_tokens=64, + max_model_len=1024, + enable_chunked_prefill=True, + is_encoder_decoder=model_config.is_encoder_decoder, + ), + cache_config=CacheConfig( + block_size=BLOCK_SIZE, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + ), + kv_transfer_config=KVTransferConfig( + kv_connector="P2pNcclConnector", + kv_role=kv_role, + kv_port="14579", + ), + device_config=DeviceConfig("cpu"), + attention_config=AttentionConfig(), + ) + return P2pNcclConnector(config, KVConnectorRole.SCHEDULER) + + +def _make_request( + request_id: str, + num_tokens: int = 10, + kv_transfer_params: dict | None = None, +) -> Request: + req = Request( + request_id=request_id, + prompt_token_ids=list(range(num_tokens)), + sampling_params=SamplingParams(max_tokens=16), + pooling_params=None, + eos_token_id=50256, + ) + req.kv_transfer_params = kv_transfer_params + return req + + +def _make_blocks(block_ids: list[int]) -> KVCacheBlocks: + blocks = tuple([KVCacheBlock(block_id=bid) for bid in block_ids]) + return KVCacheBlocks(blocks=(blocks,)) + + +def _make_scheduler_output( + new_reqs: list[NewRequestData] | None = None, + num_scheduled_tokens: dict[str, int] | None = None, + cached_req_ids: list[str] | None = None, + cached_num_computed: list[int] | None = None, + cached_new_block_ids: list | None = None, + resumed_req_ids: set[str] | None = None, +) -> SchedulerOutput: + cached = CachedRequestData( + req_ids=cached_req_ids or [], + resumed_req_ids=resumed_req_ids or set(), + new_token_ids=[[] for _ in (cached_req_ids or [])], + all_token_ids={}, + new_block_ids=cached_new_block_ids or [None] * len(cached_req_ids or []), + num_computed_tokens=cached_num_computed or [], + num_output_tokens=[0] * len(cached_req_ids or []), + ) + return SchedulerOutput( + scheduled_new_reqs=new_reqs or [], + scheduled_cached_reqs=cached, + num_scheduled_tokens=num_scheduled_tokens or {}, + total_num_scheduled_tokens=sum((num_scheduled_tokens or {}).values()), + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[], + finished_req_ids=set(), + free_encoder_mm_hashes=[], + ) + + +def _new_req( + req_id: str, + block_ids: list[int], + prompt_token_ids: list[int], + num_computed_tokens: int = 0, +) -> NewRequestData: + return NewRequestData( + req_id=req_id, + prompt_token_ids=prompt_token_ids, + mm_features=[], + sampling_params=SamplingParams(max_tokens=16), + pooling_params=None, + block_ids=(block_ids,), + num_computed_tokens=num_computed_tokens, + lora_request=None, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_producer_consumer_handoff(): + """The params returned by the producer must produce correct consumer + metadata -- this is the core contract of the kv_transfer_params design.""" + + producer = _make_connector("kv_producer") + consumer = _make_connector("kv_consumer") + + # -- Producer side -- + prod_req = _make_request( + "prefill-42", + num_tokens=10, + kv_transfer_params={"remote_kv_addr": "10.0.1.3:22001"}, + ) + producer.update_state_after_alloc( + prod_req, _make_blocks([0, 1]), num_external_tokens=0 + ) + + tokens = list(range(10)) + prod_sched = _make_scheduler_output( + new_reqs=[_new_req("prefill-42", [0, 1], tokens)], + num_scheduled_tokens={"prefill-42": 10}, + ) + prod_meta = producer.build_connector_meta(prod_sched) + assert len(prod_meta.requests) == 1 + assert prod_meta.requests[0].remote_kv_addr == "10.0.1.3:22001" + + # Producer finishes -> returns params (these go through the proxy) + _, params = producer.request_finished(prod_req, block_ids=[0, 1]) + assert params is not None + assert params["remote_request_id"] == "prefill-42" + assert params["remote_kv_addr"] == producer._kv_addr + + # -- Consumer side (receives params via proxy) -- + cons_req = _make_request("decode-99", num_tokens=10, kv_transfer_params=params) + consumer.update_state_after_alloc( + cons_req, _make_blocks([5, 6]), num_external_tokens=9 + ) + + cons_sched = _make_scheduler_output( + new_reqs=[_new_req("decode-99", [5, 6], tokens)], + num_scheduled_tokens={"decode-99": 10}, + ) + cons_meta = consumer.build_connector_meta(cons_sched) + + # Consumer metadata must reference producer's request_id + assert len(cons_meta.requests) == 1 + assert cons_meta.requests[0].request_id == "prefill-42" + assert cons_meta.requests[0].local_request_id == "decode-99" + assert cons_meta.requests[0].remote_kv_addr == producer._kv_addr + + +def test_two_step_chunked_prefill(): + """A prompt too large for one scheduling step must be handled + across two steps without losing the remote address.""" + + producer = _make_connector("kv_producer") + + tokens = list(range(20)) # 20 tokens + req = _make_request( + "req-chunk", + num_tokens=20, + kv_transfer_params={"remote_kv_addr": "10.0.1.3:22001"}, + ) + producer.update_state_after_alloc(req, _make_blocks([0, 1]), num_external_tokens=0) + + # Step 1: only 8 of 20 tokens scheduled -> chunked + step1 = _make_scheduler_output( + new_reqs=[_new_req("req-chunk", [0, 1], tokens, num_computed_tokens=0)], + num_scheduled_tokens={"req-chunk": 8}, + ) + meta1 = producer.build_connector_meta(step1) + # Not ready yet -- should produce no metadata + assert len(meta1.requests) == 0 + assert "req-chunk" in producer.chunked_prefill + + # Step 2: remaining 12 tokens scheduled as cached req + step2 = _make_scheduler_output( + cached_req_ids=["req-chunk"], + cached_num_computed=[8], + cached_new_block_ids=[([2, 3],)], + num_scheduled_tokens={"req-chunk": 12}, + ) + meta2 = producer.build_connector_meta(step2) + + # Now the full prompt is prefilled -> metadata emitted + assert len(meta2.requests) == 1 + assert meta2.requests[0].request_id == "req-chunk" + assert meta2.requests[0].remote_kv_addr == "10.0.1.3:22001" + # Block IDs accumulated across both steps + assert list(meta2.requests[0].block_ids.numpy()) == [0, 1, 2, 3] + # Chunked state cleaned up + assert "req-chunk" not in producer.chunked_prefill + + +def test_non_disagg_request_skipped(): + """A request arriving at the producer without kv_transfer_params + must be silently skipped, not crash in _get_remote_kv_addr.""" + + producer = _make_connector("kv_producer") + + # Request has no kv_transfer_params -> not in _requests_need_save + tokens = list(range(10)) + sched = _make_scheduler_output( + new_reqs=[_new_req("plain-req", [0, 1], tokens)], + num_scheduled_tokens={"plain-req": 10}, + ) + meta = producer.build_connector_meta(sched) + assert len(meta.requests) == 0 + + +def test_mixed_disagg_and_plain_requests(): + """Only disagg requests produce metadata; non-disagg are skipped.""" + + producer = _make_connector("kv_producer") + + disagg_req = _make_request( + "disagg-1", + num_tokens=10, + kv_transfer_params={"remote_kv_addr": "10.0.1.3:22001"}, + ) + producer.update_state_after_alloc( + disagg_req, _make_blocks([0]), num_external_tokens=0 + ) + + tokens = list(range(10)) + sched = _make_scheduler_output( + new_reqs=[ + _new_req("plain-1", [1], tokens), + _new_req("disagg-1", [0], tokens), + _new_req("plain-2", [2], tokens), + ], + num_scheduled_tokens={ + "plain-1": 10, + "disagg-1": 10, + "plain-2": 10, + }, + ) + meta = producer.build_connector_meta(sched) + assert len(meta.requests) == 1 + assert meta.requests[0].request_id == "disagg-1" + + +def test_request_finished_cleanup(): + """request_finished must clean up internal state and return the right + kv_transfer_params for the proxy to forward.""" + + producer = _make_connector("kv_producer") + + req = _make_request( + "req-1", + num_tokens=10, + kv_transfer_params={"remote_kv_addr": "10.0.1.3:22001"}, + ) + producer.update_state_after_alloc(req, _make_blocks([0, 1]), num_external_tokens=0) + assert "req-1" in producer._requests_need_save + + delay, params = producer.request_finished(req, block_ids=[0, 1]) + + # Producer returns False (no async send at scheduler level) + assert delay is False + # Params contain what the proxy needs + assert params == { + "remote_request_id": "req-1", + "remote_kv_addr": producer._kv_addr, + } + # Internal state cleaned up + assert "req-1" not in producer._requests_need_save + + +def test_consumer_request_finished_returns_no_params(): + """Consumer's request_finished should return no kv_transfer_params.""" + + consumer = _make_connector("kv_consumer") + + req = _make_request( + "decode-1", + num_tokens=10, + kv_transfer_params={ + "remote_request_id": "prefill-1", + "remote_kv_addr": "10.0.0.1:14579", + }, + ) + consumer.update_state_after_alloc(req, _make_blocks([0, 1]), num_external_tokens=9) + + delay, params = consumer.request_finished(req, block_ids=[0, 1]) + assert delay is False + assert params is None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 3be1be18e534..84f1254d02ab 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any -import regex as re import torch from vllm.config import VllmConfig @@ -19,6 +18,7 @@ from vllm.distributed.parallel_state import get_world_group from vllm.logger import init_logger from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata +from vllm.utils.network_utils import get_ip from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput @@ -39,16 +39,27 @@ class ReqMeta: block_ids: torch.Tensor # Request num tokens num_tokens: int + # Remote side's base KV address (no rank offset) + remote_kv_addr: str = "" + # Decode's own request_id, for ID translation in get_finished + local_request_id: str = "" @staticmethod def make_meta( - request_id: str, token_ids: list[int], block_ids: list[int], block_size: int + request_id: str, + token_ids: list[int], + block_ids: list[int], + block_size: int, + remote_kv_addr: str = "", + local_request_id: str = "", ) -> "ReqMeta": block_ids_tensor = torch.tensor(block_ids) return ReqMeta( request_id=request_id, block_ids=block_ids_tensor, num_tokens=len(token_ids), + remote_kv_addr=remote_kv_addr, + local_request_id=local_request_id, ) @@ -65,9 +76,18 @@ def add_request( token_ids: list[int], block_ids: list[int], block_size: int, + remote_kv_addr: str = "", + local_request_id: str = "", ) -> None: self.requests.append( - ReqMeta.make_meta(request_id, token_ids, block_ids, block_size) + ReqMeta.make_meta( + request_id, + token_ids, + block_ids, + block_size, + remote_kv_addr, + local_request_id, + ) ) @@ -85,14 +105,21 @@ def __init__( ) self._block_size = vllm_config.cache_config.block_size self._requests_need_load: dict[str, Any] = {} + self._requests_need_save: dict[str, Any] = {} self.is_producer = self._kv_transfer_config.is_kv_producer self.chunked_prefill: dict[str, tuple[list[int], list[int] | None]] = {} + self._kv_addr = f"{get_ip()}:{self._kv_transfer_config.kv_port}" + self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0 self._local_rank = ( get_world_group().local_rank if role == KVConnectorRole.WORKER else 0 ) + # Needed so P2pNcclEngine can access the prefill request_id for + # cleanup in get_finished + self._local_to_remote_id: dict[str, str] = {} + self.p2p_nccl_engine = ( P2pNcclEngine( local_rank=self._local_rank, @@ -108,6 +135,11 @@ def __init__( # Worker-side methods # ============================== + def _resolve_remote_address(self, base_addr: str) -> str: + """Add rank offset to a base host:port address.""" + ip, base_port_str = base_addr.rsplit(":", 1) + return f"{ip}:{int(base_port_str) + self._rank}" + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. @@ -201,9 +233,11 @@ def inject_kv_into_layer( # Load the KV for each request each layer for request in metadata.requests: - request_id = request.request_id - ip, port = self.parse_request_id(request_id, False) - remote_address = ip + ":" + str(port + self._rank) + remote_address = self._resolve_remote_address(request.remote_kv_addr) + + if request.local_request_id: + self._local_to_remote_id[request.local_request_id] = request.request_id + for layer_name in forward_context.no_compile_layers: layer = forward_context.no_compile_layers[layer_name] @@ -297,13 +331,13 @@ def extract_kv_from_layer( connector_metadata = self._get_connector_metadata() assert isinstance(connector_metadata, P2pNcclConnectorMetadata) for request in connector_metadata.requests: - request_id = request.request_id - ip, port = self.parse_request_id(request_id, True) - remote_address = ip + ":" + str(port + self._rank) + remote_address = self._resolve_remote_address(request.remote_kv_addr) kv_cache = extract_kv_from_layer(kv_layer, request.block_ids) self.p2p_nccl_engine.send_tensor( - request_id + "#" + layer_name, kv_cache, remote_address + request.request_id + "#" + layer_name, + kv_cache, + remote_address, ) def wait_for_save(self): @@ -327,6 +361,14 @@ def get_finished( assert self.p2p_nccl_engine is not None + # Decode uses local req IDs but the engine keys on prefill's req IDs + if not self.is_producer: + translated_ids = set() + for req_id in finished_req_ids: + remote_id = self._local_to_remote_id.pop(req_id, None) + translated_ids.add(remote_id if remote_id is not None else req_id) + finished_req_ids = translated_ids + no_compile_layers = self._vllm_config.compilation_config.static_forward_context return self.p2p_nccl_engine.get_finished(finished_req_ids, no_compile_layers) @@ -369,12 +411,24 @@ def update_state_after_alloc( """ Update KVConnector state after block allocation. """ - if not self.is_producer and num_external_tokens > 0: + if self.is_producer: + if request.kv_transfer_params: + self._requests_need_save[request.request_id] = request + elif num_external_tokens > 0: self._requests_need_load[request.request_id] = ( request, blocks.get_block_ids()[0], ) + def _get_remote_kv_addr(self, req_id: str) -> str: + """Look up the remote KV address from stored request params.""" + req = self._requests_need_save.get(req_id) + if not req or not req.kv_transfer_params: + return "" + addr = req.kv_transfer_params.get("remote_kv_addr", "") + assert addr, f"kv_transfer_params for {req_id} missing 'remote_kv_addr'" + return addr + def build_connector_meta( self, scheduler_output: SchedulerOutput, @@ -392,6 +446,9 @@ def build_connector_meta( for new_req in scheduler_output.scheduled_new_reqs: if self.is_producer: + if new_req.req_id not in self._requests_need_save: + continue + remote_kv_addr = self._get_remote_kv_addr(new_req.req_id) num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[ new_req.req_id ] @@ -410,16 +467,26 @@ def build_connector_meta( token_ids=new_req.prompt_token_ids or [], block_ids=new_req.block_ids[0], block_size=self._block_size, + remote_kv_addr=remote_kv_addr, ) continue + if new_req.req_id in self._requests_need_load: + request, _ = self._requests_need_load.pop(new_req.req_id) + assert request.kv_transfer_params is not None, ( + f"Consumer request {new_req.req_id} missing kv_transfer_params" + ) + kv_params = request.kv_transfer_params + remote_request_id = kv_params["remote_request_id"] + remote_kv_addr = kv_params["remote_kv_addr"] meta.add_request( - request_id=new_req.req_id, + request_id=remote_request_id, token_ids=new_req.prompt_token_ids or [], block_ids=new_req.block_ids[0], block_size=self._block_size, + remote_kv_addr=remote_kv_addr, + local_request_id=new_req.req_id, ) - self._requests_need_load.pop(new_req.req_id) cached_reqs = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(cached_reqs.req_ids): @@ -442,11 +509,13 @@ def build_connector_meta( self.chunked_prefill[req_id] = (block_ids, prompt_token_ids) continue # the request's prompt is all prefilled finally + remote_kv_addr = self._get_remote_kv_addr(req_id) meta.add_request( request_id=req_id, token_ids=prompt_token_ids, block_ids=block_ids, block_size=self._block_size, + remote_kv_addr=remote_kv_addr, ) self.chunked_prefill.pop(req_id, None) continue @@ -465,11 +534,19 @@ def build_connector_meta( assert new_block_ids is not None block_ids = new_block_ids[0] + assert request.kv_transfer_params is not None, ( + f"Consumer request {req_id} missing kv_transfer_params" + ) + kv_params = request.kv_transfer_params + remote_request_id = kv_params["remote_request_id"] + remote_kv_addr = kv_params["remote_kv_addr"] meta.add_request( - request_id=req_id, + request_id=remote_request_id, token_ids=token_ids, block_ids=block_ids, block_size=self._block_size, + remote_kv_addr=remote_kv_addr, + local_request_id=req_id, ) self._requests_need_load.clear() @@ -492,6 +569,13 @@ def request_finished( """ self.chunked_prefill.pop(request.request_id, None) + self._requests_need_save.pop(request.request_id, None) + + if self.is_producer: + return False, { + "remote_request_id": request.request_id, + "remote_kv_addr": self._kv_addr, + } return False, None @@ -499,24 +583,6 @@ def request_finished( # Static methods # ============================== - @staticmethod - def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]: - # Regular expression to match the string hostname and integer port - if is_prefill: - pattern = r"___decode_addr_(.*):(\d+)" - else: - pattern = r"___prefill_addr_(.*):(\d+)___" - - # Use re.search to find the pattern in the request_id - match = re.search(pattern, request_id) - if match: - # Extract the ranks - ip = match.group(1) - port = int(match.group(2)) - - return ip, port - raise ValueError(f"Request id {request_id} does not contain hostname and port") - @staticmethod def check_tensors_except_dim(tensor1, tensor2, dim): shape1 = tensor1.size()