Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async def _run_prefill(
payload: dict,
headers: dict[str, str],
request_id: str,
):
) -> dict:
url = f"{PREFILL_BASE}{request_path}"
start_ts = time.perf_counter()
logger.info("[prefill] start request_id=%s url=%s", request_id, url)
Expand All @@ -146,13 +146,14 @@ async def _run_prefill(
raise RuntimeError(
f"Prefill backend error {resp.status}: {error_text}"
)
await resp.read()
response_data = await resp.json()
logger.info(
"[prefill] done request_id=%s status=%s elapsed=%.2fs",
request_id,
resp.status,
time.perf_counter() - start_ts,
)
return response_data
except asyncio.TimeoutError as exc:
raise RuntimeError(f"Prefill service timeout at {url}") from exc
except aiohttp.ClientError as exc:
Expand Down Expand Up @@ -203,29 +204,31 @@ async def process_request():
try:
original_request_data = await request.get_json()

# Create prefill request (max_tokens=1)
prefill_request = original_request_data.copy()
prefill_request["max_tokens"] = 1
prefill_request["stream"] = False
if "max_completion_tokens" in prefill_request:
prefill_request["max_completion_tokens"] = 1
prefill_request["kv_transfer_params"] = {
"remote_kv_addr": DECODE_KV_ADDR,
}

# Execute prefill stage
# The request id encodes both KV socket addresses so the backend can
# shuttle tensors directly via NCCL once the prefill response
# completes.
request_id = (
f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_"
f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}"
)
request_id = str(uuid.uuid4())

headers = _build_headers(request_id)
await _run_prefill(request.path, prefill_request, headers, request_id)
prefill_response = await _run_prefill(
request.path, prefill_request, headers, request_id
)

kv_transfer_params = prefill_response.get("kv_transfer_params", {})
logger.info("[proxy] kv_transfer_params: %s", kv_transfer_params)

decode_request = original_request_data.copy()
if kv_transfer_params:
decode_request["kv_transfer_params"] = kv_transfer_params

# Execute decode stage and stream response
# Pass the unmodified user request so the decode phase can continue
# sampling with the already-populated KV cache.
generator = _stream_decode(
request.path, original_request_data, headers, request_id
request.path, decode_request, headers, request_id
)
response = await make_response(generator)
response.timeout = None # Disable timeout for streaming response
Expand Down
14 changes: 7 additions & 7 deletions docs/design/p2p_nccl_connector.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ An implementation of xPyD with dynamic scaling based on point-to-point communica
As shown in Figure 1, the overall process of this **PD disaggregation** solution is described through a request flow:

1. The client sends an HTTP request to the Proxy/Router's `/v1/completions` interface.
2. The Proxy/Router selects a **1P1D (1 Prefill instance + 1 Decode instance)** through either through round-robin or random selection, generates a `request_id` (rules to be introduced later), modifies the `max_tokens` in the HTTP request message to **1**, and then forwards the request to the **P instance**.
3. Immediately afterward, the Proxy/Router forwards the **original HTTP request** to the **D instance**.
4. The **P instance** performs **Prefill** and then **actively sends the generated KV cache** to the D instance (using **PUT_ASYNC** mode). The D instance's `zmq_addr` can be resolved through the `request_id`.
2. The Proxy/Router selects a **1P1D (1 Prefill instance + 1 Decode instance)** through either round-robin or random selection, generates an `request_id`, modifies the `max_tokens` in the HTTP request message to **1**, disables streaming, injects `kv_transfer_params` containing the D instance's KV address, and then forwards the request to the **P instance**.
3. The Proxy/Router waits for the P instance's response, extracts the returned `kv_transfer_params` (containing the P instance's `request_id` and KV address), and forwards them along with the **original HTTP request** to the **D instance**.
4. The **P instance** performs **Prefill** and then **actively sends the generated KV cache** to the D instance (using **PUT_ASYNC** mode). The D instance's KV address is provided via `kv_transfer_params`.
5. The **D instance** has a **dedicated thread** for receiving the KV cache (to avoid blocking the main process). The received KV cache is saved into the **GPU memory buffer**, the size of which is determined by the vLLM startup parameter `kv_buffer_size`. When the GPU buffer is full, the KV cache is stored in the **local Tensor memory pool**.
6. During the **Decode**, the D instance's main process retrieves the KV cache (transmitted by the P instance) from either the **GPU buffer** or the **memory pool**, thereby **skipping Prefill**.
7. After completing **Decode**, the D instance returns the result to the **Proxy/Router**, which then forwards it to the **client**.
Expand All @@ -22,11 +22,11 @@ As shown in Figure 1, the overall process of this **PD disaggregation** solution

A simple HTTP service acts as the entry point for client requests and starts a background thread to listen for P/D instances reporting their HTTP IP and PORT, as well as ZMQ IP and PORT. It maintains a dictionary of `http_addr -> zmq_addr`. The `http_addr` is the IP:PORT for the vLLM instance's request, while the `zmq_addr` is the address for KV cache handshake and metadata reception.

The Proxy/Router is responsible for selecting 1P1D based on the characteristics of the client request, such as the prompt, and generating a corresponding `request_id`, for example:
The Proxy/Router is responsible for selecting 1P1D based on the characteristics of the client request and coordinating the two-phase handshake via `kv_transfer_params`:

```text
cmpl-___prefill_addr_10.0.1.2:21001___decode_addr_10.0.1.3:22001_93923d63113b4b338973f24d19d4bf11-0
```
1. **Prefill request**: The proxy generates a UUID `request_id` and injects `kv_transfer_params` containing the D instance's KV address (`remote_kv_addr`) into the request body. Streaming is disabled so the proxy can read the full JSON response.
2. **Prefill response**: The P instance's completion response includes `kv_transfer_params` with its `request_id` and KV address, which the proxy extracts from the JSON body.
3. **Decode request**: The proxy forwards the prefill's `kv_transfer_params` to the D instance, which uses them to coordinate the KV cache transfer.

Currently, to quickly verify whether xPyD can work, a round-robin selection of 1P1D is used. In the future, it is planned to use a trie combined with the load status of instances to select appropriate P and D.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,34 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex)


async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(url=url, json=data, headers=headers) as response:
if response.status == 200:
if True:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
content = await response.read()
yield content
def _build_headers(request_id):
headers = {"X-Request-Id": request_id}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers


async def forward_request(url, data, headers):
async with (
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
session.post(url=url, json=data, headers=headers) as response,
):
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes


async def run_prefill(url, data, headers):
async with (
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
session.post(url=url, json=data, headers=headers) as response,
):
if response.status == 200:
return await response.json()
raise RuntimeError(
f"Prefill backend error {response.status}: {await response.text()}"
)


@app.route("/v1/completions", methods=["POST"])
Expand Down Expand Up @@ -154,20 +168,28 @@ async def handle_request():
)
count += 1

request_id = (
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_"
f"{decode_zmq_addr}_{random_uuid()}"
)
request_id = random_uuid()
headers = _build_headers(request_id)

prefill_request["stream"] = False
prefill_request["kv_transfer_params"] = {
"remote_kv_addr": decode_zmq_addr,
}

# finish prefill
async for _ in forward_request(
f"http://{prefill_addr}{request.path}", prefill_request, request_id
):
continue
prefill_response = await run_prefill(
f"http://{prefill_addr}{request.path}", prefill_request, headers
)

# forward kv_transfer_params from prefill to decode
kv_transfer_params = prefill_response.get("kv_transfer_params", {})
decode_request = original_request_data.copy()
if kv_transfer_params:
decode_request["kv_transfer_params"] = kv_transfer_params

# return decode
generator = forward_request(
f"http://{decode_addr}{request.path}", original_request_data, request_id
f"http://{decode_addr}{request.path}", decode_request, headers
)
response = await make_response(generator)
response.timeout = None
Expand Down
Loading
Loading