Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions docs/design/cuda_graphs_multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The encoder CUDA Graph system uses a **budget-based capture/replay** strategy, m

* [EncoderCudaGraphManager][vllm.v1.worker.encoder_cudagraph.EncoderCudaGraphManager]: orchestrates capture, replay, greedy packing, and data-parallel execution for encoder CUDA Graphs.
* [SupportsEncoderCudaGraph][vllm.model_executor.models.interfaces.SupportsEncoderCudaGraph]: a runtime-checkable protocol that models implement to opt-in to encoder CUDA Graphs.
* [EncoderItemSpec][vllm.v1.worker.encoder_cudagraph_defs.EncoderItemSpec]: describes a single encoder input item (image or video) with its input size and output token count.
* [BudgetGraphMetadata][vllm.v1.worker.encoder_cudagraph.BudgetGraphMetadata]: holds the captured CUDA Graph and its associated I/O buffers for a single token budget level.

### Budget-based graph capture
Expand All @@ -30,8 +31,7 @@ class BudgetGraphMetadata:
max_batch_size: int
max_frames_per_batch: int
graph: torch.cuda.CUDAGraph
input_buffer: torch.Tensor # e.g. pixel_values
metadata_buffers: dict[str, torch.Tensor] # e.g. embeddings, seq metadata
input_buffers: dict[str, torch.Tensor] # e.g. pixel_values, embeddings, seq metadata
output_buffer: torch.Tensor # encoder hidden states
```

Expand All @@ -43,8 +43,8 @@ When a batch of images arrives, the manager sorts images by output token count (

For each graph replay:

1. Zero the pre-allocated `input_buffer`, then copy input tensors (e.g., `pixel_values`) into it.
2. Zero `metadata_buffers`, then slice-copy precomputed values (e.g., rotary embeddings, sequence metadata).
1. Call `prepare_encoder_cudagraph_replay_buffers()` to compute buffer values (including `pixel_values` and precomputed metadata) from actual batch inputs.
2. Zero the pre-allocated `input_buffers`, then slice-copy the replay values into them.
3. Replay the CUDA Graph.
4. Clone outputs from `output_buffer` (cloning is necessary since the buffer is reused across replays).

Expand All @@ -65,19 +65,15 @@ Following <https://github.com/vllm-project/vllm/pull/35963> (ViT full CUDA graph

Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGraph][vllm.model_executor.models.interfaces.SupportsEncoderCudaGraph] protocol. This protocol encapsulates all model-specific logic so that the manager remains model-agnostic. The protocol defines the following methods:

* `get_encoder_cudagraph_config()` — returns static configuration (supported modalities, input key, buffer keys, output hidden size).
* `get_encoder_cudagraph_config()` — returns static configuration (supported modalities, buffer keys, output hidden size, padding logics, max frames per video).
* `get_encoder_cudagraph_budget_range(vllm_config)` — returns `(min_budget, max_budget)` for auto-inference of token budgets.
* `get_encoder_cudagraph_num_items(mm_kwargs)` — returns the number of items (e.g. images) in the batch.
* `get_encoder_cudagraph_per_item_output_tokens(mm_kwargs)` — returns per-item output token counts, used for greedy packing.
* `get_encoder_cudagraph_per_item_input_sizes(mm_kwargs)` — returns per-item input sizes (e.g. patch counts), used for DP load balancing.
* `get_encoder_cudagraph_item_specs(mm_kwargs)` — returns `list[EncoderItemSpec]` describing each item with its input size and output token count. Replaces the former three separate methods (`get_num_items`, `get_per_item_output_tokens`, `get_per_item_input_sizes`).
* `select_encoder_cudagraph_items(mm_kwargs, indices)` — extracts a sub-batch of items by index, used during greedy packing and DP sharding.
* `prepare_encoder_cudagraph_capture_inputs(...)` — creates dummy inputs for graph capture.
* `prepare_encoder_cudagraph_replay_buffers(...)` — computes new buffer values from actual batch inputs before replay.
* `encoder_cudagraph_forward(...)` — forward pass using precomputed buffers (called during capture and replay).
* `encoder_eager_forward(...)` — fallback eager forward when no graph fits.
* `get_input_modality(...)` - return the modality of the inputs.
* `get_max_frames_per_video()` - return model-specific max frames per video.
* `postprocess_encoder_output(...)` - post process encoder output, directly call scatter_output_slices by default
* `prepare_encoder_cudagraph_capture_inputs(...)` — creates dummy inputs for graph capture. Returns `EncoderCudaGraphCaptureInputs` with a single `values: dict[str, torch.Tensor]` that contains all buffers to be recorded into the graph.
* `prepare_encoder_cudagraph_replay_buffers(mm_kwargs, max_batch_size, max_frames_per_batch)` — computes buffer values from actual batch inputs. Returns `EncoderCudaGraphReplayBuffers` with a `values` dict whose keys match `buffer_keys` in the config.
* `encoder_cudagraph_forward(inputs: dict[str, torch.Tensor])` — forward pass accepting only fixed-shaped input tensors (the captured `values` dict). Called during both capture and replay. The `pixel_values` tensor is included in `inputs` alongside metadata buffers.
* `encoder_eager_forward(mm_kwargs)` — fallback eager forward when no graph fits.
* `postprocess_encoder_output(...)` — post-process encoder output, delegates to `scatter_output_slices` by default.

!!! note
The `SupportsEncoderCudaGraph` protocol is designed to be model-agnostic. New vision encoder models can opt-in by implementing the protocol methods without modifying the manager.
Expand All @@ -103,7 +99,7 @@ Three fields in `CompilationConfig` control encoder CUDA Graphs:
* `cudagraph_mm_encoder` (`bool`, default `False`) — enable CUDA Graph capture for multimodal encoder. When enabled, captures the full encoder forward as a CUDA Graph for each token budget level.
* `encoder_cudagraph_token_budgets` (`list[int]`, default `[]`) — token budget levels for capture. If empty (default), auto-inferred from model architecture as power-of-2 levels. User-provided values override auto-inference.
* `encoder_cudagraph_max_vision_items_per_batch` (`int`, default `0`) — maximum number of images/videos per batch during capture. If 0 (default), auto-inferred as `max_budget // min_budget`.
* `encoder_cudagraph_max_frames_per_batch` (`int`, default `None`) — maximum number of video frames per batch during capture. If `None` (default), auto-inferred as `encoder_cudagraph_max_vision_items_per_batch * max_frames_per_video` (`max_frames_per_video` is a model-specific value according to its `processing_info`). If we limit the video count per prompt to `0`, it will also be set to `0` (i.e., fall back to image-only mode).
* `encoder_cudagraph_max_frames_per_batch` (`int`, default `None`) — maximum number of video frames per batch during capture. If `None` (default), auto-inferred as `encoder_cudagraph_max_vision_items_per_batch * max_frames_per_video` (`max_frames_per_video` is a model-specific value from `EncoderCudaGraphConfig`, computed by `get_max_frames_per_video()` on the model). If we limit the video count per prompt to `0`, it will also be set to `0` (i.e., fall back to image-only mode).

## Usage guide

Expand Down
4 changes: 2 additions & 2 deletions rust/src/chat/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ use std::convert::Infallible;
use std::fmt;
use std::str::FromStr;

use serde_with::DeserializeFromStr;
use serde_with::{DeserializeFromStr, SerializeDisplay};

/// Specify which reasoning or tool-call parser implementation to use.
#[derive(Debug, Clone, PartialEq, Eq, Default, DeserializeFromStr)]
#[derive(Debug, Clone, PartialEq, Eq, Default, DeserializeFromStr, SerializeDisplay)]
pub enum ParserSelection {
/// Use model-based auto-detection.
#[default]
Expand Down
4 changes: 2 additions & 2 deletions rust/src/chat/src/renderer/hf/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::str::FromStr;
use minijinja::machinery::ast::{Expr, ForLoop, Set, Stmt};
use minijinja::machinery::{WhitespaceConfig, parse};
use minijinja::syntax::SyntaxConfig;
use serde_with::DeserializeFromStr;
use serde_with::{DeserializeFromStr, SerializeDisplay};

/// Chat template content format.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
Expand All @@ -18,7 +18,7 @@ pub enum ChatTemplateContentFormat {
}

/// Configurable chat-template content format selection.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, DeserializeFromStr)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, DeserializeFromStr, SerializeDisplay)]
pub enum ChatTemplateContentFormatOption {
/// Detect the format from the template source.
#[default]
Expand Down
4 changes: 2 additions & 2 deletions rust/src/chat/src/renderer/selection.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::fmt;
use std::str::FromStr;

use serde_with::DeserializeFromStr;
use serde_with::{DeserializeFromStr, SerializeDisplay};

/// Specify which chat renderer implementation to use.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, DeserializeFromStr)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, DeserializeFromStr, SerializeDisplay)]
pub enum RendererSelection {
/// Use model-based auto-detection.
#[default]
Expand Down
3 changes: 2 additions & 1 deletion rust/src/engine-core-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::Arc;
use std::time::Duration;

use futures::future::{join_all, try_join_all};
use serde::Serialize;
use tokio::sync::mpsc;
use tokio_util::task::AbortOnDropHandle;
use tracing::{debug, info, trace};
Expand All @@ -23,7 +24,7 @@ pub use stream::{EngineCoreOutputStream, EngineCoreStreamOutput};

/// How the frontend acquires its request/response transport with Python
/// `EngineCoreProc`s.
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub enum TransportMode {
/// The Rust process owns the startup handshake and allocates or binds the
/// frontend transport addresses itself before replying to engine
Expand Down
7 changes: 4 additions & 3 deletions rust/src/server/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ use std::collections::HashMap;
use std::time::Duration;

use anyhow::Result;
use serde::Serialize;
use serde_json::Value;
use vllm_chat::{ChatTemplateContentFormatOption, ParserSelection, RendererSelection};
use vllm_engine_core_client::{CoordinatorMode as EngineCoreCoordinatorMode, TransportMode};

/// How the HTTP server obtains its listening socket.
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub enum HttpListenerMode {
/// Bind a fresh TCP listener on the given host/port.
BindTcp { host: String, port: u16 },
Expand All @@ -20,7 +21,7 @@ pub enum HttpListenerMode {

/// Which coordinator implementation should be active when one is present for a
/// frontend client.
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub enum CoordinatorMode {
/// Do not run a coordinator at all.
None,
Expand All @@ -32,7 +33,7 @@ pub enum CoordinatorMode {
}

/// Normalized runtime configuration for the minimal OpenAI-compatible server.
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct Config {
/// Frontend-to-engine transport setup.
pub transport_mode: TransportMode,
Expand Down
5 changes: 4 additions & 1 deletion rust/src/server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod listener;
mod lora;
mod middleware;
mod routes;
mod server_info;
mod state;
mod utils;

Expand All @@ -30,6 +31,7 @@ use vllm_text::TextLlm;

use crate::listener::Listener;
use crate::routes::build_router;
use crate::server_info::ServerInfoSnapshot;
use crate::state::AppState;

/// Build the shared application state for one configured model and one engine
Expand Down Expand Up @@ -88,7 +90,8 @@ async fn build_state(config: &Config) -> Result<Arc<AppState>> {
Ok(Arc::new(
AppState::new(served_model_names, chat)
.with_log_requests(config.enable_log_requests)
.with_request_id_headers(config.enable_request_id_headers),
.with_request_id_headers(config.enable_request_id_headers)
.with_server_info(ServerInfoSnapshot::from_config(config)),
))
}

Expand Down
2 changes: 2 additions & 0 deletions rust/src/server/src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod load;
mod lora;
mod metrics;
pub(crate) mod openai;
mod server_info;
mod sleep;
mod version;

Expand Down Expand Up @@ -89,6 +90,7 @@ fn build_router_with_options(
.route("/sleep", post(sleep::sleep))
.route("/wake_up", post(sleep::wake_up))
.route("/is_sleeping", get(sleep::is_sleeping))
.route("/server_info", get(server_info::server_info))
}

let enable_request_id_headers = state.enable_request_id_headers;
Expand Down
47 changes: 47 additions & 0 deletions rust/src/server/src/routes/server_info.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use std::sync::Arc;

use axum::Json;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use serde::Deserialize;

use crate::server_info::ServerInfoConfigFormat;
use crate::state::AppState;

#[derive(Debug, Clone, Copy, Deserialize)]
#[serde(rename_all = "lowercase")]
enum ConfigFormat {
Text,
Json,
}

impl From<ConfigFormat> for ServerInfoConfigFormat {
fn from(value: ConfigFormat) -> Self {
match value {
ConfigFormat::Text => Self::Text,
ConfigFormat::Json => Self::Json,
}
}
}

fn default_config_format() -> ConfigFormat {
ConfigFormat::Text
}

#[derive(Debug, Deserialize)]
pub(crate) struct ServerInfoParams {
#[serde(default = "default_config_format")]
config_format: ConfigFormat,
}

/// Get server configuration and environment metadata.
pub async fn server_info(
State(state): State<Arc<AppState>>,
Query(params): Query<ServerInfoParams>,
) -> Response {
match state.server_info_response(params.config_format.into()) {
Some(response) => Json(response).into_response(),
None => StatusCode::NOT_FOUND.into_response(),
}
}
49 changes: 45 additions & 4 deletions rust/src/server/src/routes/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -742,16 +742,23 @@ async fn test_chat_with_engine_outputs(
}

async fn test_app() -> axum::Router {
test_app_with_dev_mode(false).await
}

async fn test_app_with_dev_mode(dev_mode_enabled: bool) -> axum::Router {
let (chat, _engine_task) = test_models_with_engine_outputs_and_backend(
b"engine-openai",
default_stream_output_specs(),
Arc::new(FakeChatBackend::new()),
)
.await;
build_router(Arc::new(AppState::new(
vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()],
chat,
)))
build_router_with_dev_mode(
Arc::new(AppState::new(
vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()],
chat,
)),
dev_mode_enabled,
)
}

async fn test_app_with_request_id_headers() -> (axum::Router, MockEngineTask) {
Expand Down Expand Up @@ -1087,6 +1094,23 @@ async fn version_returns_engine_vllm_version() {
);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[serial]
async fn server_info_endpoint_is_dev_mode_only() {
let mut app = test_app().await;
let response = app
.call(
Request::builder()
.uri("/server_info")
.body(Body::empty())
.expect("build request"),
)
.await
.expect("call app");

assert_eq!(response.status(), StatusCode::NOT_FOUND);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[serial]
async fn load_lora_adapter_registers_model_and_forwards_lora_request() {
Expand Down Expand Up @@ -1302,6 +1326,23 @@ async fn load_lora_adapter_registers_model_and_forwards_lora_request() {
engine_task.finish().await;
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[serial]
async fn server_info_endpoint_returns_not_found_without_snapshot() {
let mut app = test_app_with_dev_mode(true).await;
let response = app
.call(
Request::builder()
.uri("/server_info")
.body(Body::empty())
.expect("build request"),
)
.await
.expect("call app");

assert_eq!(response.status(), StatusCode::NOT_FOUND);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[serial]
async fn unload_lora_adapter_rejects_mismatched_lora_int_id() {
Expand Down
Loading
Loading