diff --git a/.github/workflows/schedule_nightly_test_a3.yaml b/.github/workflows/schedule_nightly_test_a3.yaml index ed0b4383616..5f7dce93d9a 100644 --- a/.github/workflows/schedule_nightly_test_a3.yaml +++ b/.github/workflows/schedule_nightly_test_a3.yaml @@ -126,6 +126,9 @@ jobs: - name: qwen2-5-vl-7b os: linux-aarch64-a3-4 tests: tests/e2e/nightly/single_node/models/test_qwen2_5_vl_7b.py + - name: qwen2-5-vl-7b-epd + os: linux-aarch64-a3-4 + tests: tests/e2e/nightly/single_node/models/test_qwen2_5_vl_7b_epd.py - name: qwen2-5-vl-32b os: linux-aarch64-a3-4 tests: tests/e2e/nightly/single_node/models/test_qwen2_5_vl_32b.py diff --git a/.github/workflows/scripts/config.yaml b/.github/workflows/scripts/config.yaml index 9323a20eae0..a7f0cf490c5 100644 --- a/.github/workflows/scripts/config.yaml +++ b/.github/workflows/scripts/config.yaml @@ -126,6 +126,8 @@ e2e-multicard-2-cards: estimated_time: 1050 - name: tests/e2e/multicard/2-cards/test_single_request_aclgraph.py estimated_time: 215 + - name: tests/e2e/multicard/2-cards/test_disaggregated_encoder.py + estimated_time: 90 e2e-multicard-4-cards: # TODO: recover skipped tests diff --git a/examples/disaggregated_encoder/disagg_1e1pd_example.sh b/examples/disaggregated_encoder/disagg_1e1pd_example.sh new file mode 100644 index 00000000000..55da96b26f4 --- /dev/null +++ b/examples/disaggregated_encoder/disagg_1e1pd_example.sh @@ -0,0 +1,206 @@ +#!/bin/bash +set -euo pipefail + +declare -a PIDS=() + +############################################################################### +# Configuration -- override via env before running +############################################################################### +MODEL="${MODEL:-Qwen/Qwen2.5-VL-7B-Instruct}" +LOG_PATH="${LOG_PATH:-./logs}" +mkdir -p $LOG_PATH + +ENCODE_PORT="${ENCODE_PORT:-19534}" +PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19535}" +PROXY_PORT="${PROXY_PORT:-10001}" + +CARD_E="${CARD_E:-0}" +CARD_PD="${CARD_PD:-1}" + +EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}" +TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout + +NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark + +############################################################################### +# Helpers +############################################################################### +# Find the git repository root directory +VLLM_ROOT="/vllm-workspace/vllm" + +START_TIME=$(date +"%Y%m%d_%H%M%S") +ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log +PD_LOG=$LOG_PATH/pd_${START_TIME}.log +PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log + +wait_for_server() { + local port=$1 + timeout "$TIMEOUT_SECONDS" bash -c " + until curl -s localhost:$port/v1/chat/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Cleanup function +cleanup() { + echo "Stopping everything…" + trap - INT TERM USR1 # prevent re-entrancy + + # Kill all tracked PIDs + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill "$pid" 2>/dev/null + fi + done + + # Wait a moment for graceful shutdown + sleep 2 + + # Force kill any remaining processes + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -9 "$pid" 2>/dev/null + fi + done + + # Kill the entire process group as backup + kill -- -$$ 2>/dev/null + + echo "All processes stopped." + exit 0 +} + +trap cleanup INT +trap cleanup USR1 +trap cleanup TERM + +# clear previous cache +echo "remove previous ec cache folder" +rm -rf $EC_SHARED_STORAGE_PATH + +echo "make ec cache folder" +mkdir -p $EC_SHARED_STORAGE_PATH + +############################################################################### +# Encoder worker +############################################################################### +ASCEND_RT_VISIBLE_DEVICES="$CARD_E" vllm serve "$MODEL" \ + --gpu-memory-utilization 0.01 \ + --port "$ENCODE_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --no-enable-prefix-caching \ + --max-num-batched-tokens 114688 \ + --max-num-seqs 128 \ + --ec-transfer-config '{ + "ec_connector": "ECExampleConnector", + "ec_role": "ec_producer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } + }' \ + >"${ENC_LOG}" 2>&1 & + +PIDS+=($!) + +############################################################################### +# Prefill+Decode worker +############################################################################### +ASCEND_RT_VISIBLE_DEVICES="$CARD_PD" vllm serve "$MODEL" \ + --gpu-memory-utilization 0.9 \ + --port "$PREFILL_DECODE_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --ec-transfer-config '{ + "ec_connector": "ECExampleConnector", + "ec_role": "ec_consumer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } + }' \ + >"${PD_LOG}" 2>&1 & + +PIDS+=($!) + +# Wait for workers +wait_for_server $ENCODE_PORT +wait_for_server $PREFILL_DECODE_PORT + +############################################################################### +# Proxy +############################################################################### +python ./disagg_epd_proxy.py \ + --host "0.0.0.0" \ + --port "$PROXY_PORT" \ + --encode-servers-urls "http://localhost:$ENCODE_PORT" \ + --prefill-servers-urls "disable" \ + --decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \ + >"${PROXY_LOG}" 2>&1 & + +PIDS+=($!) + +wait_for_server $PROXY_PORT +echo "All services are up!" + +############################################################################### +# Single request with local image +############################################################################### +echo "Running single request with local image (non-stream)..." +echo "Running single request with local image (non-stream)..." + +base64_image=$(base64 -w 0 "${VLLM_ROOT}/tests/v1/ec_connector/integration/hato.jpg") + +cat > /tmp/request.json << EOF +{ + "model": "${MODEL}", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "data:image/jpg;base64,${base64_image}" + } + }, + { + "type": "text", + "text": "What is in this image?" + } + ] + } + ] +} +EOF + +curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d @/tmp/request.json + +rm -f /tmp/request.json + +############################################################################### +# Benchmark +############################################################################### +echo "Running benchmark (stream)..." +vllm bench serve \ + --model $MODEL \ + --backend openai-chat \ + --endpoint /v1/chat/completions \ + --dataset-name random-mm \ + --seed 0 \ + --num-prompts $NUM_PROMPTS \ + --port $PROXY_PORT + +PIDS+=($!) + +# cleanup +echo "cleanup..." +cleanup \ No newline at end of file diff --git a/examples/disaggregated_encoder/disagg_epd_proxy.py b/examples/disaggregated_encoder/disagg_epd_proxy.py new file mode 100644 index 00000000000..78840001451 --- /dev/null +++ b/examples/disaggregated_encoder/disagg_epd_proxy.py @@ -0,0 +1,749 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +disagg_encoder_proxy.py + +Proxy that routes OpenAI-compatible “/v1/chat/completions” requests to two +clusters: + • encode (multimodal feature extraction) + • decode (language-model inference) + +For MM input we: + 1. Extract *every* image/audio item. + 2. Fire N concurrent requests to the encoder cluster + (one request per item, with **all text removed**). + 3. Wait for all of them to succeed. + 4. Forward the *original* request to a decode server. +""" + +from __future__ import annotations + +import argparse +import asyncio +import copy +import logging +import os +import random +import uuid +from collections.abc import AsyncIterator +from enum import Enum + +import aiohttp +import uvicorn +from aiohttp import ClientResponse +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + +############################################################################### +# FastAPI app & global state +############################################################################### + +logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s: %(message)s") +logger = logging.getLogger("proxy") + +app = FastAPI() +encode_session: aiohttp.ClientSession | None = None +prefill_session: aiohttp.ClientSession | None = None +decode_session: aiohttp.ClientSession | None = None + +############################################################################### +# Utils +############################################################################### + + +MM_TYPES = {"image_url", "audio_url", "input_audio"} + + +class EncoderDispatchMode(str, Enum): + SINGLE = "single" + FANOUT = "fanout" + + +def extract_mm_items(request_data: dict) -> list[dict]: + """ + Return *all* image/audio items that appear anywhere in `messages`. + + Each returned dict looks like: + { "type": "image_url", "image_url": {...} } + """ + items: list[dict] = [] + for msg in request_data.get("messages", []): + content = msg.get("content") + if not isinstance(content, list): + continue + + for item in content: + if item.get("type") in MM_TYPES: + items.append(item) + return items + + +async def _encode_fanout( + orig_request: dict, + e_urls: list[str], + req_id: str, +): + logger.info("[%s] Processing multimodal items...", req_id) + + mm_items = extract_mm_items(orig_request) + if not mm_items: + logger.info("[%s] No multimodal items, skipping encoder", req_id) + return # nothing to do + + logger.info("[%s] got %d multimodal items...", req_id, len(mm_items)) + + tasks = [] + + # Round-robin over encode servers to distribute load a bit + url_cycle = (e_urls[i % len(e_urls)] for i in range(len(mm_items))) + + for idx, (item, target_url) in enumerate(zip(mm_items, url_cycle)): + # Derive a *child* request id: :: + child_req_id = f"{req_id}:{idx}:{uuid.uuid4().hex[:6]}" + headers = {"x-request-id": child_req_id} + + encoder_req = { + # You *may* need to keep additional fields + "model": orig_request.get("model"), + "messages": [ + {"role": "user", "content": [item]}, + ], + # Only need 1 token so the server actually runs the encoder path + "max_tokens": 1, + "stream": False, + } + if encode_session is None: + raise HTTPException(status_code=500, detail="Encode session not initialized") + tasks.append( + encode_session.post( + f"{target_url}/v1/chat/completions", + json=encoder_req, + headers=headers, + ) + ) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Fail fast if any sub-request failed + for idx, r in enumerate(results): + if isinstance(r, Exception): + logger.error( + "[%s] Encoder request #%d raised exception: %s", + req_id, + idx, + r, + exc_info=r, + ) + error_detail = str(r) + if hasattr(r, "status"): + error_detail = f"Status: {r.status}, Error: {error_detail}" + elif hasattr(r, "status_code"): + error_detail = f"Status: {r.status_code}, Error: {error_detail}" + raise HTTPException(status_code=502, detail=f"Encoder request failed: {error_detail}") + if isinstance(r, ClientResponse): + if hasattr(r, "status") and r.status != 200: + try: + detail = await r.text() + except Exception: + detail = "" + logger.error( + "[%s] Encoder request #%d returned status %s: %s", + req_id, + idx, + r.status, + detail, + ) + raise HTTPException( + status_code=r.status, + detail=f"Encoder request failed: {detail}", + ) + + logger.info("[%s] All %d encoder requests completed successfully", req_id, len(mm_items)) + + +async def _encode_single_request( + orig_request: dict, + e_url: str, + req_id: str, +) -> None: + """ + 1. Build one request *per MM item* with all text removed. + 2. Send them concurrently to the encode cluster. + 3. Raise if any of them fails. + """ + logger.info("[%s] Processing multimodal items...", req_id) + + request_data = copy.deepcopy(orig_request) + headers = {"x-request-id": req_id} + request_data["max_tokens"] = 1 + request_data["stream"] = False + request_data.pop("stream_options", None) + if "max_completion_tokens" in request_data: + request_data["max_completion_tokens"] = 1 + + try: + if encode_session is None: + raise HTTPException(status_code=500, detail="Encode session not initialized") + + encode_response = await encode_session.post(f"{e_url}/v1/chat/completions", json=request_data, headers=headers) + encode_response.raise_for_status() + + if encode_response.status != 200: + encode_text = await encode_response.text() + raise HTTPException( + status_code=encode_response.status, + detail={"error": "Encoder request failed", "message": encode_text}, + ) + logger.debug("Encoder processing completed successfully for req_id: %s", req_id) + + return encode_response + + except Exception as e: + logger.error("Encoder processing failed: %s", str(e)) + raise HTTPException( + status_code=500, + detail={"error": "Encoder processing error", "message": str(e)}, + ) from e + + logger.info("[%s] Encoder request completed successfully", req_id) + + +async def fanout_encoder_primer( + orig_request: dict, + req_id: str, +): + mode = app.state.encoder_dispatch_mode + + if mode == EncoderDispatchMode.SINGLE: + e_url = random.choice(app.state.e_urls) + await _encode_single_request(orig_request, e_url, req_id) + + elif mode == EncoderDispatchMode.FANOUT: + await _encode_fanout(orig_request, app.state.e_urls, req_id) + + else: + raise RuntimeError(f"Unknown encoder dispatch mode: {mode}") + + +async def maybe_prefill( + req_data: dict, + p_url: str, + req_id: str, +) -> dict: + """ + - Do prefill-only task if p_url exist; + - Return modified request data with kv transfer params (for nixl connector) + - Else, skip and return the original request data for decode + """ + if p_url: + logger.info("[%s] Processing through prefill: %s", req_id, p_url) + + prefill_response = await process_prefill_stage(req_data, p_url, req_id) + if isinstance(prefill_response, ClientResponse): + # for nixl connector to facilitate kv transfer... + prefill_response_json = await prefill_response.json() + kv_transfer_params = prefill_response_json.get("kv_transfer_params", {}) + if kv_transfer_params: + req_data["kv_transfer_params"] = kv_transfer_params + return req_data + else: + return req_data + + +async def process_prefill_stage( + req_data: dict, + p_url: str, + req_id: str, +) -> ClientResponse: + """Process request through Prefill stage and return kv_transfer_params""" + logger.info("[%s] Sending prefill request to: %s", req_id, p_url) + + prefill_request = req_data.copy() + prefill_request["kv_transfer_params"] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None, + } + prefill_request["stream"] = False + prefill_request["max_tokens"] = 1 + if "max_completion_tokens" in prefill_request: + prefill_request["max_completion_tokens"] = 1 + if "stream_options" in prefill_request: + del prefill_request["stream_options"] + + headers = {"x-request-id": req_id} + try: + if prefill_session is None: + raise HTTPException(status_code=500, detail="Prefill session not initialized") + + prefill_response = await prefill_session.post( + f"{p_url}/v1/chat/completions", json=prefill_request, headers=headers + ) + prefill_response.raise_for_status() + + if prefill_response.status != 200: + error_text = await prefill_response.text() + logger.error( + "[%s] Prefill request failed with status %d: %s", + req_id, + prefill_response.status, + error_text, + ) + raise HTTPException( + status_code=prefill_response.status, + detail={"error": "Prefill request failed", "message": error_text}, + ) + logger.info("[%s] Prefill request completed successfully", req_id) + + return prefill_response + + except Exception as e: + logger.error("Prefill processing failed: %s", str(e)) + raise HTTPException( + status_code=500, + detail={"error": "Prefill processing error", "message": str(e)}, + ) from e + + +def has_mm_input(request_data: dict): + if "messages" not in request_data: + return False + for message in request_data["messages"]: + if not isinstance(message.get("content"), list): + continue + for content_item in message["content"]: + if content_item.get("type") in ["image_url", "audio_url", "input_audio"]: + return True + return False + + +############################################################################### +# Middleware for request/response logging +############################################################################### + + +@app.middleware("http") +async def log_requests(request: Request, call_next): + """Middleware to log all incoming requests and responses""" + req_id = request.headers.get("x-request-id", str(uuid.uuid4())) + + # Log incoming request + logger.info( + ">>> [%s] %s %s from %s", + req_id, + request.method, + request.url.path, + request.client.host if request.client else "unknown", + ) + + try: + # Process request + response = await call_next(request) + + # Log response + logger.info( + "<<< [%s] %s %s completed with status %d", + req_id, + request.method, + request.url.path, + response.status_code, + ) + + return response + except Exception as e: + # Log errors + logger.exception( + "!!! [%s] %s %s failed with error: %s", + req_id, + request.method, + request.url.path, + str(e), + ) + raise + + +############################################################################### +# FastAPI lifecycle +############################################################################### + + +@app.on_event("startup") +async def on_startup() -> None: + global encode_session, prefill_session, decode_session + timeout = aiohttp.ClientTimeout(total=100_000) + connector = aiohttp.TCPConnector(limit=0, force_close=False, keepalive_timeout=0) + encode_session = aiohttp.ClientSession(timeout=timeout, connector=connector) + if app.state.p_urls: + # only setup if prefill instance(s) exist + prefill_session = aiohttp.ClientSession(timeout=timeout, connector=connector) + decode_session = aiohttp.ClientSession(timeout=timeout, connector=connector) + + +@app.on_event("shutdown") +async def on_shutdown() -> None: + global encode_session, prefill_session, decode_session + if encode_session: + await encode_session.close() + if prefill_session: + await prefill_session.close() + if decode_session: + await decode_session.close() + + +############################################################################### +# Core forwarding +############################################################################### + + +async def forward_non_stream(req_data: dict, req_id: str, p_url: str, d_url: str) -> dict: + try: + # Step 1: Process through Encoder instance (if has MM input) + async def run_encoder(): + await fanout_encoder_primer(req_data, req_id) + + if has_mm_input(req_data): + await non_stream_retry_wrap(run_encoder) + + # Step 2: Process through Prefill instance + async def run_prefill(): + return await maybe_prefill(req_data, p_url, req_id) + + req_data = await non_stream_retry_wrap(run_prefill) + + async def run_decode_non_stream(): + # Step 3: Process through Decode instance + logger.info("[%s] Forwarding to decode: %s", req_id, d_url) + headers = {"x-request-id": req_id} + + # Non-streaming response + if decode_session is None: + raise HTTPException(status_code=500, detail="Decode session not initialized") + + async with decode_session.post(f"{d_url}/v1/chat/completions", json=req_data, headers=headers) as resp: + resp.raise_for_status() + return await resp.json() + + return await non_stream_retry_wrap(run_decode_non_stream) + + except HTTPException: + raise + except Exception as e: + logger.exception("[%s] Error in forward_non_stream: %s", req_id, str(e)) + raise HTTPException(status_code=500, detail=f"Proxy error: {str(e)}") from e + + +async def stream_retry_wrap(forward_func, max_retries: int = 3, delay: float = 0.001): + last_exc = None + first_chunk_sent = False + for attempt in range(max_retries): + try: + async for chunk in forward_func(): + first_chunk_sent = True + yield chunk + return + except Exception as e: + if first_chunk_sent: + raise + if isinstance(e, HTTPException) and e.status_code < 500: + raise + last_exc = e + logger.warning( + "attempt %s / %s failed retrying... ", + attempt + 1, + max_retries, + ) + await asyncio.sleep(delay * (attempt + 1)) + + raise RuntimeError(f"all {max_retries} retries failed.") from last_exc + + +async def non_stream_retry_wrap(forward_func, max_retries: int = 3, delay: float = 0.001): + last_exc = None + for attempt in range(max_retries): + try: + result = await forward_func() + return result + except Exception as e: + if isinstance(e, HTTPException) and e.status_code < 500: + raise + last_exc = e + logger.warning( + "attempt %s / %s failed retrying... ", + attempt + 1, + max_retries, + ) + await asyncio.sleep(delay * (attempt + 1)) + raise RuntimeError(f"all {max_retries} retries failed.") from last_exc + + +async def forward_stream(req_data: dict, req_id: str, p_url: str, d_url: str) -> AsyncIterator[str]: + try: + # Step 1: Process through Encoder instance (if has MM input) + async def run_encoder(): + await fanout_encoder_primer(req_data, req_id) + + if has_mm_input(req_data): + await non_stream_retry_wrap(run_encoder) + + # Step 2: Process through Prefill instance + async def run_prefill(): + return await maybe_prefill(req_data, p_url, req_id) + + req_data = await non_stream_retry_wrap(run_prefill) + + async def run_decode_stream(): + # Step 3: Process through Decode instance + logger.info("[%s] Starting streaming from decode: %s", req_id, d_url) + headers = {"x-request-id": req_id} + + # Streaming response + if decode_session is None: + raise HTTPException(status_code=500, detail="Decode session not initialized") + + async with decode_session.post( + f"{d_url}/v1/chat/completions", + json=req_data, + headers=headers, + ) as resp: + resp.raise_for_status() + async for chunk in resp.content.iter_chunked(1024): + if chunk: + yield chunk.decode("utf-8", errors="ignore") + + logger.info("[%s] Streaming completed", req_id) + + async for chunk in stream_retry_wrap(run_decode_stream): + yield chunk + + except HTTPException: + logger.exception("[%s] HTTPException in forward_stream", req_id) + raise + except Exception as e: + logger.exception("[%s] Error in forward_stream: %s", req_id, str(e)) + raise HTTPException(status_code=500, detail=f"Proxy streaming error: {str(e)}") from e + + +############################################################################### +# Public routes +############################################################################### + + +@app.post("/v1/chat/completions") +async def chat_completions(request: Request): + try: + req_data = await request.json() + req_id = request.headers.get("x-request-id", str(uuid.uuid4())) + + p_url = random.choice(app.state.p_urls) if app.state.p_urls else None + d_url = random.choice(app.state.d_urls) + + is_streaming = req_data.get("stream", False) + + if is_streaming: + return StreamingResponse( + forward_stream(req_data, req_id, p_url, d_url), + media_type="text/event-stream", + ) + result = await forward_non_stream(req_data, req_id, p_url, d_url) + return JSONResponse(content=result) + + except HTTPException: + raise + except Exception as e: + logger.exception("Error in chat_completions endpoint: %s", str(e)) + raise HTTPException(status_code=500, detail=f"Request processing error: {str(e)}") from e + + +@app.get("/v1/models") +async def list_models(): + if decode_session is None: + raise HTTPException(status_code=500, detail="Decode session not initialized") + async with decode_session.get(f"{app.state.d_urls[0]}/v1/models") as resp: + resp.raise_for_status() + return await resp.json() + + +@app.get("/health") +async def health_check(): + async def healthy(urls, session): + if not urls: + return "empty" + for u in urls: + try: + if session is None: + return "unhealthy" + async with session.get(f"{u}/health") as resp: + resp.raise_for_status() + except Exception: + return "unhealthy" + return "healthy" + + e_status, p_status, d_status = await asyncio.gather( + healthy(app.state.e_urls, encode_session), + healthy(app.state.p_urls, prefill_session), + healthy(app.state.d_urls, decode_session), + ) + + overall_healthy = all(status != "unhealthy" for status in (e_status, p_status, d_status)) + + status_code = 200 if overall_healthy else 503 + + return JSONResponse( + { + "proxy": "healthy", + "encode_cluster": e_status, + "prefill_cluster": p_status, + "decode_cluster": d_status, + }, + status_code=status_code, + ) + + +############################################################################### +# Simple profiler fan-out (unchanged except for sessions) +############################################################################### + + +async def _post_if_available( + session: aiohttp.ClientSession, + url: str, + payload: dict, + headers: dict, +) -> dict | None: + """ + POST `payload` to `url`. + + Returns + ------- + • The decoded JSON body on success (2xx) + • None if the endpoint does not exist (404) + • Raises for anything else. + """ + try: + if session is None: + return None + resp = await session.post(url, json=payload, headers=headers) + if resp.status == 404: # profiling disabled on that server + logger.warning("Profiling endpoint missing on %s", url) + return None + resp.raise_for_status() + return await resp.json(content_type=None) + except aiohttp.ClientResponseError as exc: + # Pass 404 through the branch above, re-raise everything else + if exc.status == 404: + logger.warning("Profiling endpoint missing on %s", url) + return None + raise + except Exception: + # Network errors etc.: propagate + raise + + +async def _profile_cmd(cmd: str, payload: dict, e_url: str, p_url: str, d_url: str): + """ + Fire & forget to both clusters, tolerate 404. + """ + headers = {"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY', '')}"} + + encode_task = _post_if_available(encode_session, f"{e_url}/{cmd}_profile", payload, headers) + prefill_task = ( + _post_if_available(prefill_session, f"{p_url}/{cmd}_profile", payload, headers) + if p_url is not None + else asyncio.sleep(0) + ) + decode_task = _post_if_available(decode_session, f"{d_url}/{cmd}_profile", payload, headers) + + encode_res, prefill_res, decode_res = await asyncio.gather(encode_task, prefill_task, decode_task) + + # If *all* clusters said “I don’t have that route”, surface an error + if encode_res is prefill_res is decode_res is None: + raise HTTPException( + status_code=503, + detail="Profiling endpoints are disabled on all clusters", + ) + + return { + "encode": encode_res, # may be None + "prefill": prefill_res, # may be None + "decode": decode_res, # may be None + } + + +@app.post("/start_profile") +async def start_profile(request: Request): + body = await request.json() + # TODO: handle multi urls properly + e_url = random.choice(app.state.e_urls) + p_url = random.choice(app.state.p_urls) if app.state.p_urls else None + d_url = random.choice(app.state.d_urls) + return await _profile_cmd("start", body, e_url, p_url, d_url) + + +@app.post("/stop_profile") +async def stop_profile(request: Request): + body = await request.json() + # TODO: handle multi urls properly + e_url = random.choice(app.state.e_urls) + p_url = random.choice(app.state.p_urls) if app.state.p_urls else None + d_url = random.choice(app.state.d_urls) + return await _profile_cmd("stop", body, e_url, p_url, d_url) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--encode-servers-urls", + required=True, + help='Comma-separated encode URLs ("http://e1:8001,http://e2:8001")', + ) + parser.add_argument( + "--prefill-servers-urls", + required=True, + help='Comma-separated prefill URLs ("http://p1:8003,http://p2:8004") to enable E->P->D, ' + 'set "disable" or "none" to enable E->PD', + ) + parser.add_argument( + "--decode-servers-urls", + required=True, + help='Comma-separated decode URLs ("http://d1:8005,http://d2:8006")', + ) + + parser.add_argument( + "--encoder-dispatch-mode", + choices=["single", "fanout"], + default="single", + help="Encoder dispatch mode: single (one request) or fanout (per-MM-item)", + ) + + args = parser.parse_args() + app.state.e_urls = [u.strip() for u in args.encode_servers_urls.split(",") if u.strip()] + app.state.d_urls = [u.strip() for u in args.decode_servers_urls.split(",") if u.strip()] + # handle prefill instances + if args.prefill_servers_urls.lower() in ("disable", "none", ""): + app.state.p_urls = [] + logger.info("Disaggregated prefill phase explicitly disabled by user. Running E + PD...") + else: + app.state.p_urls = [u.strip() for u in args.prefill_servers_urls.split(",") if u.strip()] + logger.info("Disaggregated prefill phase is enabled. Running E + P + D...") + + app.state.encoder_dispatch_mode = EncoderDispatchMode(args.encoder_dispatch_mode) + + logger.info("Proxy listening on %s:%s", args.host, args.port) + logger.info("Encode servers: %s", app.state.e_urls) + logger.info("Prefill instances %s", app.state.p_urls) + logger.info("Decode servers: %s", app.state.d_urls) + + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + loop="uvloop", + access_log=True, + ) diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index fdad87dfd28..af9c8dec3c7 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -18,6 +18,7 @@ # import contextlib +import copy import functools import gc import json @@ -27,11 +28,15 @@ import shlex import subprocess import sys +import threading import time +import traceback +from pathlib import Path from typing import Any, Optional, Tuple, TypeVar, Union import numpy as np import openai +import psutil import pytest import requests import torch @@ -80,6 +85,10 @@ _TEST_DIR = os.path.dirname(__file__) _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "long_prompt.txt")] +DISAGG_EPD_PROXY_SCRIPT = Path( + __file__ +).parent.parent.parent / "examples" / "disaggregated_encoder" / "disagg_epd_proxy.py" + def _check_npu_memory_worker(target_free_percentage: float, max_wait_seconds: float): import torch_npu # type: ignore @@ -441,6 +450,216 @@ def get_async_client(self, **kwargs): **kwargs) +class RemoteEPDServer(RemoteOpenAIServer): + def _start_server(self, model: str, server_cmd: list[str], + env_dict: Optional[dict[str, str]]) -> None: + """Subclasses override this method to customize server process launch + """ + raise NotImplementedError("RemoteEPDServer should use _start_server_with_prefix instead") + + def __init__(self, + vllm_serve_args: Union[list[str], list[list[str]]], + server_host: str = '0.0.0.0', + env_dict: Optional[dict[str, str]] = None, + max_wait_seconds: Optional[float] = 2800) -> None: + + self._proc_list = [] + + self.env_dict: dict[str, str] = {} + if env_dict is not None: + self.env_dict.update(env_dict) + + self.env_dict['VLLM_ALLOW_LONG_MAX_MODEL_LEN'] = "1" + self.env_dict['VLLM_USE_V1'] = "1" + self.env_dict['PYTORCH_NPU_ALLOC_CONF'] = "expandable_segments:True" + self.env_dict['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + + self.vllm_serve_args_list = [] + self.health_url_list = [] + self.host = server_host + + if isinstance(vllm_serve_args, list): + if not all(isinstance(item, list) for item in vllm_serve_args): + args_copy = copy.deepcopy(vllm_serve_args) + self.vllm_serve_args_list.append([str(arg) for arg in args_copy]) + else: + self.vllm_serve_args_list = [ + [str(arg) for arg in sublist] + for sublist in copy.deepcopy(vllm_serve_args) + ] + else: + raise RuntimeError("vllm_serves_args must be a list") + + serve_arg_cmd = ["vllm", "serve"] + + for i, vllm_serve_arg in enumerate(self.vllm_serve_args_list): + self.env_dict['ASCEND_RT_VISIBLE_DEVICES'] = str(i) + if isinstance(vllm_serve_arg, list): + if "--port" not in vllm_serve_arg: + raise ValueError("You have manually specified the port ") + else: + port_arg = "--port" + try: + index = vllm_serve_arg.index(port_arg) + except ValueError: + raise ValueError(f"--port not found in args: {vllm_serve_arg}") + port_str = vllm_serve_arg[index + 1] + self.port = int(port_str) + else: + vllm_serve_arg_str = str(vllm_serve_arg) + if "--port" not in vllm_serve_arg_str: + raise ValueError("You have manually specified the port ") + else: + raise ValueError(f"Unexpected type for vllm_serve_arg: {type(vllm_serve_arg)}") + + self.health_url_list.append(super().url_for("health")) + vllm_serve_arg = [*serve_arg_cmd, *vllm_serve_arg] + proc = self._start_server_with_prefix(vllm_serve_arg, self.env_dict, + f"[VLLM_{i}] ") + self._proc_list.append(proc) + + timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 2800.0 + super()._wait_for_multiple_servers([(self.host, url) + for url in self.health_url_list], + timeout=timeout_value) + + def _poll(self) -> Optional[int]: + return None + + def _delete_shm(self) -> None: + for i, arg in enumerate(self.vllm_serve_args_list): + if "--ec-transfer-config" in arg: + index = arg.index("--ec-transfer-config") + config_str = arg[index + 1] + config_dict = json.loads(config_str) + ec_connector_extra_config = config_dict.get("ec_connector_extra_config", {}) + shm_path = ec_connector_extra_config.get("shared_storage_path") + if shm_path: + args = ["rm", "-r", "-f", str(shm_path)] + print(f"delete shm_path is: {shm_path}") + self._start_server_with_prefix(args, None, "[DELETE] ") + + def _read_output(self, pipe, prefix): + try: + with pipe: + for line in iter(pipe.readline, ''): + if line: + print(f"{prefix}: {line}", end='') + + except Exception as e: + print(f"error: {e}") + traceback.print_exc() + + def _start_server_with_prefix(self, server_cmd: list[str], + env_dict: Optional[dict[str, str]], log_prefix: str): + env = os.environ.copy() + if env_dict is not None: + env.update(env_dict) + proc = subprocess.Popen(server_cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + bufsize=1) + stdout_thread = threading.Thread(target=self._read_output, + args=(proc.stdout, log_prefix), + daemon=True) + stderr_thread = threading.Thread(target=self._read_output, + args=(proc.stderr, log_prefix), + daemon=True) + + stdout_thread.start() + stderr_thread.start() + return proc + + def _terminate_server(self) -> None: + """kill process and its children""" + print("vllm instance is stopping") + for proc in self._proc_list: + parent = psutil.Process(proc.pid) + children = parent.children(recursive=True) + for child in children: + try: + child.terminate() + except psutil.NoSuchProcess: + pass + + gone, still_alive = psutil.wait_procs(children, timeout=10) + + for child in still_alive: + try: + child.kill() + except psutil.NoSuchProcess: + pass + + try: + parent.terminate() + parent.wait(timeout=10) + except (psutil.NoSuchProcess, psutil.TimeoutExpired): + try: + parent.kill() + except psutil.NoSuchProcess: + pass + + def __enter__(self): + """Context manager entry point.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit point - clean up all processes.""" + self._terminate_server() + + +class DisaggEpdProxy(RemoteEPDServer): + + def __init__(self, + proxy_args: Optional[Union[list[str], str]] = None, + env_dict: Optional[dict[str, str]] = None, + server_host: str = '0.0.0.0', + max_wait_seconds: Optional[float] = 2800) -> None: + + if proxy_args is None: + proxy_args_list: list[str] = [] + elif isinstance(proxy_args, str): + proxy_args_list = shlex.split(proxy_args) + else: + proxy_args_list = proxy_args + + self.proxy_args = proxy_args_list + self.env_dict: dict[str, str] = {} + if env_dict is not None: + self.env_dict.update(env_dict) + self._proc_list = list() + self.host = server_host + + print(f"proxy param is: {self.proxy_args}") + proxy_cmd = ["python", str(DISAGG_EPD_PROXY_SCRIPT), *self.proxy_args] + proc = self._start_server_with_prefix(proxy_cmd, self.env_dict, "[PROXY] ") + self._proc_list.append(proc) + + if "--port" not in self.proxy_args: + raise ValueError("You have manually specified the port ") + else: + try: + index = self.proxy_args.index("--port") + except ValueError: + raise ValueError("--port not found in proxy args") + port_str = self.proxy_args[index + 1] + self.port = int(port_str) + + timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 2800.0 + super()._wait_for_multiple_servers( + [(self.host, super().url_for("health"))], timeout=timeout_value) + + def __enter__(self): + """Context manager entry point.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit point - clean up all processes.""" + super()._terminate_server() + + class VllmRunner: def __init__( diff --git a/tests/e2e/multicard/2-cards/test_disaggregated_encoder.py b/tests/e2e/multicard/2-cards/test_disaggregated_encoder.py new file mode 100644 index 00000000000..6d54a318b41 --- /dev/null +++ b/tests/e2e/multicard/2-cards/test_disaggregated_encoder.py @@ -0,0 +1,71 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import pytest +from vllm.utils.network_utils import get_open_port + +from tests.e2e.conftest import DisaggEpdProxy, RemoteEPDServer +from tools.send_mm_request import send_image_request + +MODELS = [ + "Qwen/Qwen2.5-VL-7B-Instruct", +] +SHARED_STORAGE_PATH = "/dev/shm/epd/storage" +TENSOR_PARALLELS = [1] + +@pytest.mark.asyncio +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS) +async def test_models(model: str, tp_size: int) -> None: + encode_port = get_open_port() + pd_port = get_open_port() + vllm_server_args = [ + [ + "--port", + str(encode_port), "--model", model, "--gpu-memory-utilization", + "0.01", "--tensor-parallel-size", + str(tp_size), "--enforce-eager", "--no-enable-prefix-caching", + "--max-model-len", "10000", "--max-num-batched-tokens", "10000", + "--max-num-seqs", "1", "--ec-transfer-config", + '{"ec_connector_extra_config":{"shared_storage_path":"' + + SHARED_STORAGE_PATH + + '"},"ec_connector":"ECExampleConnector","ec_role": "ec_producer"}' + ], + [ + "--port", + str(pd_port), "--model", model, "--gpu-memory-utilization", "0.95", + "--tensor-parallel-size", + str(tp_size), "--enforce-eager", "--max-model-len", "10000", + "--max-num-batched-tokens", "10000", "--max-num-seqs", "128", + "--ec-transfer-config", + '{"ec_connector_extra_config":{"shared_storage_path":"' + + SHARED_STORAGE_PATH + + '"},"ec_connector":"ECExampleConnector","ec_role": "ec_consumer"}' + ] + ] + proxy_port = get_open_port() + proxy_args = [ + "--host", "127.0.0.1", "--port", + str(proxy_port), "--encode-servers-urls", + f"http://localhost:{encode_port}", "--decode-servers-urls", + f"http://localhost:{pd_port}", "--prefill-servers-urls", "disable" + ] + + with RemoteEPDServer(vllm_serve_args=vllm_server_args) as _: + with DisaggEpdProxy(proxy_args=proxy_args) as proxy: + send_image_request(model, proxy) + diff --git a/tests/e2e/nightly/single_node/models/test_qwen2_5_vl_7b_epd.py b/tests/e2e/nightly/single_node/models/test_qwen2_5_vl_7b_epd.py new file mode 100644 index 00000000000..fee7a705868 --- /dev/null +++ b/tests/e2e/nightly/single_node/models/test_qwen2_5_vl_7b_epd.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import pytest +from vllm.utils.network_utils import get_open_port + +from tests.e2e.conftest import DisaggEpdProxy, RemoteEPDServer +from tools.aisbench import run_aisbench_cases + +MODELS = [ + "Qwen/Qwen2.5-VL-7B-Instruct", +] +SHARED_STORAGE_PATH = "/dev/shm/epd/storage" +TENSOR_PARALLELS = [1] + +warmup_cases = [{ + "case_type": "performance", + "dataset_path": "vllm-ascend/textvqa-perf-1080p", + "request_conf": "vllm_api_stream_chat", + "dataset_conf": "textvqa/textvqa_gen_base64", + "num_prompts": 50, + "max_out_len": 20, + "batch_size": 32, + "request_rate": 0, + "baseline": 1, + "threshold": 0.97 +}] +aisbench_cases = [{ + "case_type": "accuracy", + "dataset_path": "vllm-ascend/textvqa-lite", + "request_conf": "vllm_api_stream_chat", + "dataset_conf": "textvqa/textvqa_gen_base64", + "max_out_len": 2048, + "batch_size": 128, + "baseline": 82.05, + "threshold": 5 +}, { + "case_type": "performance", + "dataset_path": "vllm-ascend/textvqa-perf-1080p", + "request_conf": "vllm_api_stream_chat", + "dataset_conf": "textvqa/textvqa_gen_base64", + "num_prompts": 512, + "max_out_len": 256, + "batch_size": 128, + "request_rate": 0, + "baseline": 1, + "threshold": 0.97 +}] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS) +async def test_models(model: str, tp_size: int) -> None: + encode_port = get_open_port() + pd_port = get_open_port() + vllm_server_args = [ + [ + "--port", + str(encode_port), "--model", model, "--gpu-memory-utilization", + "0.01", "--tensor-parallel-size", + str(tp_size), "--enforce-eager", "--no-enable-prefix-caching", + "--max-model-len", "10000", "--max-num-batched-tokens", "10000", + "--max-num-seqs", "1", "--ec-transfer-config", + '{"ec_connector_extra_config":{"shared_storage_path":"' + + SHARED_STORAGE_PATH + + '"},"ec_connector":"ECExampleConnector","ec_role": "ec_producer"}' + ], + [ + "--port", + str(pd_port), "--model", model, "--gpu-memory-utilization", "0.95", + "--tensor-parallel-size", + str(tp_size), "--enforce-eager", "--max-model-len", "10000", + "--max-num-batched-tokens", "10000", "--max-num-seqs", "128", + "--ec-transfer-config", + '{"ec_connector_extra_config":{"shared_storage_path":"' + + SHARED_STORAGE_PATH + + '"},"ec_connector":"ECExampleConnector","ec_role": "ec_consumer"}' + ] + ] + proxy_port = get_open_port() + proxy_args = [ + "--host", "127.0.0.1", "--port", + str(proxy_port), "--encode-servers-urls", + f"http://localhost:{encode_port}", "--decode-servers-urls", + f"http://localhost:{pd_port}", "--prefill-servers-urls", "disable" + ] + + with RemoteEPDServer(vllm_serve_args=vllm_server_args) as _: + with DisaggEpdProxy(proxy_args=proxy_args) as _: + # warm up + run_aisbench_cases(model=model, + port=proxy_port, + aisbench_cases=warmup_cases) + # aisbench test + run_aisbench_cases(model, proxy_port, aisbench_cases) diff --git a/tools/check_python_src_init.py b/tools/check_python_src_init.py index ab0fa21f477..ff3f97e5161 100644 --- a/tools/check_python_src_init.py +++ b/tools/check_python_src_init.py @@ -67,7 +67,7 @@ def main(): print(f" - {pkg}") sys.exit(1) else: - print("✅ All Python packages have __init__.py files.") + print("All Python packages have __init__.py files.") if __name__ == "__main__":