diff --git a/examples/online_serving/separated_encode/launch_1e1pd.sh b/examples/online_serving/separated_encode/launch_1e1pd.sh
new file mode 100644
index 000000000000..3e6a3d6f658d
--- /dev/null
+++ b/examples/online_serving/separated_encode/launch_1e1pd.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+
+
+wait_for_server() {
+ local port=$1
+ timeout 12000 bash -c "
+ until curl -s localhost:$port/v1/chat/completions > /dev/null; do
+ sleep 1
+ done" && return 0 || return 1
+}
+
+MODEL="/workspace/helper/Qwen2.5-VL-3B-Instruct"
+LOG_PATH=$LOG_PATH
+ENCODE_PORT=19534
+ENCODE_RANK=0
+PREFILL_DECODE_PORT=19535
+PREFILL_DECODE_RANK=1
+PROXY_PORT=10001
+GPU_E="4"
+GPU_PD="5"
+export REDIS_HOST="localhost"
+export REDIS_PORT="6379"
+
+START_TIME=$(date +"%Y%m%d_%H%M%S")
+
+redis-server --bind "$REDIS_HOST" --port "$REDIS_PORT" &
+
+CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
+ --gpu-memory-utilization 0.9 \
+ --port "$ENCODE_PORT" \
+ --enable-request-id-headers \
+ --max-num-seqs 128 \
+ --instance-type "encode" \
+ --connector-workers-num 8 \
+ --epd-rank "$ENCODE_RANK" &
+
+
+CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \
+ --gpu-memory-utilization 0.9 \
+ --port "$PREFILL_DECODE_PORT" \
+ --enable-request-id-headers \
+ --max-num-seqs 128 \
+ --instance-type "prefill+decode" \
+ --connector-workers-num 8 \
+ --epd-rank "$PREFILL_DECODE_RANK" &
+
+wait_for_server $ENCODE_PORT
+wait_for_server $PREFILL_DECODE_PORT
+
+python examples/online_serving/separated_encode/proxy/proxy_aiohttp.py \
+ --host "0.0.0.0" \
+ --port "$PROXY_PORT" \
+ --encode-servers-urls "http://localhost:$ENCODE_PORT" \
+ --prefill-decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \
+ --encode-servers-ranks "$ENCODE_RANK" \
+ --prefill-decode-servers-ranks "$PREFILL_DECODE_RANK" &
+
+wait_for_server $PROXY_PORT
\ No newline at end of file
diff --git a/examples/online_serving/separated_encode/launch_1e2pd.sh b/examples/online_serving/separated_encode/launch_1e2pd.sh
new file mode 100644
index 000000000000..5dd46c421efc
--- /dev/null
+++ b/examples/online_serving/separated_encode/launch_1e2pd.sh
@@ -0,0 +1,75 @@
+#!/bin/bash
+
+
+wait_for_server() {
+ local port=$1
+ timeout 12000 bash -c "
+ until curl -s localhost:$port/v1/chat/completions > /dev/null; do
+ sleep 1
+ done" && return 0 || return 1
+}
+
+MODEL="/workspace/helper/Qwen2.5-VL-3B-Instruct"
+LOG_PATH=$LOG_PATH
+
+ENCODE_PORT=19534
+PREFILL_DECODE_PORT_F=19535
+PREFILL_DECODE_PORT_S=19536
+
+ENCODE_RANK=0
+PREFILL_DECODE_RANK_F=1
+PREFILL_DECODE_RANK_S=2
+
+GPU_E="3"
+GPU_PD_F="4"
+GPU_PD_S="5"
+
+PROXY_PORT=10001
+
+export REDIS_HOST="localhost"
+export REDIS_PORT="6379"
+
+START_TIME=$(date +"%Y%m%d_%H%M%S")
+
+redis-server --bind "$REDIS_HOST" --port "$REDIS_PORT" &
+
+CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
+ --gpu-memory-utilization 0.9 \
+ --port "$ENCODE_PORT" \
+ --enable-request-id-headers \
+ --max-num-seqs 128 \
+ --instance-type "encode" \
+ --connector-workers-num 8 \
+ --epd-rank "$ENCODE_RANK" &
+
+CUDA_VISIBLE_DEVICES="$GPU_PD_F" vllm serve "$MODEL" \
+ --gpu-memory-utilization 0.9 \
+ --port "$PREFILL_DECODE_PORT_F" \
+ --enable-request-id-headers \
+ --max-num-seqs 128 \
+ --instance-type "prefill+decode" \
+ --connector-workers-num 8 \
+ --epd-rank "$PREFILL_DECODE_RANK_F" &
+
+CUDA_VISIBLE_DEVICES="$GPU_PD_S" vllm serve "$MODEL" \
+ --gpu-memory-utilization 0.9 \
+ --port "$PREFILL_DECODE_PORT_S" \
+ --enable-request-id-headers \
+ --max-num-seqs 128 \
+ --instance-type "prefill+decode" \
+ --connector-workers-num 8 \
+ --epd-rank "$PREFILL_DECODE_RANK_S" &
+
+wait_for_server $ENCODE_PORT
+wait_for_server $PREFILL_DECODE_PORT_F
+wait_for_server $PREFILL_DECODE_PORT_S
+
+python examples/online_serving/separated_encode/proxy/proxy_aiohttp.py \
+ --host "0.0.0.0" \
+ --port "$PROXY_PORT" \
+ --encode-servers-urls "http://localhost:$ENCODE_PORT" \
+ --prefill-decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT_F,http://localhost:$PREFILL_DECODE_PORT_S" \
+ --encode-servers-ranks "$ENCODE_RANK" \
+ --prefill-decode-servers-ranks "$PREFILL_DECODE_RANK_F,$PREFILL_DECODE_RANK_S" &
+
+wait_for_server $PROXY_PORT
\ No newline at end of file
diff --git a/examples/online_serving/separated_encode/launch_2e1pd.sh b/examples/online_serving/separated_encode/launch_2e1pd.sh
new file mode 100644
index 000000000000..6aa54eddf4df
--- /dev/null
+++ b/examples/online_serving/separated_encode/launch_2e1pd.sh
@@ -0,0 +1,75 @@
+#!/bin/bash
+
+
+wait_for_server() {
+ local port=$1
+ timeout 12000 bash -c "
+ until curl -s localhost:$port/v1/chat/completions > /dev/null; do
+ sleep 1
+ done" && return 0 || return 1
+}
+
+MODEL="/workspace/helper/Qwen2.5-VL-3B-Instruct"
+LOG_PATH=$LOG_PATH
+
+ENCODE_PORT_F=19534
+ENCODE_PORT_S=19535
+PREFILL_DECODE_PORT=19536
+
+ENCODE_RANK_F=0
+ENCODE_RANK_S=1
+PREFILL_DECODE_RANK=2
+
+GPU_E_F="3"
+GPU_E_S="4"
+GPU_PD="5"
+
+PROXY_PORT=10001
+
+export REDIS_HOST="localhost"
+export REDIS_PORT="6379"
+
+START_TIME=$(date +"%Y%m%d_%H%M%S")
+
+redis-server --bind "$REDIS_HOST" --port "$REDIS_PORT" &
+
+CUDA_VISIBLE_DEVICES="$GPU_E_F" vllm serve "$MODEL" \
+ --gpu-memory-utilization 0.9 \
+ --port "$ENCODE_PORT_F" \
+ --enable-request-id-headers \
+ --max-num-seqs 128 \
+ --instance-type "encode" \
+ --connector-workers-num 8 \
+ --epd-rank "$ENCODE_RANK_F" &
+
+CUDA_VISIBLE_DEVICES="$GPU_E_S" vllm serve "$MODEL" \
+ --gpu-memory-utilization 0.9 \
+ --port "$ENCODE_PORT_S" \
+ --enable-request-id-headers \
+ --max-num-seqs 128 \
+ --instance-type "encode" \
+ --connector-workers-num 8 \
+ --epd-rank "$ENCODE_RANK_S" &
+
+CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \
+ --gpu-memory-utilization 0.9 \
+ --port "$PREFILL_DECODE_PORT" \
+ --enable-request-id-headers \
+ --max-num-seqs 128 \
+ --instance-type "prefill+decode" \
+ --connector-workers-num 8 \
+ --epd-rank "$PREFILL_DECODE_RANK" &
+
+wait_for_server $ENCODE_PORT_F
+wait_for_server $ENCODE_PORT_S
+wait_for_server $PREFILL_DECODE_PORT
+
+python examples/online_serving/separated_encode/proxy/proxy_aiohttp.py \
+ --host "0.0.0.0" \
+ --port "$PROXY_PORT" \
+ --encode-servers-urls "http://localhost:$ENCODE_PORT_F,http://localhost:$ENCODE_PORT_S" \
+ --prefill-decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \
+ --encode-servers-ranks "$ENCODE_RANK_F, $ENCODE_RANK_S" \
+ --prefill-decode-servers-ranks "$PREFILL_DECODE_RANK" &
+
+wait_for_server $PROXY_PORT
\ No newline at end of file
diff --git a/examples/online_serving/separated_encode/launch_epd_serve.sh b/examples/online_serving/separated_encode/launch_epd_serve.sh
new file mode 100644
index 000000000000..ef45a96e9b81
--- /dev/null
+++ b/examples/online_serving/separated_encode/launch_epd_serve.sh
@@ -0,0 +1,54 @@
+#!/bin/bash
+
+
+wait_for_server() {
+ local port=$1
+ timeout 12000 bash -c "
+ until curl -s localhost:$port/v1/chat/completions > /dev/null; do
+ sleep 1
+ done" && return 0 || return 1
+}
+
+MODEL="/workspace/helper/Qwen2.5-VL-3B-Instruct"
+LOG_PATH=$LOG_PATH
+ENCODE_PORT=19534
+PREFILL_DECODE_PORT=19535
+PROXY_PORT=10001
+GPU="5"
+export REDIS_HOST="localhost"
+export REDIS_PORT="6379"
+START_TIME=$(date +"%Y%m%d_%H%M%S")
+
+redis-server --bind "$REDIS_HOST" --port "$REDIS_PORT" &
+
+CUDA_VISIBLE_DEVICES="$GPU" vllm serve "$MODEL" \
+ --gpu-memory-utilization 0.2 \
+ --port "$ENCODE_PORT" \
+ --enable-request-id-headers \
+ --max-num-seqs 32 \
+ --instance-type "encode" \
+ --connector-workers-num 8 \
+ --epd-rank 0 &
+
+wait_for_server $ENCODE_PORT
+
+CUDA_VISIBLE_DEVICES="$GPU" vllm serve "$MODEL" \
+ --gpu-memory-utilization 0.7 \
+ --port "$PREFILL_DECODE_PORT" \
+ --enable-request-id-headers \
+ --max-num-seqs 128 \
+ --instance-type "prefill+decode" \
+ --connector-workers-num 8 \
+ --epd-rank 1 &
+
+wait_for_server $PREFILL_DECODE_PORT
+
+python examples/online_serving/separated_encode/proxy/proxy_aiohttp.py \
+ --host "0.0.0.0" \
+ --port "$PROXY_PORT" \
+ --encode-servers-urls "http://localhost:$ENCODE_PORT" \
+ --prefill-decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \
+ --encode-servers-ranks "0" \
+ --prefill-decode-servers-ranks "1" &
+
+wait_for_server $PROXY_PORT
\ No newline at end of file
diff --git a/examples/online_serving/separated_encode/proxy/proxy_aiohttp.py b/examples/online_serving/separated_encode/proxy/proxy_aiohttp.py
new file mode 100644
index 000000000000..365df8ac0e64
--- /dev/null
+++ b/examples/online_serving/separated_encode/proxy/proxy_aiohttp.py
@@ -0,0 +1,271 @@
+# api_proxy.py
+import asyncio
+import json
+import time
+import uuid
+from typing import AsyncIterator, Optional, Dict, Any
+from fastapi import FastAPI, Request, HTTPException
+import aiohttp
+from fastapi.responses import StreamingResponse, JSONResponse
+import uvicorn
+import argparse
+import logging
+import random
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+app = FastAPI()
+
+encode_session: Optional[aiohttp.ClientSession] = None
+decode_session: Optional[aiohttp.ClientSession] = None
+
+@app.on_event("startup")
+async def startup_event():
+ global encode_session, decode_session
+ encode_session = aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(limit=0),
+ timeout=aiohttp.ClientTimeout(total=100000))
+ decode_session = aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(limit=0),
+ timeout=aiohttp.ClientTimeout(total=100000))
+
+@app.on_event("shutdown")
+async def shutdown_event():
+ global encode_session, decode_session
+ if encode_session:
+ await encode_session.close()
+ if decode_session:
+ await decode_session.close()
+
+
+def has_mm_input(request_data: dict):
+ if "messages" not in request_data:
+ return False
+ for message in request_data["messages"]:
+ if not isinstance(message.get("content"), list):
+ continue
+ for content_item in message["content"]:
+ if content_item.get("type") in ["image_url", "audio_url", "input_audio"]:
+ return True
+ return False
+
+async def forward_streaming_request(
+ request_data: dict,
+ request_id: str,
+ e_server_url: str,
+ pd_server_url: str,
+) -> AsyncIterator[str]:
+
+
+ headers = {"x-request-id": request_id}
+ # Skip request to encoder instance if we don't have mm input
+ if has_mm_input(request_data):
+ task1 = asyncio.create_task(
+ encode_session.post(
+ f"{e_server_url}/v1/chat/completions",
+ json=request_data,
+ headers=headers
+ )
+ )
+ try:
+ response = await task1
+ if response.status != 200:
+ error_text = await response.text()
+ raise HTTPException(
+ status_code=response.status,
+ detail={"error": "Request failed", "message": error_text}
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail={"error": "Internal server error", "message": str(e)}
+ )
+
+ try:
+ async with decode_session.post(
+ f"{pd_server_url}/v1/chat/completions",
+ json=request_data,
+ headers=headers
+ ) as response:
+ response.raise_for_status()
+ async for chunk in response.content.iter_chunked(128):
+ if chunk:
+ yield chunk.decode('utf-8', errors='ignore')
+ except Exception as e:
+ logger.error(f"Error in streaming: {e}")
+ raise
+
+async def forward_non_streaming_request(
+ request_data: dict,
+ request_id: str,
+ e_server_url: str,
+ pd_server_url: str,
+) -> dict:
+ headers = {"x-request-id": request_id}
+ # Skip request to encoder instance if we don't have mm input
+ if has_mm_input(request_data):
+ # Start request to encode server
+ task1 = asyncio.create_task(
+ encode_session.post(
+ f"{e_server_url}/v1/chat/completions",
+ json=request_data,
+ headers=headers
+ )
+ )
+
+ try:
+ response = await task1
+ if response.status != 200:
+ error_text = await response.text()
+ raise HTTPException(
+ status_code=response.status,
+ detail={"error": "Request failed", "message": error_text}
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail={"error": "Internal server error", "message": str(e)}
+ )
+
+ try:
+ # Make request to decode server
+ async with decode_session.post(
+ f"{pd_server_url}/v1/chat/completions",
+ json=request_data,
+ headers=headers
+ ) as response2:
+ response2.raise_for_status()
+ result = await response2.json()
+ return result
+ except Exception as e:
+ logger.error(f"Error in non-streaming: {e}")
+ raise
+
+@app.post("/v1/chat/completions")
+async def chat_completions(request: Request):
+ """Handle chat completion requests."""
+ try:
+ e_instance = random.randint(0, len(app.state.e_urls) - 1)
+ pd_instance = random.randint(0, len(app.state.pd_urls) - 1)
+ e_rank = app.state.e_ranks[e_instance]
+ pd_rank = app.state.pd_ranks[pd_instance]
+ e_server_url = app.state.e_urls[e_instance]
+ pd_server_url = app.state.pd_urls[pd_instance]
+
+
+ logger.info(f"Matched: E-{e_rank}, PD-{pd_rank}")
+
+ request_data = await request.json()
+ request_id = request.headers.get("x-request-id")
+ if not request_id:
+ request_id = str(uuid.uuid4())
+ request_id = f"{request_id}|{e_rank}|{pd_rank}"
+ is_streaming = request_data.get("stream", False)
+ if is_streaming:
+ return StreamingResponse(
+ forward_streaming_request(
+ request_data, request_id, e_server_url, pd_server_url),
+ media_type="text/event-stream"
+ )
+ else:
+ result = await forward_non_streaming_request(
+ request_data, request_id, e_server_url, pd_server_url)
+ return JSONResponse(content=result)
+ except Exception as e:
+ logger.error(f"Error processing request: {e}")
+ raise HTTPException(status_code=500, detail=str(e))
+
+@app.get("/v1/models")
+async def list_models():
+ try:
+ async with decode_session.get(f"{app.state.pd_urls[0]}/v1/models") as response:
+ response.raise_for_status()
+ return await response.json()
+ except Exception as e:
+ logger.error(f"Error fetching models: {e}")
+ raise HTTPException(status_code=500, detail=str(e))
+
+@app.get("/health")
+async def health_check():
+ """Health check endpoint."""
+ try:
+ async def check_encode():
+ try:
+ for e_url in app.state.e_urls:
+ async with encode_session.get(f"{e_url}/health") as response:
+ response.raise_for_status()
+ return True
+ except Exception:
+ return False
+
+ async def check_decode():
+ try:
+ for pd_url in app.state.pd_urls:
+ async with encode_session.get(f"{pd_url}/health") as response:
+ response.raise_for_status()
+ return True
+ except Exception:
+ return False
+
+ encode_healthy, decode_healthy = await asyncio.gather(
+ check_encode(), check_decode(), return_exceptions=True
+ )
+
+ health_status = {
+ "proxy": "healthy",
+ "encode_servers": "healthy" if encode_healthy is True else "unhealthy",
+ "prefill_decode_servers": "healthy" if decode_healthy is True else "unhealthy"
+ }
+
+ if not (encode_healthy is True and decode_healthy is True):
+ return JSONResponse(content=health_status, status_code=503)
+
+ return health_status
+
+ except Exception as e:
+ logger.error(f"Health check error: {e}")
+ return JSONResponse(
+ content={"proxy": "unhealthy", "error": str(e)},
+ status_code=503
+ )
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="API Proxy for distributed vLLM servers")
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Proxy host")
+ parser.add_argument("--port", type=int, default=8000, help="Proxy port")
+
+ parser.add_argument("--encode-servers-urls", type=str, required=True,
+ help="URLs of the encode server in comma separated format"
+ "(e.g., \"http://localhost:8001,http://localhost:8002\")")
+
+ parser.add_argument("--encode-servers-ranks", type=str, required=True,
+ help="Respective EPD ranks for encode servers in comma-separated format"
+ "(e.g., \"0,1\")")
+
+ parser.add_argument("--prefill-decode-servers-urls", type=str, required=True,
+ help="URLs of the prefill/decode servers in comma separated format"
+ "(e.g., \"http://localhost:8003,http://localhost:8004\")")
+
+ parser.add_argument("--prefill-decode-servers-ranks", type=str, required=True,
+ help="Respective EPD ranks for encode servers in comma-separated format"
+ "(e.g., \"2,3\")")
+
+ args = parser.parse_args()
+ app.state.e_urls = args.encode_servers_urls.split(",")
+ app.state.pd_urls = args.prefill_decode_servers_urls.split(",")
+ app.state.e_ranks = args.encode_servers_ranks.split(",")
+ app.state.pd_ranks = args.prefill_decode_servers_ranks.split(",")
+
+ logger.info(f"Starting API proxy on {args.host}:{args.port} with 1 worker")
+ logger.info(f"Encode servers: {app.state.e_urls} (respective ranks {app.state.e_ranks})")
+ logger.info(f"Prefill/Decode server: {app.state.pd_urls} (respective ranks {app.state.pd_ranks})")
+
+ uvicorn.run(
+ app,
+ host=args.host,
+ port=args.port,
+ log_level="info",
+ access_log=False,
+ loop="uvloop"
+ )
\ No newline at end of file
diff --git a/examples/online_serving/separated_encode/proxy_request.py b/examples/online_serving/separated_encode/proxy_request.py
new file mode 100644
index 000000000000..e29be8a83333
--- /dev/null
+++ b/examples/online_serving/separated_encode/proxy_request.py
@@ -0,0 +1,81 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import asyncio
+import base64
+import time
+import uuid
+from openai import AsyncOpenAI
+
+
+async def async_query_openai(query, model_path, port):
+ aclient = AsyncOpenAI(
+ base_url=f"http://localhost:{str(port)}/v1",
+ api_key="EMPTY",
+ timeout = 100000,
+ )
+ completion = await aclient.chat.completions.create(
+ model=model_path,
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {
+ "role": "user",
+ "content": query,
+ },
+ ],
+ temperature=0.0,
+ top_p=0.1,
+ max_tokens=512,
+ )
+ return completion.choices[0].message.content
+
+
+async def async_process_queries(queries, model_path, port):
+ results = await asyncio.gather(
+ *(async_query_openai(query, model_path, port) for query in queries)
+ )
+ return results
+
+
+async def main(args):
+ # single query
+ image_path = args.image_path
+ with open(image_path, "rb") as f:
+ encoded_image = base64.b64encode(f.read())
+ encoded_image_text = encoded_image.decode("utf-8")
+ image_base64 = f"data:image;base64,{encoded_image_text}"
+ query = [
+ {
+ "type": "image_url",
+ "image_url": {"url": image_base64},
+ },
+ {"type": "text", "text": "What is shown in the image.?"},
+ ]
+ bs = args.batch_size
+ queries = [query for i in range(bs)]
+
+ start_time = time.time()
+ results = await async_process_queries(
+ queries,
+ args.model_path,
+ args.port
+ )
+ end_time = time.time()
+ for result in results:
+ print(result)
+ print("-" * 50)
+ print(f"Total time: {end_time - start_time:.2f} seconds")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="test")
+ parser.add_argument(
+ "--model_path", type=str, default="Qwen/Qwen2.5-VL-3B-Instruct"
+ )
+ parser.add_argument("--image_path", type=str, default="./demo.jpeg")
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--port", type=int, default=1)
+ args, _ = parser.parse_known_args()
+
+ asyncio.run(main(args))
\ No newline at end of file
diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py
index cfc5e07d8329..3415713be059 100644
--- a/vllm/config/__init__.py
+++ b/vllm/config/__init__.py
@@ -3339,6 +3339,35 @@ class KVEventsConfig:
this topic to receive events.
"""
+@config
+@dataclass
+class EPDDisaggConfig:
+ """Configuration for the encode-prefill-decode disaggregated execution"""
+
+ instance_type: Literal["NoEPD", "encode", "prefill",
+ "prefill+decode"] = "NoEPD"
+ """The type of the instance."""
+
+ connector_workers_num: int = 4
+ """Number of workers for receive & send."""
+
+ epd_rank: int = -1
+ """EPD Disagg rank"""
+
+ def compute_hash(self):
+ """
+ Provide a hash that uniquely identifies all the configs
+ that affect the structure of the computation
+ graph from input ids/embeddings to the final hidden states,
+ excluding anything before input ids/embeddings and after
+ the final hidden states.
+ """
+ factors: list[Any] = []
+ factors.append(self.instance_type)
+ factors.append(self.connector_workers_num)
+ factors.append(self.epd_rank)
+ return hashlib.sha256(str(factors).encode()).hexdigest()
+
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@@ -3387,6 +3416,8 @@ class VllmConfig:
You can specify the full compilation config like so:
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
"""
+ epd_disagg_config: EPDDisaggConfig = field(default_factory=EPDDisaggConfig)
+ """The configuration for epd disaggregation"""
kv_transfer_config: Optional[KVTransferConfig] = None
"""The configurations for distributed KV cache transfer."""
kv_events_config: Optional[KVEventsConfig] = None
@@ -3475,6 +3506,11 @@ def compute_hash(self) -> str:
vllm_factors.append(self.kv_transfer_config.compute_hash())
else:
vllm_factors.append("None")
+ if self.epd_disagg_config:
+ vllm_factors.append(self.epd_disagg_config.compute_hash())
+ else:
+ vllm_factors.append("None")
+
if self.additional_config:
if isinstance(additional_config := self.additional_config, dict):
additional_config_hash = hashlib.md5(
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 7802802f138b..e641bb495c51 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -24,7 +24,8 @@
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigFormat, ConfigType, ConvertOption,
DecodingConfig, DetailedTraceModules, Device,
- DeviceConfig, DistributedExecutorBackend, EPLBConfig,
+ DeviceConfig, DistributedExecutorBackend, EPLBConfig,
+ EPDDisaggConfig,
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
@@ -453,6 +454,11 @@ class EngineArgs:
str, type[LogitsProcessor]]]] = ModelConfig.logits_processors
"""Custom logitproc types"""
+ instance_type: Literal["NoEPD", "encode", "prefill",
+ "prefill+decode"] = EPDDisaggConfig.instance_type
+ connector_workers_num: int = EPDDisaggConfig.connector_workers_num
+ epd_rank: int = EPDDisaggConfig.epd_rank
+
async_scheduling: bool = SchedulerConfig.async_scheduling
kv_sharing_fast_prefill: bool = \
@@ -873,6 +879,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
scheduler_group.add_argument("--async-scheduling",
**scheduler_kwargs["async_scheduling"])
+ # EPD disagg. arguments
+ epd_disagg_kwargs = get_kwargs(EPDDisaggConfig)
+ epd_disagg_group = parser.add_argument_group(
+ title="EPDDisaggConfig",
+ description=EPDDisaggConfig.__doc__,
+ )
+ epd_disagg_group.add_argument("--instance-type",
+ **epd_disagg_kwargs["instance_type"])
+ epd_disagg_group.add_argument(
+ "--connector-workers-num",
+ **epd_disagg_kwargs["connector_workers_num"])
+ epd_disagg_group.add_argument(
+ "--epd-rank",
+ **epd_disagg_kwargs["epd_rank"])
+
# vLLM arguments
vllm_kwargs = get_kwargs(VllmConfig)
vllm_group = parser.add_argument_group(
@@ -1379,6 +1400,11 @@ def create_engine_config(
collect_detailed_traces=self.collect_detailed_traces,
)
+ epd_disagg_config = EPDDisaggConfig(
+ instance_type=self.instance_type,
+ connector_workers_num=self.connector_workers_num,
+ epd_rank = self.epd_rank)
+
config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
@@ -1393,6 +1419,7 @@ def create_engine_config(
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
+ epd_disagg_config=epd_disagg_config,
additional_config=self.additional_config,
)
diff --git a/vllm/separated_encode/1e1pd.png b/vllm/separated_encode/1e1pd.png
new file mode 100644
index 000000000000..7fbab8215232
Binary files /dev/null and b/vllm/separated_encode/1e1pd.png differ
diff --git a/vllm/separated_encode/README.md b/vllm/separated_encode/README.md
new file mode 100644
index 000000000000..7735978dbaf7
--- /dev/null
+++ b/vllm/separated_encode/README.md
@@ -0,0 +1,481 @@
+# MLLM Encode separation and E-P Encoder Cache Transfer
+
+
+Encode-Prefill-Decode disaggregation provides greater flexibility in distributed MLLM inference, enables better resource utilization under fixed TTFT and TPOT Service Level Objectives, and allows for the application of stage-level optimizations. To implement the EPD full disaggregation, visual encoder is need to be moved to a separate instance.
+
+This update introduces the implementation of MLLM visual encoder separation, an abstraction for inter-instance (E-P) communication for encoder cache transfer, and a concrete example implementation of disaggregated E+PD serving.
+
+
+
+## Motivation
+
+1. Encoder separation is a critical part of EPD disaggregation, as it allows the visual encoder to be decoupled from the Prefill and Decode stages. To implement and use EPD disaggregation in future a visual encoder separation is required.
+
+2. Better serving for MLLM. Consider mixed inputs continuous serving scenario, every 10th request includes a large multimodal input, while all others are text-only. In the current vLLM, all requests in a batch in model execution will wait for the multimodal input embedding generation to complete. When new request with the image arrive, all requests will again wait for the multimodal (MM) encoder to run, causing significant performance degradation. By separating the encoding stage, this bottleneck can be avoided. This scenario is used only to highlight the bottleneck, same bottleneck appears in other serving scenarios and can be fixed by separated encoder.
+
+```mermaid
+flowchart LR
+ subgraph "Current vLLM Architecture"
+ direction LR
+
+
+ subgraph I2 ["Iteration 2"]
+ direction LR
+ R2["Requests 11-19: Text
🖼️ Request 20: MM + Text"]
+ R3["Running Batch 1-10"]
+ S2[Scheduler]
+ B2["Batch 1-20"]
+ MR2[ModelRunner]
+ MM2["execute_mm_encoder
ALL requests blocked AGAIN"]
+ LM2["Language Model
1 generation step"]
+
+ R2 --> S2
+ R3 <--> S2
+ S2 --> B2 --> MR2 --> MM2 --> LM2
+ end
+
+ subgraph I1 ["Iteration 1"]
+ direction LR
+ R1["Requests 1-9: Text
Request 10: MM + Text"]
+ S1[Scheduler]
+ B1["Batch 1-10"]
+ MR1[ ModelRunner]
+ MM1["execute_mm_encoder
ALL requests blocked"]
+ LM1["Language Model
1 generation step"]
+
+ R1 --> S1 --> B1 --> MR1 --> MM1 --> LM1
+ end
+
+ end
+```
+
+## Overall Process (Sequential requests)
+
+The separated encode overall process in 1E1PD with proxy scenario:
+
+1. The client sends an HTTP request to the Proxy/Router's /v1/completions interface.
+2. The Proxy/Router selects a 1E1PD (1 Encode instance + 1 Prefill+Decode instance) generates a request_id.
+3. Proxy sends request to E instance.
+4. When request is scheduled on the E instance, the request's metadata is sent to PD instance through Encoder Cache connector.
+5. Then multimodal embeddings are generated for request's multimodal inputs, the multimodal embeddings are scheduled for the transfer in encoder cache connector on E instance and empty response is returned to proxy immediately after generation of the multimodal embeddings is completed.
+6. After receiving response from E instance, the Proxy/Router forwards the request to the PD instance. As encoder budget is set to 0 on PD instance, it can't schedule chunks with multimodal input and waits the injection of encoder cache.
+7. When there is enough space in the encoder cache manager on the PD instance, instance preallocates the space for request and sends a preallocation notification to the E instance. If the encoder cache is not required (i.e., tokens are reused via the KV cache), it sends a "failed" preallocation notification, indicating that encoder cache transfer is not required.
+8. After preallocation notification is received on E instance's EC connector the encoder cache is sent from E instance to the PD instance.
+9. PD instance receives cache and finalize allocation of the multimodal input data and now is able to schedule the request. As encoder cache is already calculated PD instance skips MLLM's encode step and uses the value from cache.
+10. After completing Prefill and Decode, the PD instance returns the result to the Proxy/Router, which then forwards it to the client.
+
+```mermaid
+sequenceDiagram
+ participant P as Proxy
+ participant ES as Encode Scheduler
+ participant EVW as DisaggEncodeGPURunnerWrapper
+ participant Redis as Redis
+ participant PDW as DisaggPrefillDecodeGPURunnerWrapper
+ participant PDS as Prefill+Decode Scheduler
+
+ P->>ES: Request
+ ES->>ES: Schedule request
+ ES->>Redis: Send encoder cache metadata
+ ES->>EVW: SchedulerOutput
+ EVW->>EVW: _execute_mm_encoder() - Process vision/MM inputs
+ Redis->>PDS: Receive metadata
+ EVW->>EVW: Store in encoder_cache temporarily
+ EVW->>ES: ModelRunnerOutput
+ ES->>P: Empty response
+ P->>PDS: Request
+ alt Multimodal input tokens KV values are not obtained from KV cache
+ PDS->>PDS: Preallocate encoder cache space
+ else
+ end
+ PDS->>Redis: Send preallocation response
+ Redis->>EVW: Receive preallocation
+ alt Multimodal input tokens KV values are not obtained from KV cache
+ EVW->>Redis: Send actual encoder cache data
+ Redis->>PDW: Receive encoder cache
+ PDW->>PDW: Inject into encoder_cache dict
+ else Multimodal input tokens KV values are obtained from KV cache
+ end
+ EVW->>ES: ModelRunnerOutput
+ ES->>ES: Free slots in Encoder Cache Manager
+ PDW->>PDS:ModelRunnerOutput
+ alt Multimodal input tokens KV values are not obtained from KV cache
+ PDS->>PDS: Finalize allocation
+ else Multimodal input tokens KV values are obtained from KV cache
+ PDS->>PDS: Free preallocated encoder cache space
+ end
+ PDS->>PDW: SchedulerOutput
+ Note over EVW,PDW: Phase 3: Normal LM Inference
+ PDW->>PDW: Execute prefill + decode
+ PDW->>PDW: Generate text output
+ PDW->>PDS: ModelRunnerOutput
+ PDS->>P: Response
+
+```
+NOTE: Current implementation allows to execute E and PD instance for the same request in parallel, but it doesn't provide significant performance improvement, therefore scheme doesn't illustrat this.
+
+
+
+
+
+# Implementation
+
+Code changes in the implementation are madeto solve three practical problems. Implementation must provide clear control of the request lifecycle in EPD scenario so memory can be reserved for cache between instances, encoder outputs can be transfered, and request must keep using the key caching features without issues. Also the code must be compatible enough to keep working as vLLM changes, so deep vLLM core changes are avoided and most logic is put behind small adapters/wrappers. And implementaiton must provide proper EPD disaggregation functionality.
+
+Forward-compatibility is preserved by wrapping existing components instead of changing them. The GPU model runner is used as-is, with thin wrappers that add EPD tracking and connector calls. The scheduler gets small hooks to talk to a preallocator and to reconcile injected data. The transport is hidden behind a connector interface so implementations can be switched without touching schedulers or runners. The request lifecycle control and "inter-instance" encoder cache space management are handled by EncoderCachePreallocator, EncoderCacheConnector, and a two-stage allocation flow in the EncoderCacheManager.
+
+## vLLM minor changes
+
+### Scheduler routing
+**Files:** `vllm/v1/core/core.py`
+
+During execution with EPD disaggregation, EngineCore now uses the Encoder Scheduler for the Encode instance.
+
+### EPD Disaggregation Configuration
+**Files:** `vllm/config/__init__.py`, `vllm/core/arg_utils.py`
+
+Added a new configuration class for EPD disaggregation. Currently supports configuration of instance type, instance's EPD rank and the number of connector workers.
+
+### Additional ModelRunnerOutput Data Fields
+**Files:** `vllm/v1/outputs.py`
+
+The model runner output now includes two additional data fields: `transfered_mm_data` and `injected_mm_data`.
+
+The `transfered_mm_data` field passes a list of transfered encoder cache input IDs from the model runner to the scheduler on the encode instance. After receiving transfered data IDs, the scheduler will clear free space in the encoder cache manager.
+
+The `injected_mm_data` field passes a list of injected encoder cache input IDs with `mm_hash` from the model runner to the scheduler on the prefill instance. After receiving injected data IDs, the scheduler will finalize allocations
+
+### Model Runner Wrapper Integration in GPUWorker
+**Files:** `vllm/v1/worker/gpu_worker.py`
+
+When EPD disaggregation is enabled, the system uses wrapper classes of GPUModelRunner class.
+
+### GPU Model Runner santiy check in encoder execution
+**Files:** `vllm/v1/worker/gpu_model_runner.py`
+
+If EPD disaggregated serving is enabled, an additional attribute is added to indicate whether encoder execution is allowed. This attribute is used to perform a sanity check on each execution of the encoder. Also encoder cache lock is added to ensure encoder cache injection safety.
+
+## Major Changes
+
+### EncoderCacheManager new Allocation Logic
+**Files:** `vllm/v1/core/encoder_cache_manager.py`
+
+In disaggregated settings, encoder outputs need to be transferred between instances. This creates timing challenges: we must reserve cache space before receiving the data, but we can't mark it as cached until the transfer completes. For that the EncoderCacheManager now supports a two-stage allocation process for handling encoder outputs receiving. The two-stage process addresses this by separating reservation from finalization.
+
+The `preallocate()` method reserves cache space for incoming encoder outputs. It tracks which requests need which multimodal inputs and prevents premature eviction of entries that will be reused. This method returns whether the encoder output needs to be computed and transferred `(true)` or is already available in cache `(false)`.
+
+The `finalize_allocation()` method completes the allocation after encoder outputs are received. It converts the preallocation into an actual cached entry or releases the reservation if another request provided the data or mm input tokens prefilling is skipped due to prefix caching.
+
+To prevent race conditions between eviction and incoming transfers, we don't evict caches that have active preallocations(not finalized allocations*).
+
+### EncoderCachePreallocator
+**Files:** `vllm/separated_encode/sched/encoder_cache_preallocator.py`
+The EncoderCachePreallocator schedules preallocation requests. It synchronizes incoming requests with encoder metadata received asynchronously from encoding instances. The preallocator serves three essential purposes:
+
+It synchronizes request arrival with encoder metadata reception. Since encoder metadata can arrive before or after the corresponding request, the preallocator buffers metadata for requests that haven't arrived yet and processes waiting metadata when requests do arrive.
+
+It manages the preallocation queue to determine for which encode output when cache space should be reserved. The system maintains a queue of pending preallocations and validates each candidate against available cache capacity before proceeding with reservation.
+
+It tracks multimodal input processing progress to avoid unnecessary data transfers. When inputs are obtained from existing caches (KV cache or encoder cache), the preallocator sends notifications to cancel pending transfers and ignores subsequent metadata for those inputs.
+
+The system provides an abstract base class `EncoderCachePreallocatorTemplate` that defines the interface for preallocation strategies. This template initializes the encoder cache connector for receiving metadata and defines abstract methods that concrete implementations must provide. And one concrete synchronous implementation example.
+
+#### Request flow
+
+When encoder metadata arrives via the `_receive_encoder_cache_metadata` callback, the preallocator checks if the request is active. For active requests, it immediately schedules preallocation. For inactive requests, it stores the metadata in `waiting_preallocs` for later processing.
+
+Request addition triggers processing of any waiting metadata. The `add_request` method initializes tracking structures and schedules preallocations for any metadata that arrived early.
+
+As the tokens are processed, `update_mm_inputs_done` tracks which multimodal inputs are complete. When a pending input is covered by existing cache, the system sends a cancellation notification and marks it as ignored.
+
+The `get_prealloc_candidate` method provides the interface for retrieving candidates from the queue. It validates each candidate against available space and skips ignored entries. The method returns whether to continue processing and the candidate data if valid. In the scheduler this method is called from `_perform_preallocations` method.
+
+```mermaid
+sequenceDiagram
+ participant E as Encode Instance
+ participant PD as Prefill+Decode Instance
+
+ E->>PD: "I have 2048 tokens of encoder cache to send"
+ Note over PD: Check available memory
+ PD->>PD: encoder_cache_manager.can_preallocate(2048 tokens)
+ alt Multimodal input tokens KV values are not obtained from KV cache
+ PD->>PD: encoder_cache_manager.preallocate(req_id, input_id, 2048 tokens)
+ PD->>E: "I reserved 2048 tokens for you"
+ Note over E: Now safe to send
+ E->>PD: Send actual 2048 tokens data
+ Note over PD: After successful injection
+ PD->>PD: encoder_cache_manager.finalize_allocation(req_id, input_id)
+ else Multimodal input tokens KV values are obtained from KV cache
+ PD->>PD: encoder_cache_manager.depreallocate(req_id, input_id)
+ PD->>E: "No need to send"
+ end
+```
+
+
+### Encoder Scheduler
+
+#### Encoder Scheduler (encode)
+**Files:** `vllm/separated_encode/sched/encoder_scheduler.py`
+
+Separate EncoderScheduler class implementation is provided for encode instance scheduling.
+
+The EncoderScheduler is a specialized scheduler for encode instances that focuses on only multimodal input scheduling. It maintains an `_allocated` dictionary to track allocated encoder cache entries, their sizes and hashes. This dictionary is used to allow us to free up logical space without storing the request itself, which enables us to end the request before the data is transferred.
+
+Currently the encode scheduler schedules all multimodal inputs for a request at once in the `schedule()` method. It checks if there's sufficient encoder cache space and budget before allocating all inputs together. Note that input is already cached we will still add it into the `scheduled_encoder_inputs`, but we will not allocate space for it and on model runner we will skip the encoder execution for such elements, we need to do that because in `model_runner` the signal needs to be sent to `ECConnector` from each `mm_input`.
+
+A request on the encode instance is considered finished when all its multimodal embeddings have been computed, so all requests are finished in 1 iteration after scheduling, transfer is handled separately in encoder cache connectors, space allocated for encoder cache is deallocated only after transfers, not after request finish.
+
+In the `update_from_output()` method, the scheduler goes through transferred multimodal data IDs and frees the mm inputs in encoder cache manager.
+
+#### Main Scheduler (prefill and prefill+decode instances)
+**Files:** `vllm/v1/core/sched/scheduler.py`
+
+For prefill and prefill+decode instances, the main scheduler is changed for multimodal inputs encode separation.
+
+If encoder separation is turned on, we instantiate `encoder_cache_preallocator` object in scheduler, this preallocator handles communication through `ec_connector` in it and `preallocation` scheduling and synchronization, also we set `max_num_encoder_input_tokens` to 0 to avoid the usage of the multimodal data encoder on P or PD instance.
+
+Mostly main scheduler has 3 changes, integration of `encoder_cache_preallocator`, `_perform_preallocations()` and `injected_mm_data` allocation.
+
+The `encoder_cache_preallocator` is described in the corresponding part of the documentation. The `_perform_preallocations()` function is used to connect `encoder_cache_preallocator`, which manages which requests will be preallocated, and the encoder cache manager, in which we actually performs preallocations. This function just gets the preallocation candidate from the `encode_cache_preallocator` until there are enough slots in the `encoder_cache_manager`. Perform preallocation is called 2 times: in `update_after_schedule()` after some cache can potentially become freeable, and in `update_from_output` after handling injected data.
+
+The injected data handling is performed via `injected_mm_data` attribute from `ModelRunnerOutput`, scheduler just going through injected data and decides whether the allocation needs to be finalized or we don't need the obtained data anymore and we can just say that this injected data is freeable.
+
+Such implementation allows us to achieve motivation described in changes for encoder cache manager, encoder cache preallocator and also more efficiently use caching techniques in EPD disaggregation.
+
+### Instance-Specific Model Runner Wrappers
+**Files:** `vllm/separated_encode/worker/gpu_epd_lm_wrapper.py`, `vllm/separated_encode/worker/gpu_epd_vm_wrapper.py`
+
+The implementation introduces specialized GPU model runner wrappers for disaggregated architecture, focusing on distinct roles for multimodal encoding and text generation. These wrappers are built on top of the GPUModelRunner for better compatibility with future changes in GPUModelRunner. As long as the v1 interface for GPU Model Runner remains unchanged, the wrappers do not require updates, wrapper simply call the original methods, instantiate the encoder cache connector, track information, and modify the model runner output with EPD-related information.
+
+#### DisaggEncodeGPURunnerWrapper (Encode Instance)
+
+This wrapper runs on encode instances and processes multimodal inputs. It executes encoder models and sends the results to other instances through encoder cache connector.
+
+The encode instance doesn't need KV cache since it only runs vision part of MLLM. The wrapper overrides `initialize_kv_cache_tensors` and `initialize_kv_cache` to return empty results, freeing up GPU memory for larger encoder cache storage.
+
+During execution, the wrapper executes encoding for scheduled multimodal inputs and inserts encoder output in encoder cache connector, due to nature of encode scheduler the `scheduled_output.scheduled_encoder_inputs` can contain already cached inputs or multiple same multimodal inputs, as cache is already present or going to be present we can just skip the encoding process for such `mm_inputs`.So we temporarily remove cached inputs and inputs such that their `mm_hash` already present somewhere in `scheduled_encoder_inputs`, after execution we return all removed entries back to `scheduler_output`. Motivation for sending all multimodal inputs to `model_runner` is provided in `EncoderScheduler` section.
+
+Since no text generation happens here, it returns almost empty ModelRunnerOutput with additional transfered encoder outputs information in ModelRunnerOutput, this information is used in encoder scheduler to free the space in encoder cache manager.
+
+#### DisaggPrefillDecodeGPURunnerWrapper (Prefill/(Prefill+Decode) Instance)
+
+This wrapper runs on prefill or (prefill+decode) instances where the Language Model is exectued. It receives encoder cache from encode instances and injects them into the encoder cache stored in normal GPUModelRunner.
+
+The wrapper uses a callback function `receive_encoder_cache` to handle incoming encoder data. It asynchronously injects encoder_output into the encoder cache and updates injected_ids list.
+
+During `execute_model`, the wrapper simply calls `execute_model` from original GPUModelRunner, and also adds tracking of injected encoder caches. It reports successful injections back to the scheduler through the model output, allowing the scheduler to finalize allocations of preallocated inputs.
+
+### Encoder Cache Connector
+**Files:** `vllm/separated_encode/ec_transfer/connector/template.py`, `vllm/separated_encode/ec_transfer/connector/redis.py`
+
+The Encoder Cache Connector provides an abstraction layer for transferring encoder caches between encode and prefill instances in disaggregated vLLM deployments. The abstract base class ECConnectorTemplate defines the communication logic.
+
+The connector addresses several critical challenges in encoder output transfer in encoder separation. First, it ensures reliable data transfer by verifying that receiver instances have sufficient memory allocated and adequate logical space for encoder outputs before initiating transfers(Therefore we need first E->P/PD transfer). This prevents out-of-memory errors and failed transfers that could disrupt the inference pipeline. Second, it provides a flexible abstraction layer that allows simple implementation of different connector approaches (e.g., Redis, nccl, ...) without modifying the EPD disaggregation / vLLM core components. Third, it helps to manage the lifecycle of the request in EPD disaggregation.
+
+The connector operates using a thread-based architecture with separate send and receive event loops. Communication is handled asynchronously through configurable worker pools. It maintains separate queues for send and receive operations, with each operation executed by dedicated worker threads.
+
+The encoder connector operates in four distinct states based on instance type and its component:
+
+**State for Encode Scheduler** - Pure sender functionality that handles encoder cache metadata transfer. When multimodal input is scheduled, metadata sending tasks are added to the send queue for processing by the send event loop.
+
+**State for Prefill Scheduler** - Receives encoder cache metadata from encode instances and manages preallocation through scheduler callbacks. The preallocation logic is described in scheduler updates. After successful preallocation, sends completion notifications back to encode instances from which it received the metadata.
+
+**State for Encode Model Runner** - Manages cache storage, transfer, and lifecycle. It maintains:
+
+- `encoder_cache`: Dictionary storing computed encoder outputs. *NOTE: The values are not copies of encoder cache, therefore we don't use additional GPU memory to store this encoder cache dictionary.*
+- `cache_to_send`: Set of pending encoder outputs transfers awaiting preallocation confirmation
+- `cache_to_avoid`: Set of encoder_output IDs that don't need to be sent
+- `transfered_ids`: List tracking successfully transferred cache IDs
+
+When encoder output is generated, `add_encoder_cache()` either adds the cache to local (`input_id`, `req_id`) map or immediately schedules transfer(or skips transfer) if a preallocation notification was already received before.
+
+Upon receiving successful preallocation notifications via `_maybe_send_encoder_cache()`, it either sends the cache immediately or adds the request to the pending set. It can receive failed preallocation notification, it means that we don't need to send encoder cache to this instance and can delete the encoder cache for this (req_id, input_id) from the Encoder instance.
+
+So encoder outputs are scheduled for transfer to PD instance as soon as both conditions are met.
+
+**State for Prefill Model Runner** - Receive-only state that accepts encoder cache data and calls injection callbacks to add the cache into the model runner's encoder cache dictionary.
+
+The communication flow follows this sequence:
+
+- Encode Scheduler sends metadata to Prefill Scheduler for cache preallocation
+- Prefill Scheduler attempts preallocation and sends success/failure notifications to Encode Model Runner
+- Upon successful preallocation, Encode Model Runner transfers the actual encoder cache data to Prefill Model Runner
+
+Transfer completion tracking is built into the class. Through the connector's `get_transferred_ids` method, the model runner can retrieve which request data has already been received.
+
+```mermaid
+graph LR
+ subgraph "Encode Instance"
+ ES[Encode Scheduler]
+ EMR[Encode Model Runner]
+ end
+
+ subgraph "Prefill Instance"
+ PS[Prefill Scheduler]
+ PMR[Prefill Model Runner]
+ end
+
+
+ ES -.->|metadata| PS
+ EMR -.->|cache data| PMR
+ PS -.->|notification| EMR
+ PMR -.->|injected IDs| PS
+ EMR -.->|transfered IDs| ES
+```
+
+
+
+#### Extension Example
+
+The included `RedisECConnector` demonstrates a concrete implementation using Redis as the communication backend. To use other communication backends, implement the abstract methods `_send_prealloc_notification`, `_send_encoder_cache_metas`, `_send_encoder_cache`, `_recv_prealloc_notification`, `_recv_encoder_cache_metas`, and `_recv_encoder_cache` according to your chosen transport mechanism. This connection extension supports multiple E instances and multiple PD or P instances.
+
+# Usage Instructions
+
+*To use E+PD disaggregation install `redis`, `msgpack_numpy` python packages and `redis-server` to your system. Because currently only RedisECConnector is implemented*.
+
+Update provides a toy proxy implementation and scripts to startup the EPD Disaggregated vLLM. There are multiple scripts to run EPD disaggregation, one of them is used to run EPD Disaggregated vLLM on 1 GPU, and all another scripts are used to run it on multiple GPUs. Any number of GPUs is supported by components(you can run xEyPD disagg.), but you need to rewrite script, use other scripts as reference.
+
+To start the EPD instances and proxy server, select one of the provided scripts and modify the arguments as needed before execution. You can run the deployment using any of these commands:
+
+```bash
+# 1 GPU, 2 Instances on the same GPU
+bash examples/online_serving/separated_encode/launch_epd_serve.sh
+```
+
+```bash
+# 2 GPUs, 1 E instance, 1 PD instance
+bash examples/online_serving/separated_encode/launch_1e1pd.sh
+```
+
+```bash
+# 3 GPUs, 1 E instance, 2 PD instances
+bash examples/online_serving/separated_encode/launch_1e2pd.sh
+```
+
+```bash
+# 3 GPUs, 2 E instance, 1 PD instances
+bash examples/online_serving/separated_encode/launch_2e1pd.sh
+```
+
+After the server starts running, you can interact with it using OpenAI-compatible API requests to send queries and receive responses. Sample Python code for sending requests is available in the examples/online_serving/separated_encode/ directory. You can send testing request like this:
+
+```bash
+python examples/online_serving/separated_encode/proxy_request.py --port $PORT --model_path $MODEL --image_path docs/assets/design/arch_overview/entrypoints.excalidraw.png &
+```
+
+# Benchmark
+
+Performance evaluation was conducted using Qwen2.5-VL-3B-Instruct on an NVIDIA A100-SXM4-80GB GPU. :
+
+- default vllm(1 GPU)
+- default vllm(2 GPU, Tensor Par.)
+- default vllm(2 GPU, Data Par.)
+- 1E1PD disaggregated(1 GPU),
+- 1E1PD disaggregated(2 GPU).
+
+Testing was conducted on the lmarena-ai/VisionArena-Chat dataset with varying prompt loads from 100 to 1000 requests to assess scalability characteristics.
+
+```
+python benchmarks/benchmark_serving.py \
+ --backend openai-chat \
+ --endpoint /v1/chat/completions \
+ --model $MODEL \
+ --dataset-name hf \
+ --dataset-path $DATASET \
+ --hf-split train \
+ --num-prompts $NUM_PROMPTS \
+ --seed 40 \
+ --save-result \
+ --save-detailed \
+ --result-dir $LOG_PATH/vision_arena_results \
+ --result-filename vision_arena_outputs$(date +"%Y%m%d_%H%M%S").json \
+ --port 10001 > $LOG_PATH/benchmark_VisionArena_$(date +"%Y%m%d_%H%M%S").log 2>&1
+```
+
+We benchmarked EPD approaches and Default (1GPU) three times across 4 workloads. Data parallel and tensor parallel approaches were testedat different times on the same server. The table below shows the mean statistics across all benchmark runs, with detailed individual results provided at the end of this document.
+
+| Approach | \# prompts | Benchmark duration (s) | Req throughput (req/s) | Mean TTFT (ms) | P99 TTFT (ms) | Mean TPOT (ms) | P99 TPOT (ms) | Mean ITL (ms) | P99 ITL (ms) |
+| ----------------------------- | ---------- | ---------------------- | -------------------------- | -------------- | ------------- | -------------- | ------------- | ------------- | ------------ |
+| E+PD [2GPU,1E,1PD] (Mean) | 1000 | 63.36 | 15.82 | 32002.81 | 60608.7 | 62.31 | 111.02 | 152.28 | 1639.21 |
+| Default [2GPU, Data Par.]\* | 1000 | 74.94 | 13.34 | 37626.68 | 70951.04 | 146.56 | 342.69 | 195.13 | 1816.9 |
+| Default [2GPU, Tensor Par.]\* | 1000 | 91.69 | 10.91 | 42369.61 | 87645.92 | 159.2 | 294.61 | 157.89 | 1100.4 |
+| E+PD [1GPU] (Mean) | 1000 | 93.54 | 10.69 | 50211.56 | 90882.21 | 98.28 | 186.53 | 128.71 | 924.02 |
+| Default [1GPU] (Mean) | 1000 | 106.71 | 9.39 | 50789.65 | 102740.3 | 115.47 | 202.19 | 139.82 | 1123.19 |
+| E+PD [2GPU,1E,1PD] (Mean) | 500 | 34.4 | 14.56 | 17151.17 | 32094.03 | 60.62 | 104.12 | 97.91 | 789.28 |
+| Default [2GPU, Data Par.]\* | 500 | 40.23 | 12.43 | 19294.31 | 36177.82 | 128.46 | 378.29 | 173.43 | 1485.35 |
+| Default [2GPU, Tensor Par.]\* | 500 | 48.99 | 10.21 | 22218.02 | 45730.88 | 157.56 | 234.1 | 158.67 | 833.73 |
+| E+PD [1GPU] (Mean) | 500 | 49.25 | 10.16 | 26199.06 | 46978.92 | 93.11 | 196.85 | 114.67 | 651.61 |
+| Default [1GPU] (Mean) | 500 | 54.67 | 9.15 | 23792.12 | 51304.22 | 108.05 | 239.63 | 129.38 | 1054.42 |
+| E+PD [2GPU,1E,1PD] (Mean) | 200 | 15.46 | 12.98 | 6666.79 | 12372.37 | 45.57 | 76.24 | 68.81 | 367.75 |
+| Default [2GPU, Data Par.]\* | 200 | 18.04 | 11.09 | 7791.15 | 4029.51 | 89.08 | 312.61 | 82.42 | 906.56 |
+| Default [2GPU, Tensor Par.]\* | 200 | 21.76 | 9.19 | 9143.71 | 18484.54 | 101.74 | 221.54 | 95.91 | 455.9 |
+| E+PD [1GPU] (Mean) | 200 | 21.83 | 9.17 | 9780.98 | 19271.01 | 79.11 | 147.9 | 85.97 | 354.24 |
+| Default [1GPU] (Mean) | 200 | 22.97 | 8.71 | 9420.3 | 20477.48 | 88.53 | 216.12 | 86.12 | 667.83 |
+| E+PD [2GPU,1E,1PD] (Mean) | 100 | 7.49 | 13.4 | 3867.56 | 6057.55 | 21.8 | 44.44 | 28.63 | 264.65 |
+| Default [2GPU, Data Par.]\* | 100 | 8.26 | 12.1 | 3898.85 | 6745.78 | 45.53 | 163.51 | 36.75 | 577.1 |
+| Default [2GPU, Tensor Par.]\* | 100 | 10.87 | 9.2 | 4920.22 | 9505.53 | 60.54 | 212.46 | 50.51 | 653.17 |
+| E+PD [1GPU] (Mean) | 100 | 10.84 | 9.27 | 5925.45 | 9402.15 | 36.09 | 122.15 | 35.19 | 262.08 |
+| Default [1GPU] (Mean) | 100 | 11.05 | 9.11 | 5153.35 | 9354.25 | 57.95 | 195.97 | 49.15 | 478.4 |
+
+* The TP and DP benchmarking were conducted at different time, the observed differences may be attributed to other factors
+
+The benchmark results demonstrate that E+PD can provide request throughput and latency performance compared to default configurations across both 1 and 2 GPU setups, as evaluated on NVIDIA A100-SXM4-80GB hardware using the Qwen2.5-VL-3B-Instruct model. These performance improvements highlight the effectiveness of the Encoder Separation technique for optimizing multimodal model serving workloads.
+
+
+Detailed results:
+
+| Approach | Trial | num_prompts | Benchmark duration (s) | Total input tokens | Total generated tokens | Request throughput (req/s) | Output token throughput (tok/s) | Total Token throughput (tok/s) | Mean TTFT (ms) | Median TTFT (ms) | P99 TTFT (ms) | Mean TPOT (ms) | Median TPOT (ms) | P99 TPOT (ms) | Mean ITL (ms) | Median ITL (ms) | P99 ITL (ms) |
+| ------------------ | ---------- | ----------- | ---------------------- | ------------------ | ---------------------- | -------------------------- | ------------------------------- | ------------------------------ | -------------- | ---------------- | ------------- | -------------- | ---------------- | ------------- | ------------- | --------------- | ------------ |
+| Default [1GPU] | 1 | 100 | 12.31013 | 8122 | 10917 | 8.123391 | 886.8306 | 1546.612 | 6213.311 | 5776.914 | 10610 | 60.43373 | 50.78793 | 216.0569 | 51.2556 | 14.30057 | 613.0175 |
+| Default [1GPU] | 2 | 100 | 10.85676 | 8122 | 10957 | 9.210854 | 1009.233 | 1757.339 | 4779.188 | 4088.457 | 9124.295 | 59.50579 | 52.69066 | 191.1305 | 50.71513 | 14.28988 | 607.7266 |
+| Default [1GPU] | 3 | 100 | 9.991029 | 8122 | 10989 | 10.00898 | 1099.887 | 1912.816 | 4467.538 | 3962.17 | 8328.441 | 53.90317 | 47.2205 | 180.7076 | 45.4647 | 14.11477 | 214.4418 |
+| Default [1GPU] | mean_stats | | 11.05264 | 8122 | 10954.33 | 9.114408 | 998.6502 | 1738.922 | 5153.346 | 4609.18 | 9354.245 | 57.94756 | 50.23303 | 195.965 | 49.14514 | 14.23507 | 478.3953 |
+| Default [1GPU] | 1 | 200 | 22.98617 | 32055 | 22152 | 8.700882 | 963.7097 | 2358.244 | 9865.168 | 8039.06 | 20380.57 | 83.9027 | 87.92922 | 200.0653 | 83.52419 | 22.85806 | 752.5209 |
+| Default [1GPU] | 2 | 200 | 22.89445 | 32055 | 22060 | 8.735741 | 963.5523 | 2363.673 | 9200.236 | 7245.377 | 20568.84 | 90.94091 | 98.33245 | 224.7042 | 87.7061 | 20.40263 | 630.2986 |
+| Default [1GPU] | 3 | 200 | 23.03203 | 32055 | 22073 | 8.683559 | 958.361 | 2350.118 | 9195.481 | 7298.334 | 20483.04 | 90.75096 | 96.2287 | 223.5768 | 87.12716 | 22.04948 | 620.6832 |
+| Default [1GPU] | mean_stats | | 22.97089 | 32055 | 22095 | 8.706728 | 961.8743 | 2357.345 | 9420.295 | 7527.59 | 20477.48 | 88.53152 | 94.16346 | 216.1154 | 86.11915 | 21.77006 | 667.8342 |
+| Default [1GPU] | 1 | 500 | 55.16338 | 60877 | 54268 | 9.063984 | 983.7685 | 2087.345 | 24032.54 | 20746.58 | 51869.53 | 108.6506 | 98.5411 | 250.0097 | 132.4853 | 106.2516 | 1066.694 |
+| Default [1GPU] | 2 | 500 | 54.1481 | 60877 | 54413 | 9.233935 | 1004.892 | 2129.161 | 23329.79 | 20071.22 | 50569.95 | 106.4999 | 105.1927 | 237.7959 | 127.4949 | 89.68485 | 1064.654 |
+| Default [1GPU] | 3 | 500 | 54.69779 | 60877 | 53811 | 9.141137 | 983.7874 | 2096.757 | 24014.02 | 21001.2 | 51473.18 | 109.0092 | 110.4562 | 231.0923 | 128.1639 | 101.1179 | 1031.919 |
+| Default [1GPU] | mean_stats | | 54.66976 | 60877 | 54164 | 9.146352 | 990.816 | 2104.421 | 23792.12 | 20606.33 | 51304.22 | 108.0533 | 104.73 | 239.6326 | 129.3814 | 99.0181 | 1054.422 |
+| Default [1GPU] | 1 | 1000 | 105.1188 | 92971 | 106904 | 9.513042 | 1016.982 | 1901.419 | 50431.95 | 51536 | 101359.1 | 115.7794 | 117.7622 | 222.7387 | 142.6449 | 106.7326 | 1131.579 |
+| Default [1GPU] | 2 | 1000 | 102.5482 | 92971 | 106982 | 9.751516 | 1043.237 | 1949.845 | 48095.11 | 46742.29 | 98620.85 | 109.9126 | 107.0314 | 190.3935 | 129.0594 | 96.81393 | 1035.998 |
+| Default [1GPU] | 3 | 1000 | 112.4488 | 92971 | 107016 | 8.892937 | 951.6866 | 1778.472 | 53841.88 | 54992.64 | 108241 | 120.7289 | 126.1713 | 193.4413 | 147.7406 | 110.6369 | 1201.979 |
+| Default [1GPU] | mean_stats | | 106.7053 | 92971 | 106967.3 | 9.385832 | 1003.969 | 1876.579 | 50789.65 | 51090.31 | 102740.3 | 115.4736 | 116.9883 | 202.1912 | 139.815 | 104.7278 | 1123.185 |
+
+| Approach | Trial | num_prompts | Benchmark duration (s) | Total input tokens | Total generated tokens | Request throughput (req/s) | Output token throughput (tok/s) | Total Token throughput (tok/s) | Mean TTFT (ms) | Median TTFT (ms) | P99 TTFT (ms) | Mean TPOT (ms) | Median TPOT (ms) | P99 TPOT (ms) | Mean ITL (ms) | Median ITL (ms) | P99 ITL (ms) |
+| ------------------ | ---------- | ----------- | ---------------------- | ------------------ | ---------------------- | -------------------------- | ------------------------------- | ------------------------------ | -------------- | ---------------- | ------------- | -------------- | ---------------- | ------------- | ------------- | --------------- | ------------ |
+| E+PD [1GPU] | 1 | 100 | 12.03229 | 8122 | 10952 | 8.310967 | 910.2171 | 1585.234 | 6994.752 | 6383.804 | 10612.88 | 34.84306 | 35.47549 | 97.73405 | 34.04824 | 13.6935 | 276.8582 |
+| E+PD [1GPU] | 2 | 100 | 10.18145 | 8122 | 10913 | 9.821786 | 1071.852 | 1869.577 | 5035.682 | 5188.816 | 8728.965 | 39.86719 | 36.29167 | 156.4304 | 38.67256 | 13.81485 | 194.8142 |
+| E+PD [1GPU] | 3 | 100 | 10.32014 | 8122 | 10892 | 9.689795 | 1055.412 | 1842.418 | 5745.909 | 5403.587 | 8864.612 | 33.5687 | 34.10522 | 112.2893 | 32.83498 | 13.1678 | 314.57 |
+| E+PD [1GPU] | mean_stats | | 10.84463 | 8122 | 10919 | 9.274183 | 1012.494 | 1765.743 | 5925.448 | 5658.735 | 9402.153 | 36.09298 | 35.29079 | 122.1513 | 35.18526 | 13.55872 | 262.0808 |
+| E+PD [1GPU] | 1 | 200 | 22.53692 | 32055 | 22210 | 8.874328 | 985.4941 | 2407.827 | 10164.43 | 8225.639 | 19737.15 | 78.71672 | 83.81666 | 149.1261 | 87.53348 | 42.34258 | 349.4325 |
+| E+PD [1GPU] | 2 | 200 | 21.15139 | 32055 | 22113 | 9.455644 | 1045.463 | 2560.967 | 9476.332 | 8594.147 | 18818.54 | 80.28189 | 83.99973 | 165.4138 | 79.32952 | 41.26098 | 366.8436 |
+| E+PD [1GPU] | 3 | 200 | 21.79525 | 32055 | 22335 | 9.17631 | 1024.764 | 2495.498 | 9702.187 | 8326.214 | 19257.32 | 78.32132 | 92.71767 | 129.1543 | 91.03399 | 43.73197 | 346.4315 |
+| E+PD [1GPU] | mean_stats | | 21.82785 | 32055 | 22219.33 | 9.168761 | 1018.574 | 2488.097 | 9780.984 | 8382 | 19271.01 | 79.10664 | 86.84469 | 147.8981 | 85.96566 | 42.44518 | 354.2359 |
+| E+PD [1GPU] | 1 | 500 | 51.44422 | 60877 | 54201 | 9.719265 | 1053.588 | 2236.947 | 25547.28 | 23111.25 | 49217.38 | 98.04562 | 95.03304 | 177.3768 | 138.6956 | 96.89536 | 1005.7 |
+| E+PD [1GPU] | 2 | 500 | 48.27857 | 60877 | 53854 | 10.35656 | 1115.485 | 2376.437 | 27904.17 | 30772.62 | 46019.26 | 88.95721 | 78.50464 | 206.927 | 98.23104 | 77.5184 | 438.898 |
+| E+PD [1GPU] | 3 | 500 | 48.0366 | 60877 | 54391 | 10.40873 | 1132.282 | 2399.587 | 25145.74 | 25104.45 | 45700.12 | 92.33802 | 91.74887 | 206.2593 | 107.086 | 80.41453 | 510.236 |
+| E+PD [1GPU] | mean_stats | | 49.25313 | 60877 | 54148.67 | 10.16152 | 1100.452 | 2337.657 | 26199.06 | 26329.44 | 46978.92 | 93.11362 | 88.42885 | 196.8544 | 114.6709 | 84.94276 | 651.6112 |
+| E+PD [1GPU] | 1 | 1000 | 92.8552 | 92971 | 106896 | 10.76946 | 1151.212 | 2152.459 | 49511.54 | 52686.3 | 90166.05 | 99.1985 | 99.80546 | 166.7989 | 124.7338 | 90.17502 | 742.4403 |
+| E+PD [1GPU] | 2 | 1000 | 93.60591 | 92971 | 107227 | 10.68309 | 1145.515 | 2138.732 | 49018.63 | 45659.07 | 90859.77 | 97.11694 | 93.49865 | 189.8387 | 132.8123 | 87.51951 | 1008.385 |
+| E+PD [1GPU] | 3 | 1000 | 94.16372 | 92971 | 106684 | 10.6198 | 1132.963 | 2120.297 | 52104.5 | 56606.76 | 91620.82 | 98.5272 | 90.3015 | 202.946 | 128.5848 | 83.8681 | 1021.238 |
+| E+PD [1GPU] | mean_stats | | 93.54161 | 92971 | 106935.7 | 10.69078 | 1143.23 | 2137.163 | 50211.56 | 51650.71 | 90882.21 | 98.28088 | 94.5352 | 186.5278 | 128.7103 | 87.18754 | 924.021 |
+
+| Approach | Trial | num_prompts | Benchmark duration (s) | Total input tokens | Total generated tokens | Request throughput (req/s) | Output token throughput (tok/s) | Total Token throughput (tok/s) | Mean TTFT (ms) | Median TTFT (ms) | P99 TTFT (ms) | Mean TPOT (ms) | Median TPOT (ms) | P99 TPOT (ms) | Mean ITL (ms) | Median ITL (ms) | P99 ITL (ms) |
+| ------------------ | ---------- | ----------- | ---------------------- | ------------------ | ---------------------- | -------------------------- | ------------------------------- | ------------------------------ | -------------- | ---------------- | ------------- | -------------- | ---------------- | ------------- | ------------- | --------------- | ------------ |
+| E+PD [2GPU,1E,1PD] | 1 | 100 | 8.129621 | 8122 | 10841 | 12.3007 | 1333.518 | 2332.581 | 4531.025 | 4842.273 | 6653.774 | 20.48578 | 19.92048 | 48.2179 | 24.76422 | 11.64389 | 229.7043 |
+| E+PD [2GPU,1E,1PD] | 2 | 100 | 7.196747 | 8122 | 10820 | 13.89517 | 1503.457 | 2632.022 | 3723.191 | 3591.878 | 5752.759 | 22.33815 | 21.74556 | 39.84977 | 28.97187 | 11.99492 | 348.9685 |
+| E+PD [2GPU,1E,1PD] | 3 | 100 | 7.147378 | 8122 | 10879 | 13.99114 | 1522.097 | 2658.457 | 3348.468 | 3238.785 | 5766.111 | 22.58321 | 23.11859 | 45.2545 | 32.16768 | 11.61273 | 215.2784 |
+| E+PD [2GPU,1E,1PD] | mean_stats | | 7.491249 | 8122 | 10846.67 | 13.39567 | 1453.024 | 2541.02 | 3867.562 | 3890.979 | 6057.548 | 21.80238 | 21.59488 | 44.44073 | 28.63459 | 11.75051 | 264.6504 |
+| E+PD [2GPU,1E,1PD] | 1 | 200 | 16.66409 | 32055 | 21998 | 12.00186 | 1320.084 | 3243.681 | 7707.448 | 7896.659 | 13593.91 | 43.1641 | 44.17912 | 70.99135 | 67.42118 | 34.18124 | 329.3577 |
+| E+PD [2GPU,1E,1PD] | 2 | 200 | 14.63482 | 32055 | 22037 | 13.66604 | 1505.792 | 3696.116 | 6035.327 | 6079.81 | 11616.76 | 45.66734 | 47.01402 | 79.61327 | 70.14829 | 33.69296 | 358.0308 |
+| E+PD [2GPU,1E,1PD] | 3 | 200 | 15.0796 | 32055 | 22402 | 13.26295 | 1485.583 | 3611.302 | 6257.599 | 5918.491 | 11906.45 | 47.87966 | 50.82737 | 78.11015 | 68.863 | 36.27602 | 415.8599 |
+| E+PD [2GPU,1E,1PD] | mean_stats | | 15.45951 | 32055 | 22145.67 | 12.97695 | 1437.153 | 3517.033 | 6666.791 | 6631.654 | 12372.37 | 45.57036 | 47.34017 | 76.23825 | 68.81082 | 34.71674 | 367.7494 |
+| E+PD [2GPU,1E,1PD] | 1 | 500 | 33.87189 | 60877 | 54103 | 14.7615 | 1597.283 | 3394.555 | 17079.91 | 16107.95 | 31689.53 | 55.98321 | 62.27903 | 93.5782 | 101.1234 | 59.71449 | 723.1812 |
+| E+PD [2GPU,1E,1PD] | 2 | 500 | 36.32494 | 60877 | 54199 | 13.76465 | 1492.06 | 3167.961 | 17724.86 | 20153.1 | 33779.83 | 65.33763 | 61.43517 | 126.8949 | 95.23513 | 63.6576 | 727.3881 |
+| E+PD [2GPU,1E,1PD] | 3 | 500 | 32.99449 | 60877 | 54261 | 15.15405 | 1644.547 | 3489.613 | 16648.73 | 17887.73 | 30812.72 | 60.54447 | 62.02508 | 91.89419 | 97.36272 | 58.35994 | 917.2843 |
+| E+PD [2GPU,1E,1PD] | mean_stats | | 34.39711 | 60877 | 54187.67 | 14.56007 | 1577.964 | 3350.71 | 17151.17 | 18049.59 | 32094.03 | 60.62177 | 61.91309 | 104.1224 | 97.90707 | 60.57734 | 789.2845 |
+| E+PD [2GPU,1E,1PD] | 1 | 1000 | 60.60446 | 92971 | 106831 | 16.50044 | 1762.758 | 3296.82 | 32204.67 | 33271.47 | 58176.87 | 57.46417 | 56.7792 | 106.0846 | 141.4603 | 74.10948 | 1447.025 |
+| E+PD [2GPU,1E,1PD] | 2 | 999 | 61.5361 | 92962 | 107012 | 16.23437 | 1739.012 | 3249.702 | 28927.87 | 27331.07 | 59006.99 | 60.83669 | 58.55945 | 104.2718 | 155.272 | 91.53153 | 1701.048 |
+| E+PD [2GPU,1E,1PD] | 3 | 1000 | 67.93865 | 92971 | 106216 | 14.71916 | 1563.411 | 2931.866 | 34875.88 | 35455.82 | 64642.23 | 68.64089 | 69.29424 | 122.7087 | 160.1047 | 71.93224 | 1769.555 |
+| E+PD [2GPU,1E,1PD] | mean_stats | | 63.35973 | 92968 | 106686.3 | 15.81799 | 1688.394 | 3159.463 | 32002.81 | 32019.45 | 60608.7 | 62.31392 | 61.5443 | 111.0217 | 152.279 | 79.19108 | 1639.209 |
\ No newline at end of file
diff --git a/vllm/separated_encode/ec_transfer/connector/redis.py b/vllm/separated_encode/ec_transfer/connector/redis.py
new file mode 100644
index 000000000000..94852de167ae
--- /dev/null
+++ b/vllm/separated_encode/ec_transfer/connector/redis.py
@@ -0,0 +1,208 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Callable, Literal, Optional
+
+import msgpack_numpy
+import redis
+
+from vllm.config import VllmConfig
+from vllm.separated_encode.ec_transfer.connector.template import (
+ ECConnectorTemplate)
+from vllm.logger import init_logger
+import torch
+
+logger = init_logger(__name__)
+
+class RedisECConnector(ECConnectorTemplate):
+
+ def __init__(self,
+ vllm_config: "VllmConfig",
+ device: Optional[torch.device],
+ intra_instance_type: Literal["scheduler", "model-runner"],
+ preallocate_callback: Optional[Callable[[str, int, int, str],
+ None]],
+ injection_callback: Optional[Callable[
+ [str, int, torch.Tensor, str], None]],
+ redis_host: str = "localhost",
+ redis_port: int = 6379):
+
+ if redis_host is None or redis_port is None:
+ raise RuntimeError("Redis Encoder Cache Connector is used, "
+ "but redis_host or redis_port is not specified")
+
+ self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port)
+ self.rank = vllm_config.epd_disagg_config.epd_rank
+ super().__init__(
+ vllm_config,
+ device,
+ intra_instance_type,
+ preallocate_callback,
+ injection_callback,
+ )
+
+ def _get_request_ranks(self, request_id: str):
+ """Extract E_RANK and PD_RANK from a proxy-formatted request ID.
+
+ Extracts the request_id with format $ACTUAL_REQUEST_ID|$E_RANK|$PD_RANK
+
+ Args:
+ request_id: The formatted request ID string from the proxy.
+
+ Returns:
+ Tuple containing (E_RANK, PD_RANK).
+ """
+ result = request_id.split("|")
+ return int(result[-2]), int(result[-1])
+
+ def _send_prealloc_notification(self, request_id: str, input_id: int,
+ successful: bool, mm_hash: str) -> None:
+ """
+ Send pre-allocation notification from PD to E instance via Redis.
+
+ Notifies the encoder instance whether pre-allocation was successful
+ and whether the encoder cache should be sent.
+
+ Args:
+ request_id: The formatted request ID containing rank information.
+ input_id: Index of the multimodal input within the request.
+ successful: Whether pre-allocation succeeded and cache should be sent.
+ mm_hash: Hash of the multimodal input.
+ """
+ transfer_data = {
+ "request_id": request_id,
+ "input_id": input_id,
+ "successful": successful,
+ "mm_hash": mm_hash
+ }
+ rank = self._get_request_ranks(request_id)[0]
+ logger.debug(f"Sent prealloc notification -> {rank}, {request_id}, {successful}")
+ self.redis_client.lpush(f"prealloc{rank}",
+ msgpack_numpy.packb(transfer_data))
+
+ def _send_encoder_cache_metas(
+ self, request_id: str, input_id: int,
+ num_encoder_tokens: int, mm_hash: str
+ ) -> None:
+ """
+ Send encoder cache metadata from E to PD instance via Redis.
+
+ Transfers metadata needed for pre-allocating space for the encoder cache
+ on the prefill/decode instance.
+
+ Args:
+ request_id: The formatted request ID containing rank information.
+ input_id: Index of the multimodal input within the request.
+ num_encoder_tokens: Number of tokens in the encoder cache.
+ mm_hash: Hash of the multimodal input.
+ """
+ transfer_data = {
+ "request_id": request_id,
+ "input_id": input_id,
+ "num_encoder_tokens": num_encoder_tokens,
+ "mm_hash": mm_hash
+ }
+ rank = self._get_request_ranks(request_id)[1]
+ logger.debug(f"Sent encode cache metadata -> {rank}, {request_id}")
+ self.redis_client.lpush(f"cache_metas{rank}",
+ msgpack_numpy.packb(transfer_data))
+
+ def _send_encoder_cache(
+ self, request_id: str, input_id: int,
+ encoder_cache: torch.Tensor, mm_hash: str) -> None:
+ """
+ Send encoder cache tensor from E to PD instance via Redis.
+
+ Converts the encoder cache to CPU float16 numpy array before sending
+ to optimize transfer size.
+
+ Args:
+ request_id: The formatted request ID containing rank information.
+ input_id: Index of the multimodal input within the request.
+ encoder_cache: The encoder output tensor to transfer.
+ mm_hash: Hash of the multimodal input.
+ """
+ encoder_cache_numpy = encoder_cache.to("cpu", dtype=torch.float16).numpy()
+ transfer_data = msgpack_numpy.packb({
+ "request_id": request_id,
+ "input_id": input_id,
+ "encoder_cache": encoder_cache_numpy,
+ "mm_hash": mm_hash
+ })
+ rank = self._get_request_ranks(request_id)[1]
+ logger.debug(f"Sent encode cache -> {rank}, {request_id}")
+ self.redis_client.lpush(f"cache{rank}", transfer_data)
+
+ def _recv_prealloc_notification(
+ self, maybe_send_cache_callback: Callable[[str, int, bool, str],
+ None]) -> None:
+ """
+ Receive pre-allocation notification on E instance from Redis.
+
+ Blocks until a notification is received, then unpacks the data and
+ invokes the callback to handle cache sending logic.
+
+ Args:
+ maybe_send_cache_callback: Callback to determine whether to send
+ the encoder cache based on the pre-allocation result.
+ """
+ transfered_data = self.redis_client.blpop(f"prealloc{self.rank}")[1]
+ transfered_data = msgpack_numpy.unpackb(transfered_data, raw=False)
+ request_id, input_id, successful, mm_hash = (
+ transfered_data["request_id"],
+ transfered_data["input_id"],
+ transfered_data["successful"],
+ transfered_data["mm_hash"]
+ )
+ logger.debug(f"Received prealloc notif -> {self.rank}, {request_id}")
+ maybe_send_cache_callback(request_id, input_id, successful, mm_hash)
+
+ def _recv_encoder_cache_metas(
+ self, preallocate_callback: Callable[[str, int, int, str],
+ None]) -> None:
+ """
+ Receive encoder cache metadata on PD instance from Redis.
+
+ Blocks until metadata is received, then unpacks the data and invokes
+ the callback to pre-allocate space in the scheduler.
+
+ Args:
+ preallocate_callback: Scheduler callback to pre-allocate space
+ for the incoming encoder cache.
+ """
+ transfered_data = self.redis_client.blpop(f"cache_metas{self.rank}")[1]
+ transfered_data = msgpack_numpy.unpackb(transfered_data, raw=False)
+ request_id, input_id, num_encoder_tokens, mm_hash = (
+ transfered_data["request_id"],
+ transfered_data["input_id"],
+ transfered_data["num_encoder_tokens"],
+ transfered_data["mm_hash"]
+ )
+ logger.debug(f"Received encoder metadata -> {self.rank}, {request_id}")
+ preallocate_callback(request_id, input_id, num_encoder_tokens, mm_hash)
+
+ def _recv_encoder_cache(
+ self,
+ injection_callback: Callable[[str, int, torch.Tensor, str],None]
+ ) -> None:
+ """
+ Receive encoder cache tensor on PD instance from Redis.
+
+ Blocks until cache data is received, converts it from numpy back to
+ the appropriate torch tensor format, then invokes the injection callback.
+
+ Args:
+ injection_callback: Model runner callback to inject the encoder
+ cache into the cache dictionary.
+ """
+ transfered_data = self.redis_client.blpop(f"cache{self.rank}")[1]
+ transfered_data = msgpack_numpy.unpackb(transfered_data, raw=False)
+ request_id, input_id, encoder_cache, mm_hash = (
+ transfered_data["request_id"],
+ transfered_data["input_id"],
+ transfered_data["encoder_cache"],
+ transfered_data["mm_hash"]
+ )
+ encoder_cache = torch.from_numpy(encoder_cache).to(
+ device=self.device, dtype=self.dtype)
+ logger.debug(f"Received encoder cache -> {self.rank}, {request_id}")
+ injection_callback(request_id, input_id, encoder_cache, mm_hash)
diff --git a/vllm/separated_encode/ec_transfer/connector/template.py b/vllm/separated_encode/ec_transfer/connector/template.py
new file mode 100644
index 000000000000..e0972b8050ba
--- /dev/null
+++ b/vllm/separated_encode/ec_transfer/connector/template.py
@@ -0,0 +1,398 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import queue
+import threading
+from abc import ABC, abstractmethod
+from concurrent.futures import ThreadPoolExecutor
+from typing import Callable, Literal, Optional
+
+import torch
+
+from vllm.config import EPDDisaggConfig, VllmConfig
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+class ECConnectorTemplate(ABC):
+ """
+ Abstraction for the communication between the E instance and
+ P(or PD) instance, all encoder cache communication is handled
+ by this class.
+
+ ECConnector communication handling is executed in separate thread,
+ the cache injection, preallocation and allocation logic is handled
+ by the gpu model runner/scheduler function.
+
+ Send and receive logic are handled by specific implementations like
+ RedisECConnector, note that all _recv tasks are created in advance
+ on the class startup, and the number of _recv tasks is maintained
+ so it'll be better to remove timeout from you _recv functions
+ implementation.
+
+ Also the ECConnector move the encoder_cache dict in itself to handle
+ send encoder cache task.
+ """
+
+ def __init__(
+ self,
+ vllm_config: "VllmConfig",
+ device: Optional[torch.device],
+ intra_instance_type: Literal["scheduler", "model-runner"],
+ preallocate_callback: Optional[Callable[[str, int, int, str], None]],
+ injection_callback: Optional[Callable[[str, int, torch.Tensor, str],
+ None]],
+ ):
+ callback_mapping = {
+ ("encode", "scheduler"): (None, None),
+ ("encode", "model-runner"):
+ (self._recv_prealloc_notification, self._maybe_send_encoder_cache),
+ ("prefill", "scheduler"):
+ (self._recv_encoder_cache_metas, preallocate_callback),
+ ("prefill", "model-runner"):
+ (self._recv_encoder_cache, injection_callback),
+ ("prefill+decode", "scheduler"):
+ (self._recv_encoder_cache_metas, preallocate_callback),
+ ("prefill+decode", "model-runner"): (self._recv_encoder_cache,
+ injection_callback)
+ }
+ self.device = device
+ self.dtype = vllm_config.model_config.dtype
+
+ self.epd_disagg_config: EPDDisaggConfig
+ self.intra_instance_type: Literal["scheduler", "model-runner"]
+ self.inter_instance_type: Literal["encode", "prefill",
+ "prefill+decode"]
+ self.encoder_cache: dict[str, dict[int, torch.Tensor]]
+ self.send_executors: ThreadPoolExecutor
+ self.recv_executors: ThreadPoolExecutor
+
+ # Instance type and configs:
+ self.epd_disagg_config = vllm_config.epd_disagg_config
+ self.inter_instance_type = self.epd_disagg_config.instance_type
+ self.intra_instance_type = intra_instance_type
+
+ # Initialize main transfer processing components:
+ self.send_tasks_queue: queue.Queue = queue.Queue()
+ self.send_executors = ThreadPoolExecutor(
+ max_workers=self.epd_disagg_config.connector_workers_num
+ )
+
+ # Sanity check
+ assert self.epd_disagg_config.connector_workers_num > 0
+
+ # Arif: max_workers num must match with limiting semaphore value
+ # otherwise receive busy loop will infinitely create tasks for
+ # the self.recv_executors
+ self.recv_executors = ThreadPoolExecutor(
+ max_workers=self.epd_disagg_config.connector_workers_num + 1
+ )
+ self.send_worker = threading.Thread(target=self._send_event_loop)
+ self.recv_worker = threading.Thread(target=self._recv_event_loop)
+ self.target_recv_callback = callback_mapping.get(
+ (self.inter_instance_type, self.intra_instance_type))
+
+
+ self.limiting_semaphore = threading.Semaphore(
+ self.epd_disagg_config.connector_workers_num + 1
+ )
+
+ # Used on model runner of encode instance:
+ if (self.intra_instance_type == "model-runner"
+ and self.inter_instance_type == "encode"):
+ self.use_cache_lock: threading.Lock = threading.Lock()
+ self.cache_to_send: set = set()
+ self.cache_to_skip: set = set()
+ self.encoder_cache = {}
+ self.transfered_ids_lock: threading.Lock = threading.Lock()
+ self.transfered_ids = []
+
+ self.send_worker.start()
+ self.recv_worker.start()
+
+ @abstractmethod
+ def _send_prealloc_notification(self, request_id: str, input_id: int,
+ successful: bool, mm_hash: str) -> None:
+ """Send a pre-allocation completion notification.
+
+ This method sends a notification to signal that the pre-allocation of
+ space for an encoder cache, identified by request_id and input_id,
+ has been completed on the P(or PD) instance.
+
+ Args:
+ request_id: id of the encoder cache's request.
+ input_id: index of the mm input amoung request's mm inputs
+ successful: indicates whether we need to send the encoder cache or not
+ mm_hash: hash of the mm input
+
+ """
+ pass
+
+ @abstractmethod
+ def _send_encoder_cache_metas(self, request_id: str, input_id: int,
+ num_encoder_tokens: int, mm_hash: str) -> None:
+ """Send the metadata of an encoder cache.
+
+ This method is used to transfer the encoder cache's metadata.
+
+ Args:
+ request_id: id of the encoder cache's request.
+ input_id: index of the mm input amoung request's mm inputs
+ num_encoder_tokens: size of the encoder cache
+ mm_hash: hash of the mm input
+ """
+ pass
+
+ @abstractmethod
+ def _send_encoder_cache(
+ self, request_id: str, input_id: int,
+ encoder_cache: torch.Tensor, mm_hash: str
+ ) -> None:
+ """Send the encoder cache.
+
+ This method sends the computed encoder cache in NumPy float type
+ array.
+
+ Args:
+ request_id: id of the encoder cache's request.
+ input_id: index of the mm input amoung request's mm inputs
+ encoder_cache: encoder output
+ mm_hash: hash of the mm input
+ """
+ pass
+
+ @abstractmethod
+ def _recv_prealloc_notification(
+ self, maybe_send_cache_callback: Callable[[str, int, bool, str],
+ None]) -> None:
+ """Receive a pre-allocation completion notification.
+
+ This method invoke maybe_send_cache callback for any received
+ pre-allocation notification. Note that you don't need to call
+ it immediately, you can delay the invocation of the callback,
+ also this function is called in advance on the init startup.
+
+ Check the receiving logic of RedisECConnector and recv event loop
+ for more details.
+
+ Args:
+ maybe_send_cache_callback: A callback function within the ec
+ connector. This function either schedules encoder cache
+ sending or adds the requested encoder cache to the set of
+ pending/ignored requests.
+ """
+ pass
+
+ @abstractmethod
+ def _recv_encoder_cache_metas(
+ self, preallocate_callback: Callable[[str, int, int, str],
+ None]) -> None:
+ """Receives the encoder cache and calls preallocate callback
+
+ This method invokes the preallocate callback for any received
+ encoder cache. Note that you don't need to call
+ it immediately, you can delay the invocation of the callback,
+ also this function is called in advance on the init startup.
+
+ Check the receiving logic of RedisECConnector and recv event loop
+ for more details.
+
+ Args:
+ preallocate_callback: A callback function within the scheduler.
+ This function preallocates space for encoder cache in the
+ encoder cache manager within the scheduler.
+ """
+ pass
+
+ @abstractmethod
+ def _recv_encoder_cache(
+ self,
+ injection_callback: Callable[[str, int, torch.Tensor, str],None]
+ ) -> None:
+ """Receives the encoder cache and calls injection callback
+
+ This method invokes the injection callback for any received
+ encoder cache. Note that you don't need to call
+ it immediately, you can delay the invocation of the callback,
+ also this function is called in advance on the init startup.
+
+ Check the receiving logic of RedisECConnector and recv event loop
+ for more details.
+
+ Args:
+ injection_callback: A callback function within the model runner.
+ This function injects encoder cache into the encoder_cache
+ dictionary within the model runner.
+ """
+ pass
+
+ def add_encoder_cache(self, request_id: str, input_id: int,
+ encoder_cache: torch.Tensor, mm_hash: str):
+ """Add an encoder cache to the EC connector.
+
+ This method adds the encoder cache to the self.encoder_cache dictionary
+ if the encoder cache is not already present in the set of pending
+ requested encoder caches.
+
+ Args:
+ request_id: id of the encoder cache's request.
+ input_id: index of the mm input amoung request's mm inputs
+ encoder_cache: encoder cache in numpy array form
+ """
+ with self.use_cache_lock:
+ if (request_id, input_id) in self.cache_to_send:
+ self.schedule_send_encoder_cache(request_id=request_id,
+ input_id=input_id,
+ encoder_cache=encoder_cache,
+ mm_hash=mm_hash)
+ self.cache_to_send.remove((request_id, input_id))
+ elif (request_id, input_id) in self.cache_to_skip:
+ with self.transfered_ids_lock:
+ self.transfered_ids.append((request_id, input_id))
+ self.cache_to_skip.remove((request_id, input_id))
+ else:
+ if request_id not in self.encoder_cache:
+ self.encoder_cache[request_id] = {}
+ self.encoder_cache[request_id][input_id] = encoder_cache
+
+ def _maybe_send_encoder_cache(
+ self, request_id: str, input_id: int, successful: bool, mm_hash: str
+ ):
+ """Sends the encoder cache or adds it to the pending send set
+
+ This method schedules the task of sending the encoder cache if it was
+ already been calculated. If the cache is not available, the method adds
+ the request to the set of pending sends.
+
+ Args:
+ request_id: id of the encoder cache's request.
+ input_id: index of the mm input amoung request's mm inputs
+ successful: indicates whether we need to send the encoder cache or not
+ mm_hash: hash of the mm input
+ """
+ with self.use_cache_lock:
+ if (request_id in self.encoder_cache
+ and input_id in self.encoder_cache[request_id]):
+ if successful:
+ self.schedule_send_encoder_cache(
+ request_id, input_id,
+ self.encoder_cache.get(request_id).get(input_id),
+ mm_hash)
+ else:
+ with self.transfered_ids_lock:
+ self.transfered_ids.append((request_id, input_id))
+ self.encoder_cache[request_id].pop(input_id)
+ if not self.encoder_cache[request_id]:
+ self.encoder_cache.pop(request_id)
+ else:
+ if successful:
+ self.cache_to_send.add((request_id, input_id))
+ else:
+ self.cache_to_skip.add((request_id, input_id))
+
+ def _send_event_loop(self, ):
+ """Run receive event loop
+
+ This method runs event loop for send tasks.
+ """
+ try:
+ while True:
+ callback, args = self.send_tasks_queue.get()
+ self.send_executors.submit(callback, *args)
+ except Exception as e:
+ raise ConnectionError("Error during send event loop.") from e
+
+ def _limiting_wrapper(self, callback: Callable, arg: Callable):
+ """Wrapper function to limit the number of workers """
+ with self.limiting_semaphore:
+ callback(arg)
+
+ def _recv_event_loop(self, ):
+ """Run receive event loop
+
+ This method runs event loop for receive tasks and ensures that
+ the number of requested parallel receives is limited by
+ $max_connector_workers.
+ """
+ try:
+ if self.target_recv_callback[0] is None:
+ return
+ while True:
+ callback, arg = self.target_recv_callback
+ with self.limiting_semaphore:
+ self.recv_executors.submit(self._limiting_wrapper,
+ callback, arg)
+ except Exception as e:
+ raise ConnectionError("Error during recv event loop") from e
+
+ def schedule_send_prealloc_notification(self, request_id: str, input_id: int,
+ successful: bool, mm_hash: str) -> None:
+ """Schedule preallocate completion notification sending
+
+ This method schedules the task of sending preallocate completion
+ notification to the encoder model runner(instance E).
+
+ Args:
+ request_id: id of the encoder cache's request.
+ input_id: index of the mm input amoung request's mm inputs
+ successful: indicates whether we need to send the encoder cache or not
+ mm_hash: hash of the mm input
+ """
+ self.send_tasks_queue.put_nowait(
+ (self._send_prealloc_notification,
+ (request_id, input_id, successful, mm_hash)))
+
+ def schedule_send_encoder_cache_metadata(self, request_id: str,
+ input_id: int,
+ num_encoder_tokens: int,
+ mm_hash: str) -> None:
+ """Schedule encoder cache metadata sending
+
+ This method schedules the task of sending encoder cache's metadata
+ for the encoder cache space preallocation.
+
+ Args:
+ request_id: id of the encoder cache's request.
+ input_id: index of the mm input amoung request's mm inputs
+ num_encoder_tokens: size of the encoder cache
+ mm_hash: hash of the mm input
+ """
+ self.send_tasks_queue.put_nowait(
+ (self._send_encoder_cache_metas, (request_id, input_id,
+ num_encoder_tokens, mm_hash)))
+
+ def schedule_send_encoder_cache(
+ self, request_id: str, input_id: int,
+ encoder_cache: torch.Tensor, mm_hash: str
+ ) -> None:
+ """Schedule encoder cache sending
+
+ This method schedules the task of sending encoder cache.
+
+ Args:
+ request_id: id of the encoder cache's request.
+ input_id: index of the mm input amoung request's mm inputs
+ encoder_cache: encoder output
+ """
+ self.send_tasks_queue.put_nowait(
+ (self._finish_wrapper, (self._send_encoder_cache, request_id,
+ input_id, encoder_cache, mm_hash)))
+
+ def _finish_wrapper(
+ self, callback: Callable, request_id: str, input_id: int,
+ encoder_cache: torch.Tensor, mm_hash: str
+ ):
+ """
+ Wrapper to fill the transfered_ids list
+ """
+ callback(request_id, input_id, encoder_cache, mm_hash)
+ with self.transfered_ids_lock:
+ self.transfered_ids.append((request_id, input_id))
+
+ def get_transfered_ids(self, ):
+ """
+ Method to get transfered ids
+ """
+ with self.transfered_ids_lock:
+ transfered_ids = self.transfered_ids
+ self.transfered_ids = []
+ return transfered_ids
\ No newline at end of file
diff --git a/vllm/separated_encode/epd.png b/vllm/separated_encode/epd.png
new file mode 100644
index 000000000000..5c059fa70e9f
Binary files /dev/null and b/vllm/separated_encode/epd.png differ
diff --git a/vllm/separated_encode/sched/encoder_cache_preallocator.py b/vllm/separated_encode/sched/encoder_cache_preallocator.py
new file mode 100644
index 000000000000..8a4154f135b1
--- /dev/null
+++ b/vllm/separated_encode/sched/encoder_cache_preallocator.py
@@ -0,0 +1,314 @@
+
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from __future__ import annotations
+from abc import ABC, abstractmethod
+import os
+import queue
+import threading
+from typing import Callable
+from collections import defaultdict
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.separated_encode.ec_transfer.connector.redis import (
+ RedisECConnector)
+from vllm.v1.request import Request
+
+logger = init_logger(__name__)
+
+
+class EncoderCachePreallocatorTemplate(ABC):
+ """Abstract base class for encoder cache preallocation strategies.
+
+ Defines the interface for managing encoder cache preallocation in
+ disaggregated deployments. Concrete implementations handle the coordination
+ between encoding and prefill instances, ensuring cache space is reserved
+ before encoder outputs are transferred.
+
+ This template provides the connection infrastructure through RedisECConnector
+ and defines the required methods that subclasses must implement to handle
+ request lifecycle, preallocation scheduling, and metadata reception.
+
+ Attributes:
+ ec_connector: Redis-based connector for communication between instances.
+ Handles metadata reception and preallocation notifications.
+ """
+ def __init__(self, vllm_config: VllmConfig):
+ """Initialize the preallocator with ECConnector."""
+ self.ec_connector = RedisECConnector(
+ vllm_config = vllm_config,
+ device = None,
+ # no need to pass device if intra_instance_type scheduler
+ intra_instance_type = "scheduler",
+ preallocate_callback = self._receive_encoder_cache_metadata,
+ injection_callback = None,
+ redis_host=os.getenv("REDIS_HOST"),
+ redis_port=os.getenv("REDIS_PORT"),
+ )
+
+ @abstractmethod
+ def is_empty(self,) -> bool:
+ """Check if there are pending preallocations to process.
+ """
+ pass
+
+ @abstractmethod
+ def add_request(self, request: Request):
+ """Register a new request for preallocation tracking.
+
+ Called when a request arrives at the prefill instance. Implementations
+ should initialize tracking structures and process any metadata that
+ arrived before the request.
+ """
+ pass
+
+ @abstractmethod
+ def finish_request(self, request: Request):
+ """Clean up resources for a finished or cancelled request.
+
+ Called when a request completes, is cancelled, or is aborted.
+ Implementations should remove all tracking information and handle
+ any pending preallocations for this request.
+ """
+ pass
+
+ @abstractmethod
+ def update_mm_inputs_done(self, request: Request):
+ """Update multimodal input processing progress for a request.
+ """
+ pass
+
+ @abstractmethod
+ def _receive_encoder_cache_metadata(
+ self, req_id: str, input_id: int, size: int, mm_hash: str
+ ):
+ """Handle incoming encoder cache metadata from encoding instance.
+
+ Callback method invoked by ec_connector when metadata arrives.
+ Implementations should handle both cases where the request is
+ already active and where metadata arrives before the request.
+ """
+ pass
+
+class SyncEncoderCachePreallocator(EncoderCachePreallocatorTemplate):
+ """Synchronous preallocation manager for encoder cache in disaggregated systems.
+
+ This class coordinates the preallocations of encoder cache space between
+ encoding and prefill instances in a disaggregated deployment. It ensures
+ that cache space is reserved before encoder outputs are transferred between
+ instances, manages the lifecycle of multimodal inputs across distributed
+ processing stages, handles out-of-order arrival of encoder metadata and
+ ensures proper synchronization between request arrival and encoder cache
+ metadata receiving.
+
+ Key responsibilities:
+ - Queue and schedule preallocation requests
+ - Track multimodal input processing progress for each request
+ - Handle metadata that arrives before or after request initialization
+ - Send notification through EC Connector
+ - Provide requests with inputs that are ready for preallocations
+ - Clean up resources for finished & cancelled requests
+
+ Attributes:
+ active_requests: Set of request IDs currently being processed.
+ mm_inputs_done: Maps request ID to number of processed multimodal inputs.
+ mm_inputs_total: Maps request ID to total number of multimodal inputs.
+ preallocs_queue: Queue of pending preallocation requests.
+ prealloc_candidate: Current preallocation being considered for processing.
+ pending_preallocs: Maps request ID to set of input IDs awaiting preallocation.
+ waiting_preallocs: Stores preallocation data for requests not yet received on instance.
+ ignored_preallocs: Set of (request_id, input_id) pairs to skip.
+ received_metas_reqs: Set of (request_id, input_id) pairs with received metadata.
+ """
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ perform_allocation: Callable,
+ ):
+ super().__init__(vllm_config)
+
+ self.active_requests: set[str] = set()
+
+ self.mm_inputs_done = defaultdict(int)
+ self.mm_inputs_total = defaultdict(int)
+
+ self.preallocs_queue = queue.Queue()
+ self.prealloc_candidate = None
+ self.pending_preallocs: dict[str, set[int]] = {}
+
+ self.waiting_preallocs: dict[str, list[int]] = {}
+
+ self.ignored_preallocs = set()
+ self.received_metas_reqs = set()
+
+ self.recv_lock = threading.Lock()
+ self.scheduling_lock = threading.Lock()
+
+ def is_empty(self, ):
+ return (self.preallocs_queue.qsize() == 0)
+
+ def finish_request(self, request: Request):
+ with self.recv_lock:
+ if request.request_id in self.waiting_preallocs:
+ self.waiting_preallocs.pop(request.request_id)
+ if request.request_id in self.pending_preallocs:
+ self.pending_preallocs.pop(request.request_id)
+ for _ in range(self.mm_inputs_total[request.request_id]):
+ # Clean ignored_preallocs later, currently we assume that
+ # all mm_inputs will come to the instance at some moment
+ if (request.request_id, _) in self.received_metas_reqs:
+ self.received_metas_reqs.remove((request.request_id, _))
+ continue
+ self.ignored_preallocs.add((request.request_id, _))
+ self.mm_inputs_done.pop(request.request_id)
+ self.mm_inputs_total.pop(request.request_id)
+ self.active_requests.remove(request.request_id)
+
+ def _schedule_prealloc_request(self, req_id: str, input_id: int,
+ size: int, mm_hash: str):
+ """Schedule a preallocation request for processing.
+
+ Internal method that adds a preallocation request to the queue and
+ tracks it in pending_preallocs.
+ """
+ if req_id not in self.pending_preallocs:
+ self.pending_preallocs[req_id] = set()
+ self.pending_preallocs[req_id].add(input_id)
+ self.preallocs_queue.put_nowait((req_id, input_id, size, mm_hash))
+
+ def _receive_encoder_cache_metadata(self, req_id: str, input_id: int,
+ size: int, mm_hash: str):
+ """Handle incoming encoder cache metadata from encoding instance.
+
+ This callback processes metadata about encoder outputs that need to be
+ transferred. If the request is active, it schedules preallocation. If not,
+ it stores the metadata for later processing when the request arrives.
+ """
+
+ with self.scheduling_lock:
+ with self.recv_lock:
+ if (req_id, input_id) in self.ignored_preallocs:
+ # if request is not active/data is obtained from KV cache
+ self.ignored_preallocs.remove((req_id, input_id))
+ self.ec_connector.schedule_send_prealloc_notification(
+ req_id, input_id, False, mm_hash
+ )
+ return
+ self.received_metas_reqs.add((req_id, input_id))
+ if req_id not in self.active_requests:
+ if req_id not in self.waiting_preallocs:
+ self.waiting_preallocs[req_id] = []
+ self.waiting_preallocs[req_id].append((input_id, size))
+ return
+
+ self._schedule_prealloc_request(req_id, input_id, size, mm_hash)
+
+ def add_request(self, request: Request):
+ """Register a new request and process any waiting preallocations.
+
+ When a request arrives, this method initializes tracking structures and
+ processes any encoder metadata that arrived before the request.
+ """
+ with self.recv_lock:
+ req_id = request.request_id
+ self.active_requests.add(req_id)
+ self.mm_inputs_done[req_id] = 0
+ self.mm_inputs_total[req_id] = len(request.mm_hashes)
+ if req_id not in self.waiting_preallocs:
+ return
+ for (input_id, size) in self.waiting_preallocs[req_id]:
+ mm_hash = request.mm_hashes[input_id]
+ self._schedule_prealloc_request(req_id, input_id, size, mm_hash)
+ self.waiting_preallocs.pop(req_id)
+
+ def update_mm_inputs_done(self, request: Request):
+ """Update the progress of multimodal input processing for a request.
+
+ Tracks which multimodal inputs have been fully processed based on the
+ number of computed tokens. For inputs that were prealloc candidates but
+ are now obtained from cache, sends notifications to cancel transfers.
+ """
+
+ if not request.has_encoder_inputs:
+ return
+
+ with self.scheduling_lock:
+ req_id = request.request_id
+ mm_inputs_done_local = self.mm_inputs_done[req_id]
+
+ while mm_inputs_done_local < self.mm_inputs_total[req_id]:
+ mm_hash = request.mm_hashes[mm_inputs_done_local]
+ pos_info = request.mm_positions[mm_inputs_done_local]
+ mm_inputs_end = pos_info.offset + pos_info.length
+ if mm_inputs_end > request.num_computed_tokens:
+ break
+
+ if (req_id in self.pending_preallocs
+ and mm_inputs_done_local in self.pending_preallocs[req_id]
+ ):
+ self.pending_preallocs[req_id].remove(mm_inputs_done_local)
+ self.ec_connector.schedule_send_prealloc_notification(
+ req_id, mm_inputs_done_local, False, mm_hash
+ )
+ self.ignored_preallocs.add((req_id, mm_inputs_done_local))
+ mm_inputs_done_local += 1
+
+ self.mm_inputs_done[req_id] = mm_inputs_done_local
+
+ def get_prealloc_candidate(
+ self, free_space: int, fill_next: bool
+ ) -> tuple[bool, tuple[str, int, int, str] | None]:
+ """Validate the preallocation candidate, fill the next preallocation
+ candidate
+
+ Validate current preallocation candidate, retrieves the next
+ preallocation request from the queue. Skips ignored preallocations
+ and checks whether prellocated data will fit in space constraints.
+
+ Args:
+ free_space: Available cache space in encoder tokens.
+ fill_next: Whether to fetch the next candidate after processing.
+
+ Returns:
+ Tuple of (should_continue, candidate_data) where:
+ - should_continue: True if caller should continue preallocations,
+ False if caller should stop.
+ - candidate_data: None or tuple of (request_id, input_id,
+ num_encoder_tokens, mm_hash)
+ """
+ with self.scheduling_lock:
+ if self.prealloc_candidate is None:
+ if fill_next is True:
+ self.prealloc_candidate = self.preallocs_queue.get()
+ return (True, None) # No candidate, just get next candidate
+
+ (request_id, input_id, num_encoder_tokens, mm_hash) = \
+ self.prealloc_candidate
+ if num_encoder_tokens > free_space:
+ return (False, None)
+
+ if fill_next is True:
+ self.prealloc_candidate = self.preallocs_queue.get()
+ else:
+ self.prealloc_candidate = None
+
+ if (request_id, input_id) in self.ignored_preallocs:
+ self.ignored_preallocs.remove((request_id, input_id))
+ return (True, None) # Skip and get next
+
+ self.pending_preallocs[request_id].remove(input_id)
+
+ return (True, (request_id, input_id, num_encoder_tokens, mm_hash))
+
+ def send_prealloc_notification(
+ self,
+ req_id: str,
+ input_id: int,
+ is_receiving_required: bool,
+ mm_hash: str
+ ):
+ """Send a preallocation notification to the encoding instance."""
+ self.ec_connector.schedule_send_prealloc_notification(
+ req_id, input_id, is_receiving_required, mm_hash
+ )
\ No newline at end of file
diff --git a/vllm/separated_encode/sched/encoder_scheduler.py b/vllm/separated_encode/sched/encoder_scheduler.py
new file mode 100644
index 000000000000..453666e0b4cf
--- /dev/null
+++ b/vllm/separated_encode/sched/encoder_scheduler.py
@@ -0,0 +1,389 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from __future__ import annotations
+
+import os
+import time
+from collections import defaultdict
+from collections.abc import Iterable
+from typing import Any, Optional, Union
+
+from vllm.config import VllmConfig
+
+from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
+ KVConnectorRole)
+from vllm.logger import init_logger
+from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
+from vllm.separated_encode.ec_transfer.connector.redis import (
+ RedisECConnector)
+from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
+ compute_encoder_budget)
+from vllm.v1.core.sched.interface import SchedulerInterface
+from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
+ SchedulerOutput)
+from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
+ create_request_queue)
+from vllm.v1.core.sched.utils import check_stop
+from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
+ EngineCoreOutputs)
+from vllm.v1.kv_cache_interface import KVCacheConfig
+from vllm.v1.metrics.stats import SchedulerStats
+from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
+from vllm.v1.request import Request, RequestStatus
+from vllm.v1.spec_decode.metrics import SpecDecodingStats
+from vllm.v1.structured_output import StructuredOutputManager
+
+logger = init_logger(__name__)
+
+
+class EncoderScheduler(SchedulerInterface):
+
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ kv_cache_config: KVCacheConfig,
+ structured_output_manager: StructuredOutputManager,
+ mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
+ include_finished_set: bool = False,
+ log_stats: bool = False,
+ ) -> None:
+ self.vllm_config = vllm_config
+ self.scheduler_config = vllm_config.scheduler_config
+ self.cache_config = vllm_config.cache_config
+ self.parallel_config = vllm_config.parallel_config
+ self.log_stats = log_stats
+ self.structured_output_manager = structured_output_manager
+ self.epd_disagg_config = vllm_config.epd_disagg_config
+
+ # include_finished_set controls whether a separate set of finished
+ # request ids should be included in the EngineCoreOutputs returned
+ # by update_from_outputs(). This is currently used in the multi-engine
+ # case to track request lifetimes efficiently.
+ self.finished_req_ids_dict: Optional[dict[int, set[str]]] = (
+ defaultdict(set) if include_finished_set else None)
+
+ # Scheduling constraints.
+ self.max_num_running_reqs = self.scheduler_config.max_num_seqs
+ self.max_num_scheduled_tokens = \
+ self.scheduler_config.max_num_batched_tokens
+ self.max_model_len = self.scheduler_config.max_model_len
+
+ # req_id -> Request
+ self.requests: dict[str, Request] = {}
+ self.policy = SchedulingPolicy.FCFS
+ self.waiting = create_request_queue(self.policy)
+ self.running: list[Request] = []
+
+ self.finished_req_ids: set[str] = set()
+
+ encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
+ model_config=vllm_config.model_config,
+ scheduler_config=vllm_config.scheduler_config,
+ mm_registry=mm_registry,
+ )
+
+ self.max_num_encoder_input_tokens = encoder_compute_budget
+
+ self.encoder_cache_manager = EncoderCacheManager(
+ cache_size=encoder_cache_size*128)
+ self.use_pp = self.parallel_config.pipeline_parallel_size > 1
+
+ self.separated_encode = True
+ self.instance_type = self.epd_disagg_config.instance_type
+ if self.instance_type != "encode":
+ raise RuntimeError("Incorrect instance initialization")
+
+ self.ec_connector = RedisECConnector(
+ vllm_config = self.vllm_config,
+ device=None,
+ # no need to pass device if intra_instance_type scheduler
+ intra_instance_type = "scheduler",
+ preallocate_callback = None,
+ injection_callback = None,
+ redis_host=os.getenv("REDIS_HOST"),
+ redis_port=os.getenv("REDIS_PORT"),
+ )
+ self._allocated: dict[str, dict[int, tuple[int, str]]] = {}
+
+ def schedule(self) -> SchedulerOutput:
+ scheduled_new_reqs: list[Request] = []
+
+ token_budget = self.max_num_scheduled_tokens
+
+ # Encoder-related.
+ scheduled_encoder_inputs: dict[str, list[int]] = {}
+ encoder_compute_budget = self.max_num_encoder_input_tokens
+
+ # For logging.
+ scheduled_timestamp = time.monotonic()
+ # mm input is processed in 1 step.
+ while self.waiting and token_budget > 0:
+ if len(self.running) == self.max_num_running_reqs:
+ break
+
+ request = self.waiting.peek_request()
+ if not request.has_encoder_inputs:
+ raise RuntimeError("Request without encoder input")
+
+ new_encoder_compute_budget = encoder_compute_budget
+ #Schedule all mm inputs at once:
+ mm_hashes_to_schedule = set()
+ mm_positions = request.mm_positions
+
+ num_tokens_to_schedule = 0
+ can_allocate_all = True
+ encoder_inputs_to_schedule = []
+ is_cached = []
+
+ for input_id, pos_info in enumerate(mm_positions):
+ num_encoder_tokens = pos_info.length
+ if (
+ request.mm_hashes[input_id] in mm_hashes_to_schedule
+ or self.encoder_cache_manager.check_and_update_cache(
+ request, input_id
+ )
+ ):
+ # On Encoder instance we need to send all inputs to model runner
+ # because we need to pass (req_id, input_id) to model runner's
+ # ec connector, to send the cache to PD instance, so we will add
+ # it to the scheduled encoder inputs without changing budget
+ # and in model runner we will just skip all calculated values
+ encoder_inputs_to_schedule.append(input_id)
+ is_cached.append(True)
+ continue
+ if not self.encoder_cache_manager.can_allocate(
+ request=request,
+ input_id=input_id,
+ encoder_compute_budget=new_encoder_compute_budget,
+ num_tokens_to_schedule=num_tokens_to_schedule,
+ ):
+ can_allocate_all = False
+ break
+ num_tokens_to_schedule += num_encoder_tokens
+ new_encoder_compute_budget -= num_encoder_tokens
+ encoder_inputs_to_schedule.append(input_id)
+ is_cached.append(False)
+
+ # NOTE: Note that all updates from loop above are not applied
+ # if we can't allocate all mm_inputs
+ if not can_allocate_all:
+ break
+
+ request = self.waiting.pop_request()
+ self.running.append(request)
+
+ if self.log_stats:
+ request.record_event(EngineCoreEventType.SCHEDULED,
+ scheduled_timestamp)
+ if request.status == RequestStatus.WAITING:
+ scheduled_new_reqs.append(request)
+ else:
+ raise RuntimeError(
+ f"Invalid request status: {request.status}")
+
+ request.status = RequestStatus.RUNNING
+ req_id = request.request_id
+ scheduled_encoder_inputs[req_id] = encoder_inputs_to_schedule
+
+ # Allocate the encoder cache.
+ for input_id, is_cached_input in zip(encoder_inputs_to_schedule, is_cached):
+ mm_hash = request.mm_hashes[input_id]
+ num_encoder_tokens = request.get_num_encoder_tokens(input_id)
+ if not is_cached_input:
+ self.encoder_cache_manager.allocate(request, input_id)
+ self.ec_connector.schedule_send_encoder_cache_metadata(
+ req_id,
+ input_id,
+ num_encoder_tokens,
+ mm_hash
+ )
+ if not req_id in self._allocated:
+ self._allocated[req_id] = {}
+ self._allocated[req_id][input_id] = (num_encoder_tokens, mm_hash)
+ encoder_compute_budget = new_encoder_compute_budget
+
+
+ assert len(self.running) <= self.max_num_running_reqs
+
+ new_reqs_data = [
+ NewRequestData.from_request(req, ([],))
+ for req in scheduled_new_reqs
+ ]
+
+ scheduler_output = SchedulerOutput(
+ scheduled_new_reqs=new_reqs_data,
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
+ num_scheduled_tokens={},
+ total_num_scheduled_tokens=0,
+ scheduled_spec_decode_tokens={},
+ scheduled_encoder_inputs=scheduled_encoder_inputs,
+ num_common_prefix_blocks=0,
+ finished_req_ids=self.finished_req_ids,
+ free_encoder_mm_hashes=self.encoder_cache_manager.\
+ get_freed_mm_hashes(),
+ structured_output_request_ids={},
+ grammar_bitmask=None,
+ )
+
+ self.finished_req_ids = set()
+ return scheduler_output
+
+ def update_from_output(
+ self,
+ scheduler_output: SchedulerOutput,
+ model_runner_output: ModelRunnerOutput,
+ ) -> dict[int, EngineCoreOutputs]:
+
+ # clean up the logic space of mm_data that was transfered
+ transfered_mm_data = model_runner_output.transfered_mm_data
+
+ for (req_id, input_id) in transfered_mm_data:
+ assert req_id in self._allocated
+ assert input_id in self._allocated[req_id]
+ cache_size, mm_hash = self._allocated[req_id][input_id]
+ self._allocated[req_id].pop(input_id)
+ if not self._allocated[req_id]:
+ self._allocated.pop(req_id)
+ self.encoder_cache_manager.free_encoder_input_after_finish(
+ req_id, cache_size, mm_hash
+ )
+
+ outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
+
+ # stop all requests from the current batch
+ model_finished = []
+ for request in self.running:
+ req_id = request.request_id
+ model_finished.append(req_id)
+ outputs[request.client_index].append(
+ EngineCoreOutput(request_id=req_id,
+ new_token_ids=[],
+ finish_reason=RequestStatus.get_finished_reason(
+ RequestStatus.FINISHED_STOPPED
+ ),
+ stop_reason="stop",
+ kv_transfer_params={}
+ )
+ )
+ self.finish_requests(model_finished, RequestStatus.FINISHED_STOPPED)
+ # Create EngineCoreOutputs for all clients that have requests with
+ # outputs in this step.
+ engine_core_outputs = {
+ client_index: EngineCoreOutputs(outputs=outs)
+ for client_index, outs in outputs.items()
+ }
+
+ if engine_core_outputs:
+ # Return stats to only one of the front-ends.
+ next(iter(engine_core_outputs.values())).scheduler_stats = (
+ self.make_stats(None))
+
+ return engine_core_outputs
+
+ def add_request(self, request: Request) -> None:
+ self.waiting.add_request(request)
+ self.requests[request.request_id] = request
+ if self.log_stats:
+ request.record_event(EngineCoreEventType.QUEUED)
+
+ def finish_requests(
+ self,
+ request_ids: Union[str, Iterable[str]],
+ finished_status: RequestStatus,
+ ) -> None:
+ """Handles the finish signal from outside the scheduler.
+
+ For example, the API server can abort a request when the client
+ disconnects.
+ """
+ assert RequestStatus.is_finished(finished_status)
+
+ if isinstance(request_ids, str):
+ request_ids = (request_ids, )
+ else:
+ request_ids = set(request_ids)
+
+ running_requests_to_remove = []
+ waiting_requests_to_remove = []
+ valid_requests = []
+
+ # First pass: collect requests to remove from queues
+ for req_id in request_ids:
+ request = self.requests.get(req_id)
+ if request is None:
+ # Invalid request ID.
+ continue
+
+ valid_requests.append(request)
+ if request.status == RequestStatus.RUNNING:
+ running_requests_to_remove.append(request)
+ else:
+ waiting_requests_to_remove.append(request)
+
+ # Remove all requests from queues at once for better efficiency
+ for request in running_requests_to_remove:
+ self.running.remove(request)
+ if waiting_requests_to_remove:
+ self.waiting.remove_requests(waiting_requests_to_remove)
+
+ # Second pass: set status and free requests
+ for request in valid_requests:
+ request.status = finished_status
+ self._free_request(request)
+
+ def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
+ assert request.is_finished()
+ request_id = request.request_id
+ self.finished_req_ids.add(request_id)
+ if self.finished_req_ids_dict is not None:
+ self.finished_req_ids_dict[request.client_index].add(request_id)
+ del self.requests[request.request_id]
+ return None
+
+# no changes v
+
+ def get_request_counts(self) -> tuple[int, int]:
+ """Returns (num_running_reqs, num_waiting_reqs)."""
+ return len(self.running), len(self.waiting)
+
+ def get_num_unfinished_requests(self) -> int:
+ return len(self.waiting) + len(self.running)
+
+ def has_finished_requests(self) -> bool:
+ return len(self.finished_req_ids) > 0
+
+ def make_stats(
+ self,
+ spec_decoding_stats: Optional[SpecDecodingStats] = None,
+ ) -> Optional[SchedulerStats]:
+ if not self.log_stats:
+ return None
+ return SchedulerStats(
+ num_running_reqs=len(self.running),
+ num_waiting_reqs=len(self.waiting),
+ )
+
+# Placeholder functions v
+ def make_spec_decoding_stats(
+ self,
+ spec_decoding_stats: Optional[SpecDecodingStats],
+ num_draft_tokens: int,
+ num_accepted_tokens: int,
+ ) -> Optional[SpecDecodingStats]:
+ return None
+
+ def shutdown(self) -> None:
+ pass
+
+ def get_kv_connector(self) -> Optional[KVConnectorBase_V1]:
+ return None
+
+ def reset_prefix_cache(self) -> bool:
+ pass
+
+ def update_draft_token_ids(
+ self,
+ draft_token_ids: "DraftTokenIds",
+ ) -> None:
+ pass
\ No newline at end of file
diff --git a/vllm/separated_encode/worker/gpu_epd_lm_wrapper.py b/vllm/separated_encode/worker/gpu_epd_lm_wrapper.py
new file mode 100644
index 000000000000..ba6ff9b03c26
--- /dev/null
+++ b/vllm/separated_encode/worker/gpu_epd_lm_wrapper.py
@@ -0,0 +1,114 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import os
+from typing import TYPE_CHECKING, Literal, Optional, Union
+
+import torch
+
+from vllm.config import EPDDisaggConfig, VllmConfig
+from vllm.logger import init_logger
+from vllm.separated_encode.ec_transfer.connector.redis import (
+ RedisECConnector)
+from vllm.separated_encode.ec_transfer.connector.template import (
+ ECConnectorTemplate)
+from vllm.sequence import IntermediateTensors
+from vllm.v1.outputs import ModelRunnerOutput
+from vllm.v1.worker.gpu_model_runner import GPUModelRunner
+
+if TYPE_CHECKING:
+ from vllm.v1.core.sched.output import SchedulerOutput
+
+logger = init_logger(__name__)
+
+class DisaggPrefillDecodeGPURunnerWrapper(GPUModelRunner):
+ """
+ GPU model runner wrapper for disaggregated Language Model processing.
+
+ This class extends GPUModelRunner to support Encode-Prefill-Decode (EPD)
+ disaggregation by handling remote encoder cache injection and management.
+ It integrates with encoder cache connectors to receive and process encoder
+ outputs from remote encoder instances.
+
+ The runner maintains encoder cache state and coordinates with the scheduler
+ to track successful encoder cache injections for multimodal processing.
+
+ Attributes:
+ epd_disagg_config : Configuration for EPD disaggregation
+ ec_connector : Connector for encoder cache transfer
+ injected_encoder_cache_ids : List of successfully injected encoder
+ cache identifiers
+ instance_type : Type of processing instance this runner handles
+ """
+
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ device: torch.device,
+ ):
+ super().__init__(vllm_config, device)
+ self.epd_disagg_config: EPDDisaggConfig
+ self.ec_connector: ECConnectorTemplate
+ self.epd_disagg_config: EPDDisaggConfig
+ self.injected_encoder_cache_ids: list[tuple[str, int]]
+ self.instance_type: Literal["NoEPD", "prefill+decode", "prefill",
+ "encode"]
+ self.epd_disagg_config = vllm_config.epd_disagg_config
+
+ assert self.epd_disagg_config.instance_type != "NoEPD",\
+ "Can't use LM instance without EPD disaggregation"
+
+ self.instance_type = vllm_config.epd_disagg_config.instance_type
+ self.ec_connector = RedisECConnector(
+ vllm_config=vllm_config,
+ device=device,
+ intra_instance_type="model-runner",
+ preallocate_callback=None,
+ injection_callback=self.receive_encoder_cache,
+ redis_host=os.getenv("REDIS_HOST"),
+ redis_port=os.getenv("REDIS_PORT"),
+ )
+ self.injected_encoder_cache_ids = []
+
+ @torch.inference_mode()
+ def execute_model(
+ self,
+ scheduler_output: "SchedulerOutput",
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ ) -> Union[ModelRunnerOutput, IntermediateTensors]:
+ """
+ Executes the model and includes injected encoder cache information.
+
+ Extends the base execute_model functionality to track and report
+ encoder cache injections that occurred during model execution.
+ The injected encoder cache IDs are included in the model output
+ to inform the scheduler about successful cache injections.
+ """
+ model_runner_output = super().execute_model(scheduler_output,
+ intermediate_tensors)
+ injected_encoder_cache_ids = None
+ with self.encoder_cache_lock:
+ injected_encoder_cache_ids = self.injected_encoder_cache_ids
+ self.injected_encoder_cache_ids = []
+ model_runner_output.injected_mm_data = injected_encoder_cache_ids
+ return model_runner_output
+
+ def receive_encoder_cache(
+ self,
+ request_id: str,
+ input_id: int,
+ encoder_cache: torch.Tensor,
+ mm_hash: str
+ ):
+ """
+ Callback function for receiving encoder cache from remote instances.
+
+ This method is invoked by the encoder cache connector when encoder
+ cache data is received from remote encoder instances, then It stores
+ received tensor in the local encoder_cache dictionary.
+
+ The method updates the injected encoder cache IDs list to inform the
+ scheduler about successful cache injections.
+ """
+ with self.encoder_cache_lock:
+ self.encoder_cache[mm_hash] = encoder_cache
+ self.injected_encoder_cache_ids.append((request_id, input_id, mm_hash))
diff --git a/vllm/separated_encode/worker/gpu_epd_vm_wrapper.py b/vllm/separated_encode/worker/gpu_epd_vm_wrapper.py
new file mode 100644
index 000000000000..9f81678d308f
--- /dev/null
+++ b/vllm/separated_encode/worker/gpu_epd_vm_wrapper.py
@@ -0,0 +1,168 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import os
+from typing import TYPE_CHECKING, Literal, Optional, Union
+
+import torch
+
+from vllm.config import EPDDisaggConfig, VllmConfig
+from vllm.logger import init_logger
+from vllm.separated_encode.ec_transfer.connector.redis import (
+ RedisECConnector)
+from vllm.separated_encode.ec_transfer.connector.template import (
+ ECConnectorTemplate)
+from vllm.sequence import IntermediateTensors
+from vllm.v1.kv_cache_interface import KVCacheConfig
+from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
+from vllm.v1.worker.gpu_input_batch import CachedRequestState
+from vllm.v1.worker.gpu_model_runner import GPUModelRunner
+
+if TYPE_CHECKING:
+ from vllm.v1.core.sched.output import SchedulerOutput
+
+logger = init_logger(__name__)
+
+
+class DisaggEncodeGPURunnerWrapper(GPUModelRunner):
+ """
+ GPU model runner wrapper for disaggregated Vision/Encoder processing.
+
+ This class extends GPUModelRunner to support encoder-only processing in
+ Encode-Prefill-Decode (EPD) disaggregation. It handles multimodal encoder
+ execution and transfers the resulting encoder caches to remote instances
+ for further processing.
+
+ This wrapper focuses on encoder processing and does not initialize KV cache
+ since it doesn't perform language model inference.
+
+ Attributes:
+ epd_disagg_config: Configuration for EPD disaggregation
+ ec_connector: Connector for encoder cache transfer
+ instance_type: Type of processing instance this runner handles
+ """
+
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ device: torch.device,
+ ):
+ super().__init__(vllm_config, device)
+
+ self.epd_disagg_config: EPDDisaggConfig
+ self.ec_connector: ECConnectorTemplate
+ self.epd_disagg_config: EPDDisaggConfig = vllm_config.epd_disagg_config
+ self.instance_type: Literal["NoEPD", "prefill+decode", "prefill",
+ "encode"]
+
+ assert self.epd_disagg_config.instance_type != "NoEPD",\
+ "Can't use Encode instance without EPD disaggregation"
+
+ self.instance_type = vllm_config.epd_disagg_config.instance_type
+ self.ec_connector = RedisECConnector(
+ vllm_config=vllm_config,
+ device=device,
+ intra_instance_type="model-runner",
+ preallocate_callback=None,
+ injection_callback=None,
+ redis_host=os.getenv("REDIS_HOST"),
+ redis_port=os.getenv("REDIS_PORT"),
+ )
+
+ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
+ """
+ Updates internal request states based on scheduler output.
+
+ Manages the lifecycle of requests by removing finished
+ requests and adding newly scheduled requests. This method maintains
+ the requests needed for encoder processing.
+ """
+ for req_id in scheduler_output.finished_req_ids:
+ self.requests.pop(req_id, None)
+
+ for mm_hash in scheduler_output.free_encoder_mm_hashes:
+ self.encoder_cache.pop(mm_hash)
+
+ for new_req_data in scheduler_output.scheduled_new_reqs:
+ self.requests[new_req_data.req_id] = CachedRequestState(
+ req_id=new_req_data.req_id,
+ prompt_token_ids=new_req_data.prompt_token_ids,
+ mm_kwargs=new_req_data.mm_kwargs,
+ mm_positions=new_req_data.mm_positions,
+ mm_hashes = new_req_data.mm_hashes,
+ sampling_params=None,
+ pooling_params=None,
+ generator=None,
+ block_ids=[],
+ num_computed_tokens=0,
+ output_token_ids=[],
+ )
+
+ @torch.inference_mode()
+ def execute_model(
+ self,
+ scheduler_output: "SchedulerOutput",
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ ) -> Union[ModelRunnerOutput, IntermediateTensors]:
+ """
+ Executes encoder processing and schedules results' transfer.
+
+ This method handles the core encoder execution workflow by updating
+ internal request states, executing multimodal encoders for scheduled
+ inputs, and transferring computed encoder caches to remote instances
+ via a connector, while providing transfer status information to the
+ scheduler.
+ """
+ self._update_states(scheduler_output)
+ old_scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
+ new_scheduled_encoder_inputs = {}
+
+ # Erase cached inputs to execute mm encoder without repeated cache inputs
+ going_to_be_executed = set()
+ for req_id, mm_input_ids in old_scheduled_encoder_inputs.items():
+ mm_hashes = self.requests[req_id].mm_hashes
+ uncached_inputs = []
+ for input_id in mm_input_ids:
+ mm_hash = mm_hashes[input_id]
+ if ((not mm_hash in self.encoder_cache)
+ and (mm_hash not in going_to_be_executed)):
+ uncached_inputs.append(input_id)
+ going_to_be_executed.add(mm_hash)
+ new_scheduled_encoder_inputs[req_id] = uncached_inputs
+
+ scheduler_output.scheduled_encoder_inputs = new_scheduled_encoder_inputs
+
+ self._execute_mm_encoder(scheduler_output)
+
+ scheduler_output.scheduled_encoder_inputs = old_scheduled_encoder_inputs
+ del new_scheduled_encoder_inputs
+
+ scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
+
+ for req_id, mm_input_ids in scheduled_encoder_inputs.items():
+ mm_hashes = self.requests[req_id].mm_hashes
+ for input_id in mm_input_ids:
+ mm_hash = mm_hashes[input_id]
+ self.ec_connector.add_encoder_cache(
+ req_id,
+ input_id,
+ self.encoder_cache[mm_hash],
+ mm_hash
+ )
+
+ transfered_ids = self.ec_connector.get_transfered_ids()
+
+ # Initialize the model runner output with default values
+ # provides better compatibility with vLLM ModelRunnerOutput changes
+ model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
+ # Assign transferred mm data
+ model_runner_output.transfered_mm_data = transfered_ids
+
+ return model_runner_output
+
+ # Don't initialize of KV cache on encode instance
+ def initialize_kv_cache_tensors(
+ self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
+ return {}
+
+ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
+ return None
diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py
index bd2ec036834b..a7ae2c5d8499 100644
--- a/vllm/v1/core/encoder_cache_manager.py
+++ b/vllm/v1/core/encoder_cache_manager.py
@@ -66,6 +66,9 @@ def __init__(self, cache_size: int):
# mm_hash of mm_data => ids of requests that reference the mm_data
self.cached: dict[str, set[str]] = {}
+
+ # Blocking
+ self.preallocated: dict[str, dict[str, dict[int, int]]] = {}
# mm_hash of mm_data => num_encoder_tokens of the mm_data
self.freeable: OrderedDict[str, int] = OrderedDict()
@@ -92,56 +95,35 @@ def check_and_update_cache(self, request: Request, input_id: int) -> bool:
return False
# Cached but currently not referenced by any request
- if not self.cached[mm_hash]:
+ # If mm_hash is in preallocated then it will not be in freeable
+ if not self.cached[mm_hash] and mm_hash not in self.preallocated:
num_tokens = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_tokens
self.cached[mm_hash].add(request.request_id)
return True
- def can_allocate(self, request: Request, input_id: int,
- encoder_compute_budget: int,
- num_tokens_to_schedule: int) -> bool:
- """Check if there's sufficient cache space for a multimodal input.
- If there is, return True and update EncoderCacheManager state.
-
- If there is not enough free space in `num_free_slots` but there is
- enough reclaimable space in `num_freeable_slots`, entries will be
- evicted from `freeable` (their mm_hash appended to `freed`) until
- enough space is available, and then this method returns True.
- Older entries are evicted first.
-
- Returns False only if the requested number of tokens exceeds both
- the free and reclaimable capacities combined.
+ def can_allocate_tokens(self, num_tokens):
+ """Check if the specified number of tokens can be allocated in the cache.
+
+ This method determines whether there is sufficient cache capacity to store
+ the requested number of encoder tokens. If there isn't enough free space
+ but there is enough reclaimable space, it will evict entries from the
+ freeable list to make room.
Args:
- request: The request containing the multimodal input.
- input_id: Index of the multimodal input within the request.
- encoder_compute_budget: Number of encoder tokens allowed to be
- computed when this method is invoked.
- num_tokens_to_schedule: Number of tokens already scheduled to be
- allocated with cache space when this method is invoked.
+ num_tokens: The number of encoder tokens to allocate.
Returns:
- True if there's enough capacity to hold the encoder output for this
- input (possibly after reclaiming `freeable` entries); otherwise
- False.
-
- Note: This method does not allocate physical memory for the encoder
- output but only the state of EncoderCacheManager.
+ True if the tokens can be allocated (either immediately or after
+ eviction); False if there isn't enough total capacity even after
+ reclaiming all freeable entries.
"""
- num_tokens = request.get_num_encoder_tokens(input_id)
-
- # Not enough compute budget
- if num_tokens > encoder_compute_budget:
- return False
-
- num_tokens += num_tokens_to_schedule
# Enough free slots
if num_tokens <= self.num_free_slots:
return True
-
+
# Not enough reclaimable slots
if num_tokens > self.num_freeable_slots:
return False
@@ -156,6 +138,48 @@ def can_allocate(self, request: Request, input_id: int,
self.num_free_slots += num_free_token
return True
+
+ def can_allocate(self, request: Request, input_id: int,
+ encoder_compute_budget: int,
+ num_tokens_to_schedule: int) -> bool:
+ """Check if there's sufficient cache space for a multimodal input.
+
+ This method verifies if the encoder output for the specified input
+ can be allocated given the current compute budget and cache capacity
+ constraints. It does not modify the cache state.
+
+ The method checks two constraints:
+ 1. Compute budget: Whether the input's encoder tokens fit within the
+ available encoder compute budget.
+ 2. Cache capacity: Whether there's enough space (free or reclaimable)
+ to store the encoder output, accounting for tokens already scheduled.
+
+ Args:
+ request: The request containing the multimodal input.
+ input_id: Index of the multimodal input within the request.
+ encoder_compute_budget: Number of encoder tokens allowed to be
+ computed when this method is invoked.
+ num_tokens_to_schedule: Number of tokens already scheduled to be
+ allocated with cache space when this method is invoked.
+
+ Returns:
+ True if both compute budget and cache capacity constraints are
+ satisfied; False otherwise.
+
+ Note:
+ - This method only checks feasibility without modifying cache state.
+ - Actual eviction (if needed) happens in can_allocate_tokens().
+ - The allocate() method should be called to reserve the space after
+ this check passes.
+ """
+ num_tokens = request.get_num_encoder_tokens(input_id)
+
+ # Not enough compute budget
+ if num_tokens > encoder_compute_budget:
+ return False
+
+ return self.can_allocate_tokens(num_tokens + num_tokens_to_schedule)
+
def allocate(self, request: Request, input_id: int) -> None:
"""Allocate cache space for a multimodal input's encoder output.
@@ -196,7 +220,7 @@ def get_cached_input_ids(self, request: Request) -> set[int]:
for input_id in range(len(request.mm_hashes))
if request.mm_hashes[input_id] in self.cached
}
-
+
def free_encoder_input(self, request: Request, input_id: int) -> None:
"""Free the request's reference to the encoder input (`mm_data`)
@@ -213,7 +237,7 @@ def free_encoder_input(self, request: Request, input_id: int) -> None:
if not self.cached.get(mm_hash, None):
return
self.cached[mm_hash].discard(req_id)
- if not self.cached[mm_hash]:
+ if not self.cached[mm_hash] and mm_hash not in self.preallocated:
num_tokens = request.get_num_encoder_tokens(input_id)
self.freeable[mm_hash] = num_tokens
self.num_freeable_slots += num_tokens
@@ -243,6 +267,125 @@ def get_freed_mm_hashes(self) -> list[str]:
freed = self.freed
self.freed = []
return freed
+
+ ########################################################################
+ # Encode-Prefill-Decode Disaggregation Related Methods
+ ########################################################################
+
+ def free_encoder_input_after_finish(
+ self, req_id: str, num_tokens: int, mm_hash: str
+ ) -> None:
+ """Free a request's reference in cached dictionary for mm input after
+ request is finished long time ago.
+
+ Removes the request ID from the cached mm input's reference set. When
+ the reference set becomes empty AND the entry is not preallocated by
+ any pending requests, the entry is marked as freeable.
+
+ This method is used in disaggregated settings where the request object
+ may not be available, requiring explicit parameters for the multimodal
+ input metadata.
+
+ Args:
+ req_id: ID of the request releasing the reference.
+ num_tokens: Number of encoder tokens associated with this input.
+ mm_hash: Hash identifier of the multimodal input data.
+
+ Note:
+ The entry is NOT physically freed until capacity is needed (e.g., by
+ `can_allocate_tokens`). Entries that are preallocated remain
+ unfreeable even with zero references to prevent premature eviction.
+ """
+ # The mm_hash not in cache or the req_id set is empty
+ if not self.cached.get(mm_hash, None):
+ return
+ self.cached[mm_hash].discard(req_id)
+ if not self.cached[mm_hash] and mm_hash not in self.preallocated:
+ self.freeable[mm_hash] = num_tokens
+ self.num_freeable_slots += num_tokens
+
+ def preallocate(self, req_id: str, input_id: int,
+ num_tokens: int, mm_hash: str) -> bool:
+ """Reserve cache space for an encoder input before actual allocation.
+
+ Used in disaggregated settings to coordinate cache allocation across
+ different processing stage instances. Helps in prevention of premature
+ eviction of entries that will be needed and tracks which requests will
+ use which inputs.
+
+ Args:
+ req_id: ID of the request making the preallocation.
+ input_id: Index of the multimodal input within the request.
+ num_tokens: Number of encoder tokens to preallocate.
+ mm_hash: Hash identifier of the multimodal input data.
+
+ Returns:
+ True if encoder cache needs to be received (entry not cached),
+ False if entry is already cached or will be provided by another
+ request.
+ """
+
+ is_mm_hash_preallocated = (mm_hash in self.preallocated)
+ is_cached = (mm_hash in self.cached)
+ is_referenced = (bool(self.cached[mm_hash]) if is_cached else False)
+
+ # Add mm_input preallocation fact to self.preallocated
+ if not is_mm_hash_preallocated:
+ self.preallocated[mm_hash] = {}
+
+ preallocated_reqs = self.preallocated[mm_hash]
+ if req_id not in preallocated_reqs:
+ preallocated_reqs[req_id] = {}
+ preallocated_reqs[req_id][input_id] = num_tokens
+
+ if is_cached:
+ # Block freeableness of the targeted mm_hash if it's freeable
+ if not (is_referenced or is_mm_hash_preallocated):
+ num_tokens = self.freeable.pop(mm_hash)
+ self.num_freeable_slots -= num_tokens
+ return False
+ elif not is_mm_hash_preallocated:
+ self.num_free_slots -= num_tokens
+ self.num_freeable_slots -= num_tokens
+ return True
+
+ # Already preallocated in past, not cached, that means that encoder
+ # cache will be injected by some other (req_id, input_id) pair
+ return False
+
+ def finalize_allocation(
+ self, req_id: str, input_id: int, mm_hash: str, skipped: bool
+ ) -> None:
+ """Complete the allocation process for a preallocated encoder input.
+
+ Converts a preallocation into an actual allocation or releases the
+ preallocation if it was skipped. This method is called after the
+ encoder cache is injected.
+
+ Args:
+ req_id: ID of the request finalizing allocation.
+ input_id: Index of the multimodal input within the request.
+ mm_hash: Hash identifier of the multimodal input data.
+ skipped: True if this request skipped encoding (e.g., another
+ request provided the cached data), False otherwise.
+ """
+ preallocated_reqs = self.preallocated[mm_hash]
+ num_tokens = preallocated_reqs[req_id].pop(input_id)
+ is_preallocated = True
+
+ if not preallocated_reqs[req_id]:
+ preallocated_reqs.pop(req_id)
+ if not self.preallocated[mm_hash]:
+ self.preallocated.pop(mm_hash)
+ is_preallocated = False
+
+ if mm_hash not in self.cached:
+ self.cached[mm_hash] = set()
+ if not skipped:
+ self.cached[mm_hash].add(req_id)
+ elif not self.cached[mm_hash] and not is_preallocated:
+ self.freeable[mm_hash] = num_tokens
+ self.num_freeable_slots += num_tokens
def compute_encoder_budget(
diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
index 3bd2fe2f0515..3b6258a9e5a0 100644
--- a/vllm/v1/core/sched/scheduler.py
+++ b/vllm/v1/core/sched/scheduler.py
@@ -4,6 +4,7 @@
from __future__ import annotations
import itertools
+import threading
import time
from collections import defaultdict
from collections.abc import Iterable
@@ -17,6 +18,9 @@
KVConnectorRole)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
+from vllm.separated_encode.sched.encoder_cache_preallocator import (
+ SyncEncoderCachePreallocator
+)
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
@@ -59,6 +63,7 @@ def __init__(
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager
self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder
+ self.epd_disagg_config = vllm_config.epd_disagg_config
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
@@ -142,7 +147,7 @@ def __init__(
# the encoder cache will not be initialized because cache size is 0
# for these models.
self.encoder_cache_manager = EncoderCacheManager(
- cache_size=encoder_cache_size)
+ cache_size=encoder_cache_size*10)
speculative_config = vllm_config.speculative_config
self.use_eagle = False
@@ -163,6 +168,20 @@ def __init__(
enable_kv_cache_events=self.enable_kv_cache_events,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
+ self.separated_encode: bool = False
+
+ if (
+ self.epd_disagg_config.instance_type == "prefill+decode"
+ or self.epd_disagg_config.instance_type == "prefill"
+ ):
+ # Set max_num_encoder_input_tokens to avoid
+ # encoder execution on PD or P instance.
+ self.ec_preallocator: SyncEncoderCachePreallocator
+ self.max_num_encoder_input_tokens = 0
+ self.separated_encode = True
+ self.ec_preallocator = SyncEncoderCachePreallocator(
+ vllm_config, self._perform_preallocations)
+ self.mutex = threading.Lock()
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
@@ -222,7 +241,6 @@ def schedule(self) -> SchedulerOutput:
) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens,
encoder_compute_budget)
-
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
@@ -615,6 +633,9 @@ def _update_after_schedule(
request = self.requests[req_id]
request.num_computed_tokens += num_scheduled_token
+ if self.separated_encode:
+ self.ec_preallocator.update_mm_inputs_done(request)
+
# NOTE: _free_encoder_inputs relies on num_computed_tokens, which
# may be updated again in _update_from_output for speculative
# decoding. However, it is safe to call the method here because
@@ -623,6 +644,10 @@ def _update_after_schedule(
if request.has_encoder_inputs:
self._free_encoder_inputs(request)
+ if self.separated_encode:
+ # Finalize allocations or get rid of them
+ self._perform_preallocations()
+
# Clear the finished request IDs.
# NOTE: We shouldn't do self.finished_req_ids.clear() here because
# it will also affect the scheduler output.
@@ -788,7 +813,6 @@ def _try_schedule_encoder_inputs(
encoder_compute_budget -= num_encoder_tokens
mm_hashes_to_schedule.add(request.mm_hashes[i])
encoder_inputs_to_schedule.append(i)
-
return (
encoder_inputs_to_schedule,
num_new_tokens,
@@ -836,6 +860,16 @@ def update_from_output(
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
+ injected_mm_data = model_runner_output.injected_mm_data
+
+ if self.separated_encode:
+ for (req_id, input_id, mm_hash) in injected_mm_data:
+ is_skipped = not (self.ec_preallocator.mm_inputs_done[req_id] <= input_id)
+ self.encoder_cache_manager.finalize_allocation(
+ req_id, input_id, mm_hash, is_skipped
+ )
+ # Finalize allocations or get rid of them
+ self._perform_preallocations()
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None
@@ -1052,6 +1086,8 @@ def get_request_counts(self) -> tuple[int, int]:
return len(self.running), len(self.waiting)
def add_request(self, request: Request) -> None:
+ if self.separated_encode:
+ self.ec_preallocator.add_request(request)
self.waiting.add_request(request)
self.requests[request.request_id] = request
if self.log_stats:
@@ -1103,7 +1139,8 @@ def finish_requests(
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
assert request.is_finished()
-
+ if self.separated_encode:
+ self.ec_preallocator.finish_request(request)
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request)
request_id = request.request_id
@@ -1243,3 +1280,43 @@ def _update_from_kv_xfer_finished(self,
for req_id in (kv_connector_output.finished_sending or ()):
logger.debug("Finished sending KV transfer for request %s", req_id)
self._free_blocks(self.requests[req_id])
+
+ ########################################################################
+ # Encoder Cache related methods
+ ########################################################################
+
+ def _perform_preallocations(self, ):
+ if self.mutex.locked():
+ return
+ with self.mutex:
+ while not self.ec_preallocator.is_empty():
+ prealloc, candidate = self.ec_preallocator.get_prealloc_candidate(
+ self.encoder_cache_manager.num_free_slots, fill_next = True)
+ if not prealloc: # can't preallocate
+ return
+ if candidate is not None:
+ self.encoder_cache_manager.can_allocate_tokens(candidate[2])
+ is_receiving_required = self.encoder_cache_manager.preallocate(
+ *candidate
+ )
+ self.ec_preallocator.send_prealloc_notification(
+ candidate[0],
+ candidate[1],
+ is_receiving_required,
+ candidate[3],
+ )
+ # last element
+
+ prealloc, candidate = self.ec_preallocator.get_prealloc_candidate(
+ self.encoder_cache_manager.num_free_slots, fill_next = False)
+ if (candidate is not None):
+ self.encoder_cache_manager.can_allocate_tokens(candidate[2])
+ is_receiving_required = self.encoder_cache_manager.preallocate(
+ *candidate
+ )
+ self.ec_preallocator.send_prealloc_notification(
+ candidate[0],
+ candidate[1],
+ is_receiving_required,
+ candidate[3],
+ )
\ No newline at end of file
diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py
index a7038e2d2c26..5a80e9ea6ff5 100644
--- a/vllm/v1/engine/core.py
+++ b/vllm/v1/engine/core.py
@@ -23,6 +23,7 @@
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import receiver_cache_from_config
+from vllm.separated_encode.sched.encoder_scheduler import EncoderScheduler
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
@@ -112,6 +113,9 @@ def __init__(self,
"This scheduler interface is not public and "
"compatibility may not be maintained.",
vllm_config.scheduler_config.scheduler_cls)
+
+ if vllm_config.epd_disagg_config.instance_type == "encode":
+ Scheduler = EncoderScheduler
if len(kv_cache_config.kv_cache_groups) == 0:
# Encoder models without KV cache don't support
diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py
index f8d6b24702f3..504688b59b93 100644
--- a/vllm/v1/outputs.py
+++ b/vllm/v1/outputs.py
@@ -113,6 +113,12 @@ class ModelRunnerOutput:
# req_id -> num_nans_in_logits
num_nans_in_logits: Optional[dict[str, int]] = None
+ # EPD transfered mm data: (req_id, input_id)
+ transfered_mm_data: Optional[list[tuple[str, int]]] = None
+
+ # EPD Injected mm data: (req_id, input_id, mm_hash)
+ injected_mm_data: Optional[list[tuple[str, int, str]]] = None
+
@dataclass
class DraftTokenIds:
@@ -129,4 +135,6 @@ class DraftTokenIds:
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
- num_nans_in_logits=None)
+ num_nans_in_logits=None,
+ transfered_mm_data=None,
+ injected_mm_data=None)
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 0250a4e19a02..1733d43a5922 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -4,6 +4,7 @@
import gc
import itertools
import time
+import threading
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
@@ -182,6 +183,8 @@ def __init__(
# mm_hash -> encoder_output
self.encoder_cache: dict[str, torch.Tensor] = {}
+ # EPD: encoder cache lock for safety during encoder cache injections
+ self.encoder_cache_lock: threading.Lock = threading.Lock()
self.use_aux_hidden_state_outputs = False
# Set up speculative decoding.
@@ -323,6 +326,14 @@ def __init__(
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
+
+ # EPD disaggregation
+ self.epd_disagg_config = vllm_config.epd_disagg_config
+ if self.epd_disagg_config.instance_type == "NoEPD":
+ self.is_mm_encoder_exec_allowed = True
+ else:
+ self.is_mm_encoder_exec_allowed = (
+ self.epd_disagg_config.instance_type == "encode")
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
@@ -412,8 +423,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
new/resumed/paused/finished request in the batch.
"""
# Remove finished requests from the cached states.
- for req_id in scheduler_output.finished_req_ids:
- self.requests.pop(req_id, None)
+ with self.encoder_cache_lock:
+ for req_id in scheduler_output.finished_req_ids:
+ self.requests.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -1113,6 +1125,8 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
+ assert self.is_mm_encoder_exec_allowed, \
+ "Encoder execution is not allowed on this instance"
# Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]()
# list of tuple (mm_hash, position_info)
@@ -1202,7 +1216,8 @@ def _gather_mm_embeddings(
assert start_idx < end_idx
mm_hash = mm_hashes[i]
- encoder_output = self.encoder_cache.get(mm_hash, None)
+ with self.encoder_cache_lock:
+ encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}."
diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
index c25219331334..e28629acfcf4 100644
--- a/vllm/v1/worker/gpu_worker.py
+++ b/vllm/v1/worker/gpu_worker.py
@@ -5,7 +5,7 @@
import gc
import os
from contextlib import AbstractContextManager, nullcontext
-from typing import TYPE_CHECKING, Any, Optional
+from typing import TYPE_CHECKING, Any, Optional, Union
import torch
import torch.distributed
@@ -23,6 +23,10 @@
from vllm.model_executor import set_random_seed
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform
+from vllm.separated_encode.worker.gpu_epd_lm_wrapper import (
+ DisaggPrefillDecodeGPURunnerWrapper)
+from vllm.separated_encode.worker.gpu_epd_vm_wrapper import (
+ DisaggEncodeGPURunnerWrapper)
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
@@ -93,6 +97,11 @@ def __init__(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
+ self.separated_encode: bool = False
+ if self.vllm_config.epd_disagg_config.instance_type != "NoEPD":
+ self.separated_encode = True
+ self.instance_type = \
+ self.vllm_config.epd_disagg_config.instance_type
def sleep(self, level: int = 1) -> None:
from vllm.device_allocator.cumem import CuMemAllocator
@@ -198,8 +207,18 @@ def init_device(self):
set_random_seed(self.model_config.seed)
# Construct the model runner
- self.model_runner: GPUModelRunner = GPUModelRunner(
- self.vllm_config, self.device)
+ self.model_runner: Union[DisaggEncodeGPURunnerWrapper, GPUModelRunner,
+ DisaggPrefillDecodeGPURunnerWrapper]
+ model_runner_class = GPUModelRunner
+ # When EPD disaggregation is enabled, the system uses wrapper classes of GPUModelRunner
+ if self.separated_encode:
+ if self.instance_type == "encode":
+ model_runner_class = DisaggEncodeGPURunnerWrapper
+ elif (self.instance_type == "prefill+decode"
+ or self.instance_type == "prefill"):
+ model_runner_class = DisaggPrefillDecodeGPURunnerWrapper
+
+ self.model_runner = model_runner_class(self.vllm_config, self.device)
if self.rank == 0:
# If usage stat is enabled, collect relevant info.