Skip to content
Merged
73 changes: 66 additions & 7 deletions docs/design/cuda_graphs_multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Multiple CUDA Graphs are pre-captured at different **token budget** levels (e.g.
class BudgetGraphMetadata:
token_budget: int
max_batch_size: int
max_frames_per_batch: int
graph: torch.cuda.CUDAGraph
input_buffer: torch.Tensor # e.g. pixel_values
metadata_buffers: dict[str, torch.Tensor] # e.g. embeddings, seq metadata
Expand All @@ -51,6 +52,15 @@ For each graph replay:

When `mm_encoder_tp_mode="data"`, the manager distributes images across TP ranks using load-balanced assignment via `get_load_balance_assignment`, executes locally on each rank, then gathers results back in the original order via `tensor_model_parallel_all_gather`.

### Video inference support (experimental)

Following <https://github.com/vllm-project/vllm/pull/35963> (ViT full CUDA graph support for image inference), <https://github.com/vllm-project/vllm/pull/38061> extends the encoder CUDA graph framework to support video inference for Qwen3-VL. Previously, the CUDA graph capture/replay path only handled image inputs (`pixel_values` + `image_grid_thw`). Video inputs use different keys (`pixel_values_videos` + `video_grid_thw`) and require larger `cu_seqlens` buffers because each video item contributes multiple frames (`T` attention sequences). This PR generalizes the protocol and manager to handle both modalities through a single shared graph manager.

!!! note
Video CUDA graphs are automatically disabled when EVS (Efficient Video Sampling) pruning is enabled, since EVS makes the token count data-dependent and incompatible with CUDA graph capture.

Currently, we only support image-only or video-only inputs when enabling CUDA graph, mixed inputs (image + video) are not supported yet (we will work on it in the near future). Thus, it's recommended to turn off the image modality by `--limit-mm-per-prompt '{"image": 0}'` for video-only inputs.

## Model integration via `SupportsEncoderCudaGraph`

Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGraph][vllm.model_executor.models.interfaces.SupportsEncoderCudaGraph] protocol. This protocol encapsulates all model-specific logic so that the manager remains model-agnostic. The protocol defines the following methods:
Expand All @@ -65,12 +75,17 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra
* `prepare_encoder_cudagraph_replay_buffers(...)` — computes new buffer values from actual batch inputs before replay.
* `encoder_cudagraph_forward(...)` — forward pass using precomputed buffers (called during capture and replay).
* `encoder_eager_forward(...)` — fallback eager forward when no graph fits.

Currently supported: **Qwen3-VL** (see `vllm/model_executor/models/qwen3_vl.py`).
* `get_input_modality(...)` - return the modality of the inputs.

!!! note
The `SupportsEncoderCudaGraph` protocol is designed to be model-agnostic. New vision encoder models can opt-in by implementing the protocol methods without modifying the manager.

**Supported models:**

| Architecture | Models | CG for Image | CG for Video |
| ------------ | ------ | ------------ | ------------ |
| `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ |

!!! note
Encoder CUDA Graphs have currently been tested with `--mm-encoder-attn-backend=FLASH_ATTN` and `--mm-encoder-attn-backend=FLASHINFER` on Blackwell GPUs.

Expand All @@ -80,10 +95,13 @@ Three fields in `CompilationConfig` control encoder CUDA Graphs:

* `cudagraph_mm_encoder` (`bool`, default `False`) — enable CUDA Graph capture for multimodal encoder. When enabled, captures the full encoder forward as a CUDA Graph for each token budget level.
* `encoder_cudagraph_token_budgets` (`list[int]`, default `[]`) — token budget levels for capture. If empty (default), auto-inferred from model architecture as power-of-2 levels. User-provided values override auto-inference.
* `encoder_cudagraph_max_images_per_batch` (`int`, default `0`) — maximum number of images per batch during capture. If 0 (default), auto-inferred as `max_budget // min_budget`.
* `encoder_cudagraph_max_vision_items_per_batch` (`int`, default `0`) — maximum number of images/videos per batch during capture. If 0 (default), auto-inferred as `max_budget // min_budget`.
* `encoder_cudagraph_max_frames_per_batch` (`int`, default `0`) — maximum number of video frames per batch during capture. If 0 (default), auto-inferred as `encoder_cudagraph_max_vision_items_per_batch * 2` (to be optimized).

## Usage guide

### Image inference

Enable encoder CUDA Graphs via `compilation_config`:

```bash
Expand All @@ -95,7 +113,7 @@ With explicit budgets:

```bash
vllm serve Qwen/Qwen3-VL-32B \
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824], "encoder_cudagraph_max_images_per_batch": 8}'
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824], "encoder_cudagraph_max_vision_items_per_batch": 8}'
```

Python example:
Expand All @@ -107,7 +125,7 @@ compilation_config = {
"cudagraph_mm_encoder": True,
# Optional: override auto-inferred budgets
# "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824],
# "encoder_cudagraph_max_images_per_batch": 8,
# "encoder_cudagraph_max_vision_items_per_batch": 8,
}

model = vllm.LLM(
Expand All @@ -118,6 +136,44 @@ model = vllm.LLM(

The manager tracks hit/miss statistics and logs them periodically. A "hit" means an image was processed via CUDA Graph replay; a "miss" means eager fallback (image exceeded all budgets).

### Video inference

Enable encoder CUDA Graphs via `compilation_config`:

```bash
vllm serve Qwen/Qwen3-VL-32B \
--limit-mm-per-prompt '{"image": 0}' \
--compilation-config '{"cudagraph_mm_encoder": true}'
```

With explicit budgets:

```bash
vllm serve Qwen/Qwen3-VL-32B \
--limit-mm-per-prompt '{"image": 0}' \
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824], "encoder_cudagraph_max_vision_items_per_batch": 8, "encoder_cudagraph_max_frames_per_batch": 64}'
```

Python example:

```python
import vllm

compilation_config = {
"cudagraph_mm_encoder": True,
# Optional: override auto-inferred budgets
# "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824],
# "encoder_cudagraph_max_vision_items_per_batch": 8,
# "encoder_cudagraph_max_frames_per_batch": 64,
}

model = vllm.LLM(
model="Qwen/Qwen3-VL-32B",
limit_mm_per_prompt='{"image": 0}',
compilation_config=compilation_config,
)
```

## About the Performance

The following benchmarks were run on Blackwell GPUs (GB200) using `vllm bench mm-processor`. See [#35963](https://github.com/vllm-project/vllm/pull/35963) for full details.
Expand All @@ -140,7 +196,7 @@ vllm bench mm-processor \
--num-prompts 3000 --num-warmups 300 \
--max-model-len 32768 --seed 42 \
--mm-encoder-attn-backend FLASH_ATTN \
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_images_per_batch": 8}'
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_vision_items_per_batch": 8}'
```

### Multi-GPU (4x GB200, TP=4, DP=4)
Expand All @@ -165,5 +221,8 @@ vllm bench mm-processor \
--max-model-len 8192 --seed 42 \
--mm-encoder-attn-backend FLASHINFER \
--tensor-parallel-size 4 --mm-encoder-tp-mode data \
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_images_per_batch": 8}'
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_vision_items_per_batch": 8}'
```

!!! note
Find more details about benchmarks on GPUs (A100) for video inference at [#38061](https://github.com/vllm-project/vllm/pull/38061).
Loading
Loading