diff --git a/docs/design/cuda_graphs_multimodal.md b/docs/design/cuda_graphs_multimodal.md index 5515c91a8b69..9c15f02858c2 100644 --- a/docs/design/cuda_graphs_multimodal.md +++ b/docs/design/cuda_graphs_multimodal.md @@ -28,6 +28,7 @@ Multiple CUDA Graphs are pre-captured at different **token budget** levels (e.g. class BudgetGraphMetadata: token_budget: int max_batch_size: int + max_frames_per_batch: int graph: torch.cuda.CUDAGraph input_buffer: torch.Tensor # e.g. pixel_values metadata_buffers: dict[str, torch.Tensor] # e.g. embeddings, seq metadata @@ -51,6 +52,15 @@ For each graph replay: When `mm_encoder_tp_mode="data"`, the manager distributes images across TP ranks using load-balanced assignment via `get_load_balance_assignment`, executes locally on each rank, then gathers results back in the original order via `tensor_model_parallel_all_gather`. +### Video inference support (experimental) + +Following (ViT full CUDA graph support for image inference), extends the encoder CUDA graph framework to support video inference for Qwen3-VL. Previously, the CUDA graph capture/replay path only handled image inputs (`pixel_values` + `image_grid_thw`). Video inputs use different keys (`pixel_values_videos` + `video_grid_thw`) and require larger `cu_seqlens` buffers because each video item contributes multiple frames (`T` attention sequences). This PR generalizes the protocol and manager to handle both modalities through a single shared graph manager. + +!!! note + Video CUDA graphs are automatically disabled when EVS (Efficient Video Sampling) pruning is enabled, since EVS makes the token count data-dependent and incompatible with CUDA graph capture. + + Currently, we only support image-only or video-only inputs when enabling CUDA graph, mixed inputs (image + video) are not supported yet (we will work on it in the near future). Thus, it's recommended to turn off the image modality by `--limit-mm-per-prompt '{"image": 0}'` for video-only inputs. + ## Model integration via `SupportsEncoderCudaGraph` Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGraph][vllm.model_executor.models.interfaces.SupportsEncoderCudaGraph] protocol. This protocol encapsulates all model-specific logic so that the manager remains model-agnostic. The protocol defines the following methods: @@ -65,12 +75,17 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra * `prepare_encoder_cudagraph_replay_buffers(...)` — computes new buffer values from actual batch inputs before replay. * `encoder_cudagraph_forward(...)` — forward pass using precomputed buffers (called during capture and replay). * `encoder_eager_forward(...)` — fallback eager forward when no graph fits. - -Currently supported: **Qwen3-VL** (see `vllm/model_executor/models/qwen3_vl.py`). +* `get_input_modality(...)` - return the modality of the inputs. !!! note The `SupportsEncoderCudaGraph` protocol is designed to be model-agnostic. New vision encoder models can opt-in by implementing the protocol methods without modifying the manager. +**Supported models:** + +| Architecture | Models | CG for Image | CG for Video | +| ------------ | ------ | ------------ | ------------ | +| `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ | + !!! note Encoder CUDA Graphs have currently been tested with `--mm-encoder-attn-backend=FLASH_ATTN` and `--mm-encoder-attn-backend=FLASHINFER` on Blackwell GPUs. @@ -80,10 +95,13 @@ Three fields in `CompilationConfig` control encoder CUDA Graphs: * `cudagraph_mm_encoder` (`bool`, default `False`) — enable CUDA Graph capture for multimodal encoder. When enabled, captures the full encoder forward as a CUDA Graph for each token budget level. * `encoder_cudagraph_token_budgets` (`list[int]`, default `[]`) — token budget levels for capture. If empty (default), auto-inferred from model architecture as power-of-2 levels. User-provided values override auto-inference. -* `encoder_cudagraph_max_images_per_batch` (`int`, default `0`) — maximum number of images per batch during capture. If 0 (default), auto-inferred as `max_budget // min_budget`. +* `encoder_cudagraph_max_vision_items_per_batch` (`int`, default `0`) — maximum number of images/videos per batch during capture. If 0 (default), auto-inferred as `max_budget // min_budget`. +* `encoder_cudagraph_max_frames_per_batch` (`int`, default `0`) — maximum number of video frames per batch during capture. If 0 (default), auto-inferred as `encoder_cudagraph_max_vision_items_per_batch * 2` (to be optimized). ## Usage guide +### Image inference + Enable encoder CUDA Graphs via `compilation_config`: ```bash @@ -95,7 +113,7 @@ With explicit budgets: ```bash vllm serve Qwen/Qwen3-VL-32B \ - --compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824], "encoder_cudagraph_max_images_per_batch": 8}' + --compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824], "encoder_cudagraph_max_vision_items_per_batch": 8}' ``` Python example: @@ -107,7 +125,7 @@ compilation_config = { "cudagraph_mm_encoder": True, # Optional: override auto-inferred budgets # "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824], - # "encoder_cudagraph_max_images_per_batch": 8, + # "encoder_cudagraph_max_vision_items_per_batch": 8, } model = vllm.LLM( @@ -118,6 +136,44 @@ model = vllm.LLM( The manager tracks hit/miss statistics and logs them periodically. A "hit" means an image was processed via CUDA Graph replay; a "miss" means eager fallback (image exceeded all budgets). +### Video inference + +Enable encoder CUDA Graphs via `compilation_config`: + +```bash +vllm serve Qwen/Qwen3-VL-32B \ + --limit-mm-per-prompt '{"image": 0}' \ + --compilation-config '{"cudagraph_mm_encoder": true}' +``` + +With explicit budgets: + +```bash +vllm serve Qwen/Qwen3-VL-32B \ + --limit-mm-per-prompt '{"image": 0}' \ + --compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824], "encoder_cudagraph_max_vision_items_per_batch": 8, "encoder_cudagraph_max_frames_per_batch": 64}' +``` + +Python example: + +```python +import vllm + +compilation_config = { + "cudagraph_mm_encoder": True, + # Optional: override auto-inferred budgets + # "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824], + # "encoder_cudagraph_max_vision_items_per_batch": 8, + # "encoder_cudagraph_max_frames_per_batch": 64, +} + +model = vllm.LLM( + model="Qwen/Qwen3-VL-32B", + limit_mm_per_prompt='{"image": 0}', + compilation_config=compilation_config, +) +``` + ## About the Performance The following benchmarks were run on Blackwell GPUs (GB200) using `vllm bench mm-processor`. See [#35963](https://github.com/vllm-project/vllm/pull/35963) for full details. @@ -140,7 +196,7 @@ vllm bench mm-processor \ --num-prompts 3000 --num-warmups 300 \ --max-model-len 32768 --seed 42 \ --mm-encoder-attn-backend FLASH_ATTN \ - --compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_images_per_batch": 8}' + --compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_vision_items_per_batch": 8}' ``` ### Multi-GPU (4x GB200, TP=4, DP=4) @@ -165,5 +221,8 @@ vllm bench mm-processor \ --max-model-len 8192 --seed 42 \ --mm-encoder-attn-backend FLASHINFER \ --tensor-parallel-size 4 --mm-encoder-tp-mode data \ - --compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_images_per_batch": 8}' + --compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_vision_items_per_batch": 8}' ``` + +!!! note + Find more details about benchmarks on GPUs (A100) for video inference at [#38061](https://github.com/vllm-project/vllm/pull/38061). diff --git a/tests/v1/cudagraph/test_encoder_cudagraph.py b/tests/v1/cudagraph/test_encoder_cudagraph.py index 543dfc8bb316..3ad18cf9a786 100644 --- a/tests/v1/cudagraph/test_encoder_cudagraph.py +++ b/tests/v1/cudagraph/test_encoder_cudagraph.py @@ -6,8 +6,10 @@ No GPU required: - TestFindBudgetGraph — greedy budget selection logic - TestGetCumulativeStats — hit/miss rate statistics + - TestGetInputModality — modality routing from mm_kwargs keys GPU required: - TestEncoderCudaGraphCaptureReplay — capture, replay, fallback, counters, chunking + - TestEncoderCudaGraphVideoReplay — video modality capture, replay """ from typing import Any @@ -205,11 +207,19 @@ def __init__(self): def get_encoder_cudagraph_config(self) -> EncoderCudaGraphConfig: return EncoderCudaGraphConfig( modalities=["image"], - input_key="pixel_values", + input_key_by_modality={ + "image": "pixel_values", + }, buffer_keys=["dummy_buf"], out_hidden_size=_HIDDEN, ) + def get_input_modality( + self, + mm_kwargs: dict[str, Any], + ) -> str: + return "image" + def get_encoder_cudagraph_budget_range( self, vllm_config, @@ -268,6 +278,7 @@ def prepare_encoder_cudagraph_capture_inputs( self, token_budget: int, max_batch_size: int, + max_frames_per_batch: int, device: torch.device, dtype: torch.dtype, ) -> EncoderCudaGraphCaptureInputs: @@ -294,6 +305,7 @@ def prepare_encoder_cudagraph_replay_buffers( self, mm_kwargs: dict[str, Any], max_batch_size: int, + max_frames_per_batch: int, ) -> EncoderCudaGraphReplayBuffers: grid_thw = mm_kwargs["image_grid_thw"] n_out = _count_output_tokens(grid_thw, _SPATIAL_MERGE) @@ -327,11 +339,16 @@ def _make_manager_for_gpu( max_batch_size: int, device: torch.device, dtype: torch.dtype, + *, + max_frames_per_batch: int | None = None, ) -> EncoderCudaGraphManager: """Create EncoderCudaGraphManager bypassing VllmConfig for GPU tests.""" mgr = object.__new__(EncoderCudaGraphManager) mgr.token_budgets = sorted(token_budgets) mgr.max_batch_size = max_batch_size + mgr.max_frames_per_batch = ( + max_frames_per_batch if max_frames_per_batch is not None else max_batch_size * 2 + ) mgr.use_dp = False mgr.budget_graphs = {} mgr.graph_hits = 0 @@ -366,6 +383,18 @@ def _make_mm_kwargs( } +def _make_video_mm_kwargs( + grid_thw_list: list[list[int]], + device: torch.device, + dtype: torch.dtype, +) -> dict[str, Any]: + """Create video mm_kwargs (pixel_values_videos / video_grid_thw) for testing.""" + return { + "pixel_values_videos": _make_pixel_values(grid_thw_list, device, dtype), + "video_grid_thw": grid_thw_list, + } + + # --------------------------------------------------------------------------- # GPU tests — capture, replay, fallback, counters, chunking # --------------------------------------------------------------------------- @@ -449,3 +478,285 @@ def test_chunking_when_images_exceed_max_batch(self): assert len(result) == n_images for out in result: assert out.shape == (4, _HIDDEN) + + +# --------------------------------------------------------------------------- +# SimpleMockViTVideoModel — extends SimpleMockViTModel with video support +# --------------------------------------------------------------------------- + + +class SimpleMockViTVideoModel(SimpleMockViTModel): + """ViT mock that supports both image and video modalities. + + Reuses SimpleMockViTModel's NN weights and _forward() logic. + Only the protocol methods that are key-dependent are overridden. + """ + + def get_encoder_cudagraph_config(self) -> EncoderCudaGraphConfig: + return EncoderCudaGraphConfig( + modalities=["image", "video"], + input_key_by_modality={ + "image": "pixel_values", + "video": "pixel_values_videos", + }, + buffer_keys=["dummy_buf"], + out_hidden_size=_HIDDEN, + ) + + def get_input_modality(self, mm_kwargs: dict[str, Any]) -> str: + return "video" if "video_grid_thw" in mm_kwargs else "image" + + # ------------------------------------------------------------------ + # Private helpers — route to the correct mm_kwargs keys + # ------------------------------------------------------------------ + + def _get_grid_thw(self, mm_kwargs: dict[str, Any]) -> list[list[int]]: + key = ( + "video_grid_thw" + if self.get_input_modality(mm_kwargs) == "video" + else "image_grid_thw" + ) + return mm_kwargs[key] + + def _get_pixel_values(self, mm_kwargs: dict[str, Any]) -> torch.Tensor: + key = ( + "pixel_values_videos" + if self.get_input_modality(mm_kwargs) == "video" + else "pixel_values" + ) + return mm_kwargs[key] + + # ------------------------------------------------------------------ + # Protocol overrides that depend on modality keys + # ------------------------------------------------------------------ + + def get_encoder_cudagraph_num_items(self, mm_kwargs: dict[str, Any]) -> int: + return len(self._get_grid_thw(mm_kwargs)) + + def get_encoder_cudagraph_per_item_output_tokens( + self, mm_kwargs: dict[str, Any] + ) -> list[int]: + m = _SPATIAL_MERGE + return [t * (h // m) * (w // m) for t, h, w in self._get_grid_thw(mm_kwargs)] + + def get_encoder_cudagraph_per_item_input_sizes( + self, mm_kwargs: dict[str, Any] + ) -> list[int]: + return [t * h * w for t, h, w in self._get_grid_thw(mm_kwargs)] + + def select_encoder_cudagraph_items( + self, mm_kwargs: dict[str, Any], indices: list[int] + ) -> dict[str, Any]: + modality = self.get_input_modality(mm_kwargs) + pv_key = "pixel_values_videos" if modality == "video" else "pixel_values" + grid_key = "video_grid_thw" if modality == "video" else "image_grid_thw" + + grid_thw = self._get_grid_thw(mm_kwargs) + pixel_values = self._get_pixel_values(mm_kwargs) + + if len(indices) == 0: + return {pv_key: pixel_values[:0], grid_key: []} + + patches_per_item = [t * h * w for t, h, w in grid_thw] + cum_patches = [0] + for p in patches_per_item: + cum_patches.append(cum_patches[-1] + p) + + selected_pv = torch.cat( + [pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices] + ) + return {pv_key: selected_pv, grid_key: [grid_thw[i] for i in indices]} + + def prepare_encoder_cudagraph_capture_inputs( + self, + token_budget: int, + max_batch_size: int, + max_frames_per_batch: int, + device: torch.device, + dtype: torch.dtype, + ) -> EncoderCudaGraphCaptureInputs: + per_item_output = token_budget // max_batch_size + frames_per_item = max_frames_per_batch // max_batch_size + if frames_per_item > 1: + # Video-format capture: size cu_seqlens for T frames per item. + tokens_per_frame = ( + per_item_output + frames_per_item - 1 + ) // frames_per_item + grid_config = [ + [frames_per_item, _SPATIAL_MERGE, tokens_per_frame * _SPATIAL_MERGE] + for _ in range(max_batch_size) + ] + else: + grid_config = [ + [1, _SPATIAL_MERGE, per_item_output * _SPATIAL_MERGE] + for _ in range(max_batch_size) + ] + total_patches = _count_input_patches(grid_config) + # Use pixel_values (image key) for capture — same patch shape as video. + dummy_pixel_values = torch.randn( + total_patches, _FLAT, device=device, dtype=dtype + ) + n_out = _count_output_tokens(grid_config, _SPATIAL_MERGE) + dummy_buf = torch.zeros(n_out, _HIDDEN, device=device, dtype=dtype) + return EncoderCudaGraphCaptureInputs( + mm_kwargs={ + "pixel_values": dummy_pixel_values, + "image_grid_thw": grid_config, + }, + buffers={"dummy_buf": dummy_buf}, + ) + + def prepare_encoder_cudagraph_replay_buffers( + self, + mm_kwargs: dict[str, Any], + max_batch_size: int, + max_frames_per_batch: int, + ) -> EncoderCudaGraphReplayBuffers: + n_out = _count_output_tokens(self._get_grid_thw(mm_kwargs), _SPATIAL_MERGE) + p = next(self.parameters()) + dummy_buf = torch.zeros(n_out, _HIDDEN, device=p.device, dtype=p.dtype) + return EncoderCudaGraphReplayBuffers(buffers={"dummy_buf": dummy_buf}) + + def encoder_cudagraph_forward( + self, mm_kwargs: dict[str, Any], buffers: dict[str, torch.Tensor] + ) -> torch.Tensor: + return self._forward(self._get_pixel_values(mm_kwargs)) + + def encoder_eager_forward(self, mm_kwargs: dict[str, Any]) -> torch.Tensor: + return self._forward(self._get_pixel_values(mm_kwargs)) + + +# --------------------------------------------------------------------------- +# No-GPU tests — get_input_modality routing +# --------------------------------------------------------------------------- + + +class TestGetInputModality: + """get_input_modality returns correct modality based on mm_kwargs keys.""" + + def test_image_only_model_always_returns_image(self): + model = SimpleMockViTModel() + mm_kwargs = { + "pixel_values": torch.zeros(1, _FLAT), + "image_grid_thw": [[1, 4, 4]], + } + assert model.get_input_modality(mm_kwargs) == "image" + + def test_video_model_returns_image_for_image_kwargs(self): + model = SimpleMockViTVideoModel() + mm_kwargs = { + "pixel_values": torch.zeros(1, _FLAT), + "image_grid_thw": [[1, 4, 4]], + } + assert model.get_input_modality(mm_kwargs) == "image" + + def test_video_model_returns_video_for_video_kwargs(self): + model = SimpleMockViTVideoModel() + mm_kwargs = { + "pixel_values_videos": torch.zeros(8, _FLAT), + "video_grid_thw": [[2, 4, 4]], + } + assert model.get_input_modality(mm_kwargs) == "video" + + def test_video_model_config_has_both_modalities(self): + model = SimpleMockViTVideoModel() + cfg = model.get_encoder_cudagraph_config() + assert "image" in cfg.modalities + assert "video" in cfg.modalities + assert cfg.input_key_by_modality["image"] == "pixel_values" + assert cfg.input_key_by_modality["video"] == "pixel_values_videos" + + +# --------------------------------------------------------------------------- +# GPU tests — video capture, replay, fallback, and mixed image+video +# --------------------------------------------------------------------------- + +_VIDEO_MAX_BATCH = 4 +_VIDEO_MAX_FRAMES = 8 # 2 frames per item at max_batch_size=4 + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") +class TestEncoderCudaGraphVideoReplay: + def setup_method(self): + self.device = torch.device("cuda:0") + self.dtype = torch.float16 + self.model = SimpleMockViTVideoModel().to(self.device).half() + self.mgr = _make_manager_for_gpu( + self.model, + _BUDGETS, + _VIDEO_MAX_BATCH, + self.device, + self.dtype, + max_frames_per_batch=_VIDEO_MAX_FRAMES, + ) + self.mgr.capture() + + # --- capture --- + + def test_capture_creates_one_graph_per_budget(self): + assert len(self.mgr.budget_graphs) == len(_BUDGETS) + assert set(self.mgr.budget_graphs.keys()) == set(_BUDGETS) + + # --- output shape --- + + def test_video_execute_returns_one_tensor_per_video(self): + # T=2, 4x4 → 2*(4//2)*(4//2) = 8 tokens per video + grid_thw = [[2, 4, 4], [2, 4, 4]] + mm_kwargs = _make_video_mm_kwargs(grid_thw, self.device, self.dtype) + result = self.mgr.execute(mm_kwargs) + assert result is not None + assert len(result) == 2 + + def test_video_output_tokens_per_item(self): + # T=2,4x4 → 8 tokens; T=1,4x4 → 4 tokens + grid_thw = [[2, 4, 4], [1, 4, 4]] + mm_kwargs = _make_video_mm_kwargs(grid_thw, self.device, self.dtype) + result = self.mgr.execute(mm_kwargs) + assert result is not None + assert result[0].shape == (8, _HIDDEN) + assert result[1].shape == (4, _HIDDEN) + + # --- budget fallback --- + + def test_video_eager_fallback_when_tokens_exceed_all_budgets(self): + # T=2, 18x18 → 2*(18//2)*(18//2) = 162 tokens > max budget 64 + grid_thw = [[2, 18, 18]] + mm_kwargs = _make_video_mm_kwargs(grid_thw, self.device, self.dtype) + result = self.mgr.execute(mm_kwargs) + assert result is not None + assert len(result) == 1 + assert result[0].shape == (162, _HIDDEN) + assert self.mgr.graph_misses == 1 + + # --- counters --- + + def test_video_hit_counter_increments_by_num_videos(self): + grid_thw = [[2, 4, 4], [1, 4, 4]] + mm_kwargs = _make_video_mm_kwargs(grid_thw, self.device, self.dtype) + self.mgr.execute(mm_kwargs) + assert self.mgr.graph_hits == 2 + + def test_video_miss_counter_increments_for_oversized_video(self): + grid_thw = [[2, 18, 18]] # 162 tokens > 64 + mm_kwargs = _make_video_mm_kwargs(grid_thw, self.device, self.dtype) + self.mgr.execute(mm_kwargs) + assert self.mgr.graph_misses == 1 + + # --- image and video sharing the same manager --- + + def test_image_and_video_share_manager(self): + """Image and video inputs can both be executed through the same manager.""" + img_grid = [[1, 4, 4], [1, 4, 4]] + img_result = self.mgr.execute( + _make_mm_kwargs(img_grid, self.device, self.dtype) + ) + + vid_grid = [[2, 4, 4]] + vid_result = self.mgr.execute( + _make_video_mm_kwargs(vid_grid, self.device, self.dtype) + ) + + assert len(img_result) == 2 + assert len(vid_result) == 1 + assert img_result[0].shape == (4, _HIDDEN) + assert vid_result[0].shape == (8, _HIDDEN) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 11c933fc72f5..7a361f9fdb70 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -519,13 +519,21 @@ class CompilationConfig: User-provided values override auto-inference. Example: [2048, 4096, 8192, 13824]""" - encoder_cudagraph_max_images_per_batch: int = 0 - """Maximum number of images per batch for encoder CUDA graph capture. + encoder_cudagraph_max_vision_items_per_batch: int = 0 + """Maximum number of images/videos per batch for encoder CUDA graph capture. Determines the fixed batch size used during graph capture. If 0 (default), auto-inferred as max_budget // min_budget from the model's budget range. User-provided positive value overrides auto-inference.""" + encoder_cudagraph_max_frames_per_batch: int = 0 + """Maximum total video frames per batch for encoder CUDA graph capture. + Controls the cu_seqlens buffer size (one entry per attention sequence, + i.e. one per video frame). If 0 (default), auto-inferred per budget + level as token_budget (tight bound: packing guarantees + sum(T_i) <= token_budget). Positive value overrides auto-inference + and applies to all budget levels.""" + # Inductor capture compile_sizes: list[int | str] | None = None """Sizes to compile for inductor. In addition @@ -964,10 +972,18 @@ def __post_init__(self) -> None: # Validate encoder CUDA graph configuration if ( self.cudagraph_mm_encoder - and self.encoder_cudagraph_max_images_per_batch < 0 + and self.encoder_cudagraph_max_vision_items_per_batch < 0 + ): + raise ValueError( + "encoder_cudagraph_max_vision_items_per_batch must be " + "non-negative (0 = auto-infer)" + ) + if ( + self.cudagraph_mm_encoder + and self.encoder_cudagraph_max_frames_per_batch < 0 ): raise ValueError( - "encoder_cudagraph_max_images_per_batch must be " + "encoder_cudagraph_max_frames_per_batch must be " "non-negative (0 = auto-infer)" ) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index d03205689790..c24798e08402 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1524,6 +1524,13 @@ class SupportsEncoderCudaGraph(Protocol): def get_encoder_cudagraph_config(self) -> "EncoderCudaGraphConfig": ... + def get_input_modality( + self, + mm_kwargs: dict[str, Any], + ) -> str: + """Return the modality of the inputs.""" + ... + def get_encoder_cudagraph_budget_range( self, vllm_config: "VllmConfig", @@ -1536,7 +1543,7 @@ def get_encoder_cudagraph_budget_range( (e.g. max_num_batched_tokens) Used when ``encoder_cudagraph_token_budgets`` and/or - ``encoder_cudagraph_max_images_per_batch`` are not explicitly + ``encoder_cudagraph_max_vision_items_per_batch`` are not explicitly specified by the user. """ ... @@ -1590,6 +1597,7 @@ def prepare_encoder_cudagraph_capture_inputs( self, token_budget: int, max_batch_size: int, + max_frames_per_batch: int, device: torch.device, dtype: torch.dtype, ) -> "EncoderCudaGraphCaptureInputs": @@ -1600,6 +1608,7 @@ def prepare_encoder_cudagraph_replay_buffers( self, mm_kwargs: dict[str, Any], max_batch_size: int, + max_frames_per_batch: int, ) -> "EncoderCudaGraphReplayBuffers": """Compute buffer values from actual batch inputs for replay.""" ... diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1aa5dec5390b..fee0b937b7da 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -99,6 +99,7 @@ from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.utils.collection_utils import is_list_of from vllm.utils.math_utils import round_up +from vllm.v1.worker.encoder_cudagraph_defs import EncoderCudaGraphReplayBuffers from .interfaces import ( MultiModalEmbeddings, @@ -689,6 +690,7 @@ def prepare_encoder_metadata( grid_thw_list: list[list[int]], *, max_batch_size: int | None = None, + max_frames_per_batch: int | None = None, max_seqlen_override: int | None = None, device: torch.device | None = None, ) -> dict[str, torch.Tensor | None]: @@ -701,6 +703,10 @@ def prepare_encoder_metadata( grid_thw_list: Grid configurations as list of [t, h, w]. max_batch_size: If set, pad cu_seqlens to this size (needed for CUDA graph capture/replay). + max_frames_per_batch: If set, overrides max_batch_size for + cu_seqlens padding. For video inputs each item contributes + T attention sequences (frames); this sizes the buffer to + the total frame budget so video replays never overflow. max_seqlen_override: If set, use this value for max_seqlen instead of computing from cu_seqlens (needed for CUDA graph capture to cover worst-case replay scenarios). @@ -725,15 +731,21 @@ def prepare_encoder_metadata( ) cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) - # Pad cu_seqlens if max_batch_size specified - if max_batch_size is not None: + # Pad cu_seqlens to the required number of sequences. + # For videos each item contributes T frames = T attention sequences, + # so the total can exceed max_batch_size. max_frames_per_batch + # overrides the pad target when set. + pad_to = ( + max_frames_per_batch if max_frames_per_batch is not None else max_batch_size + ) + if pad_to is not None: num_seqs = len(cu_seqlens) - 1 - if num_seqs < max_batch_size: + if num_seqs < pad_to: cu_seqlens = np.concatenate( [ cu_seqlens, np.full( - max_batch_size - num_seqs, + pad_to - num_seqs, cu_seqlens[-1], dtype=np.int32, ), @@ -1737,9 +1749,21 @@ def get_encoder_cudagraph_config(self): EncoderCudaGraphConfig, ) + modalities = ["image"] + # NOTE: When EVS (Efficient Video Sampling) pruning is enabled, the number + # of tokens becomes data-dependent (i.e., the retained tokens are + # dynamically selected based on inter-frame differences) and therefore + # cannot be captured by CUDA Graphs. As a result, video CUDA Graphs are + # only enabled when EVS is disabled. + if not self.is_multimodal_pruning_enabled: + modalities.append("video") + return EncoderCudaGraphConfig( - modalities=["image"], - input_key="pixel_values", + modalities=modalities, + input_key_by_modality={ + "image": "pixel_values", + "video": "pixel_values_videos", + }, buffer_keys=[ "pos_embeds", "rotary_pos_emb_cos", @@ -1751,49 +1775,86 @@ def get_encoder_cudagraph_config(self): out_hidden_size=self.visual.out_hidden_size, ) + def get_input_modality( + self, + mm_kwargs: dict[str, Any], + ) -> str: + if "image_grid_thw" in mm_kwargs: + return "image" + return "video" + def get_encoder_cudagraph_budget_range( self, vllm_config, ) -> tuple[int, int]: # Min: estimated smallest possible encoder input. - # 224x224 image → 16x16 patches, spatial_merge_size=2 → 8x8 = 64 tokens + # 224x224 image → 16x16 patches (patch_size=14) + # spatial_merge_size=2 → 8x8 = 64 tokens min_budget = 64 # Max: capped by max_num_batched_tokens max_budget = vllm_config.scheduler_config.max_num_batched_tokens return (min_budget, max_budget) + def _get_pixel_values_by_modality( + self, + mm_kwargs: dict[str, Any], + ) -> torch.Tensor: + if self.get_input_modality(mm_kwargs) == "image": + pixel_values = mm_kwargs["pixel_values"] + else: + pixel_values = mm_kwargs["pixel_values_videos"] + return pixel_values + + def _get_grid_thw_by_modality( + self, + mm_kwargs: dict[str, Any], + ) -> list[tuple[int, int, int]]: + grid_thw_key = f"{self.get_input_modality(mm_kwargs)}_grid_thw" + grid_thw = mm_kwargs[grid_thw_key] + if not isinstance(grid_thw, list): + grid_thw = grid_thw.tolist() + return grid_thw + def get_encoder_cudagraph_num_items( self, mm_kwargs: dict[str, Any], ) -> int: - return len(mm_kwargs["image_grid_thw"]) + return len(self._get_grid_thw_by_modality(mm_kwargs)) def get_encoder_cudagraph_per_item_output_tokens( self, mm_kwargs: dict[str, Any], ) -> list[int]: m = self.visual.spatial_merge_size - return [t * (h // m) * (w // m) for t, h, w in mm_kwargs["image_grid_thw"]] + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) + return [t * (h // m) * (w // m) for t, h, w in grid_thw] def get_encoder_cudagraph_per_item_input_sizes( self, mm_kwargs: dict[str, Any], ) -> list[int]: - return [t * h * w for t, h, w in mm_kwargs["image_grid_thw"]] + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) + return [t * h * w for t, h, w in grid_thw] def select_encoder_cudagraph_items( self, mm_kwargs: dict[str, Any], indices: list[int], ) -> dict[str, Any]: - grid_thw = mm_kwargs["image_grid_thw"] - pixel_values = mm_kwargs["pixel_values"] + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) + pixel_values = self._get_pixel_values_by_modality(mm_kwargs) if len(indices) == 0: - return { - "pixel_values": pixel_values[:0], - "image_grid_thw": [], - } + if self.get_input_modality(mm_kwargs) == "image": + return { + "pixel_values": pixel_values[:0], + "image_grid_thw": [], + } + else: + return { + "pixel_values_videos": pixel_values[:0], + "video_grid_thw": [], + } # Compute cumulative patch offsets for slicing pixel_values patches_per_item = [t * h * w for t, h, w in grid_thw] @@ -1806,15 +1867,22 @@ def select_encoder_cudagraph_items( ) selected_grid = [grid_thw[i] for i in indices] - return { - "pixel_values": selected_pv, - "image_grid_thw": selected_grid, - } + if self.get_input_modality(mm_kwargs) == "image": + return { + "pixel_values": selected_pv, + "image_grid_thw": selected_grid, + } + else: + return { + "pixel_values_videos": selected_pv, + "video_grid_thw": selected_grid, + } def prepare_encoder_cudagraph_capture_inputs( self, token_budget: int, max_batch_size: int, + max_frames_per_batch: int, device: torch.device, dtype: torch.dtype, ): @@ -1823,14 +1891,35 @@ def prepare_encoder_cudagraph_capture_inputs( ) spatial_merge_size = self.visual.spatial_merge_size - per_image_output = token_budget // max_batch_size - - # Synthetic rectangular grid: [1, merge, per_image_output * merge] - # produces exactly per_image_output tokens per image. - grid_config = [ - [1, spatial_merge_size, per_image_output * spatial_merge_size] - for _ in range(max_batch_size) - ] + per_mm_item_output = token_budget // max_batch_size + + frames_per_item = max_frames_per_batch // max_batch_size + if frames_per_item > 1: + # Build the capture grid using a video-format layout so that + # cu_seqlens is sized for video replays from the start. + # cu_seqlens has one entry per attention sequence (one per frame), + # so using T > 1 per item makes the buffer large enough without + # relying solely on padding. + # Ceiling ensures frames_per_item * tokens_per_frame >= per_mm_item_output + # so the pixel_values buffer covers any valid single-item replay. + tokens_per_frame = ( + per_mm_item_output + frames_per_item - 1 + ) // frames_per_item + # Video-format grid_config (T=frames_per_item). + grid_config = [ + [ + frames_per_item, + spatial_merge_size, + tokens_per_frame * spatial_merge_size, + ] + for _ in range(max_batch_size) + ] + else: + # Image-format grid_config (T=1). + grid_config = [ + [1, spatial_merge_size, per_mm_item_output * spatial_merge_size] + for _ in range(max_batch_size) + ] # Create dummy pixel_values patch_embed = self.visual.patch_embed @@ -1848,15 +1937,18 @@ def prepare_encoder_cudagraph_capture_inputs( # Override max_seqlen with a safe upper bound for capture. # max_seqlen.item() gets baked into the CUDA graph (not replayed), # so the capture value must cover any replay scenario. - # Worst case: 1 image consuming the full budget -> + # Worst case: 1 item consuming the full budget -> # seq_len = token_budget * spatial_merge_size^2. buffers = self.visual.prepare_encoder_metadata( grid_config, max_batch_size=max_batch_size, + max_frames_per_batch=max_frames_per_batch, max_seqlen_override=token_budget * (spatial_merge_size**2), device=device, ) + # Just use image-modality dummy input_buffer for capturing, since it's also + # compatible for video inputs (has the same shape: [num_patches, C*T*P*P]). mm_kwargs = { "pixel_values": dummy_pixel_values, "image_grid_thw": grid_config, @@ -1871,17 +1963,21 @@ def prepare_encoder_cudagraph_replay_buffers( self, mm_kwargs: dict[str, Any], max_batch_size: int, + max_frames_per_batch: int, ): - from vllm.v1.worker.encoder_cudagraph_defs import ( - EncoderCudaGraphReplayBuffers, - ) + modality = self.get_input_modality(mm_kwargs) + grid_thw_list = self._get_grid_thw_by_modality(mm_kwargs) - grid_thw_list = mm_kwargs["image_grid_thw"] - - buffers = self.visual.prepare_encoder_metadata( - grid_thw_list, - max_batch_size=max_batch_size, - ) + if modality == "image": + buffers = self.visual.prepare_encoder_metadata( + grid_thw_list, + max_batch_size=max_batch_size, + ) + else: + buffers = self.visual.prepare_encoder_metadata( + grid_thw_list, + max_frames_per_batch=max_frames_per_batch, + ) return EncoderCudaGraphReplayBuffers(buffers=buffers) @@ -1890,16 +1986,16 @@ def encoder_cudagraph_forward( mm_kwargs: dict[str, Any], buffers: dict[str, torch.Tensor], ) -> torch.Tensor: - pixel_values = mm_kwargs["pixel_values"] - grid_thw = mm_kwargs["image_grid_thw"] + pixel_values = self._get_pixel_values_by_modality(mm_kwargs) + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) return self.visual(pixel_values, grid_thw, encoder_metadata=buffers) def encoder_eager_forward( self, mm_kwargs: dict[str, Any], ) -> torch.Tensor: - pixel_values = mm_kwargs["pixel_values"] - grid_thw = mm_kwargs["image_grid_thw"] + pixel_values = self._get_pixel_values_by_modality(mm_kwargs) + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) return self.visual(pixel_values, grid_thw) def _parse_and_validate_image_input( diff --git a/vllm/v1/worker/encoder_cudagraph.py b/vllm/v1/worker/encoder_cudagraph.py index 0fabbc77c07a..65ad8e2cec16 100644 --- a/vllm/v1/worker/encoder_cudagraph.py +++ b/vllm/v1/worker/encoder_cudagraph.py @@ -36,6 +36,7 @@ class BudgetGraphMetadata: token_budget: int max_batch_size: int # Max number of images/videos per batch + max_frames_per_batch: int # Max total frames per batch (for video) graph: torch.cuda.CUDAGraph # The input tensor updated before replay (e.g. pixel_values) input_buffer: torch.Tensor @@ -66,12 +67,13 @@ def __init__( comp_config = vllm_config.compilation_config user_budgets = comp_config.encoder_cudagraph_token_budgets - user_max_images = comp_config.encoder_cudagraph_max_images_per_batch + user_max_mm_items = comp_config.encoder_cudagraph_max_vision_items_per_batch + user_max_frames = comp_config.encoder_cudagraph_max_frames_per_batch - if user_budgets and user_max_images > 0: + if user_budgets and user_max_mm_items > 0: # Fully user-specified self.token_budgets = sorted(user_budgets) - self.max_batch_size = user_max_images + self.max_batch_size = user_max_mm_items else: # Auto-infer missing values from model min_budget, max_budget = model.get_encoder_cudagraph_budget_range( @@ -83,9 +85,15 @@ def __init__( else self._generate_budgets(min_budget, max_budget) ) self.max_batch_size = ( - user_max_images if user_max_images > 0 else max_budget // min_budget + user_max_mm_items if user_max_mm_items > 0 else max_budget // min_budget ) + if user_max_frames > 0: + self.max_frames_per_batch = user_max_frames + else: + # TODO(shen-shanshan): optimize this auto-infer for max_frames_per_batch. + self.max_frames_per_batch = self.max_batch_size * 2 + mm_config = vllm_config.model_config.multimodal_config self.use_dp = ( mm_config is not None @@ -100,9 +108,10 @@ def __init__( logger.info( "EncoderCudaGraphManager initialized with " - "budgets=%s, max_batch_size=%d, use_dp=%s", + "budgets=%s, max_batch_size=%d, max_frames_per_batch=%s, use_dp=%s", self.token_budgets, self.max_batch_size, + self.max_frames_per_batch if self.max_frames_per_batch > 0 else "auto", self.use_dp, ) @@ -136,13 +145,19 @@ def capture(self): def _capture_budget_graph(self, token_budget: int): """Capture CUDA graph for a single token budget.""" logger.debug( - "Capturing encoder cudagraph for budget=%d, max_batch_size=%d", + "Capturing encoder cudagraph for budget=%d, max_batch_size=%d, " + "max_frames_per_batch=%d", token_budget, self.max_batch_size, + self.max_frames_per_batch, ) capture_inputs = self.model.prepare_encoder_cudagraph_capture_inputs( - token_budget, self.max_batch_size, self.device, self.dtype + token_budget, + self.max_batch_size, + self.max_frames_per_batch, + self.device, + self.dtype, ) mm_kwargs = capture_inputs.mm_kwargs @@ -157,10 +172,14 @@ def _capture_budget_graph(self, token_budget: int): output = self.model.encoder_cudagraph_forward(mm_kwargs, buffers) output_buffer.copy_(output) - input_key = self.config.input_key + # Since the image and video modalities share the same per-patch shape, + # so we can use the image dummy inputs to capture CUDA graph for both + # image and video. + input_key = self.config.input_key_by_modality["image"] self.budget_graphs[token_budget] = BudgetGraphMetadata( token_budget=token_budget, max_batch_size=self.max_batch_size, + max_frames_per_batch=self.max_frames_per_batch, graph=graph, input_buffer=mm_kwargs[input_key], metadata_buffers=buffers, @@ -230,10 +249,11 @@ def _run_budget_graph( # Copy the input tensor. Buffers are sized for the full budget; # actual inputs may be smaller. Zero then slice-copy so padded # positions are invisible to attention (cu_seqlens masks them out). - input_key = self.config.input_key + input_key = self.config.input_key_by_modality[ + self.model.get_input_modality(mm_kwargs) + ] src = mm_kwargs[input_key] n = src.shape[0] - graph_meta.input_buffer.zero_() graph_meta.input_buffer[:n].copy_(src) # Copy metadata buffers using keys from config.buffer_keys. @@ -362,7 +382,9 @@ def _execute_local( (token_budget - batch_out_tokens) / token_budget * 100, ) replay = self.model.prepare_encoder_cudagraph_replay_buffers( - batch_mm_kwargs, self.max_batch_size + batch_mm_kwargs, + self.max_batch_size, + self.max_frames_per_batch, ) # graph_hits counted inside _run_budget_graph after replay. diff --git a/vllm/v1/worker/encoder_cudagraph_defs.py b/vllm/v1/worker/encoder_cudagraph_defs.py index 455786682059..00ab97b3cd3b 100644 --- a/vllm/v1/worker/encoder_cudagraph_defs.py +++ b/vllm/v1/worker/encoder_cudagraph_defs.py @@ -20,8 +20,10 @@ class EncoderCudaGraphConfig: modalities: list[str] """Supported modalities (e.g. ["image"]).""" - input_key: str - """Key in mm_kwargs for the input tensor (e.g. "pixel_values").""" + input_key_by_modality: dict[str, str] + """Per-modality input tensor key mapping, e.g. + {"image": "pixel_values", "video": "pixel_values_videos"}. + """ buffer_keys: list[str] """Keys for the tensor buffers recorded into the CUDA graph.