Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/design/cuda_graphs_multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra
| ------------ | ------ | ------------ | ------------ |
| `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ |
| `Qwen2_5_VLForConditionalGeneration` | `Qwen2.5-VL` | ✅︎ | ✅︎ |
| `KimiVLForConditionalGeneration` | `Kimi-VL` | ✅︎ | ✗ |

!!! note
Encoder CUDA Graphs have currently been tested with `--mm-encoder-attn-backend=FLASH_ATTN` and `--mm-encoder-attn-backend=FLASHINFER` on Blackwell GPUs.
Expand Down
1 change: 1 addition & 0 deletions examples/generate/multimodal/vision_language_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2467,6 +2467,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
"qwen3_vl",
"qwen3_vl_moe",
"qwen2_5_vl",
"kimi_vl",
]


Expand Down
18 changes: 18 additions & 0 deletions tests/models/multimodal/generation/test_vit_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def qwen_vl_chat_template(content: str) -> str:
return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"


def kimi_vl_chat_template(content: str) -> str:
return (
f"<|im_user|>user<|im_middle|>{content}<|im_end|>"
"<|im_assistant|>assistant<|im_middle|>"
)


MODEL_CONFIGS: dict[str, VitCudagraphTestConfig] = {
"qwen3_vl": VitCudagraphTestConfig(
model="Qwen/Qwen3-VL-2B-Instruct",
Expand All @@ -66,6 +73,17 @@ def qwen_vl_chat_template(content: str) -> str:
needs_video_metadata=False,
marks=[pytest.mark.core_model],
),
"kimi_vl": VitCudagraphTestConfig(
model="moonshotai/Kimi-VL-A3B-Instruct",
modalities=["image"],
image_prompt=kimi_vl_chat_template(
"<|media_start|>image<|media_content|><|media_pad|><|media_end|>"
"What is in this image?"
),
needs_video_metadata=False,
vllm_runner_kwargs={"trust_remote_code": True},
marks=[pytest.mark.core_model],
),
}


Expand Down
203 changes: 200 additions & 3 deletions vllm/model_executor/models/kimi_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Annotated, Any, Literal
from typing import Annotated, Any, ClassVar, Literal

import torch
from torch import nn
Expand All @@ -56,7 +56,11 @@
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs import MultiModalDataDict
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
from vllm.model_executor.models.interfaces import (
SupportsEncoderCudaGraph,
SupportsMultiModal,
SupportsPP,
)
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
Expand All @@ -79,6 +83,7 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig, MoonViTConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.v1.worker.encoder_cudagraph_defs import EncoderCudaGraphReplayBuffers

from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
from .vision import is_vit_use_data_parallel, run_dp_sharded_mrope_vision_model
Expand Down Expand Up @@ -287,8 +292,11 @@ def get_replacement(item_idx: int):
info=KimiVLProcessingInfo,
dummy_inputs=KimiVLDummyInputsBuilder,
)
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class KimiVLForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsEncoderCudaGraph, SupportsPP
):
supports_encoder_tp_data = True
supports_encoder_cudagraph: ClassVar[Literal[True]] = True

@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
Expand Down Expand Up @@ -340,6 +348,195 @@ def __init__(

self.media_placeholder: int = self.config.media_placeholder_token_id

self.model_config = model_config

# -- SupportsEncoderCudaGraph protocol methods --

def get_encoder_cudagraph_config(self):
from vllm.v1.worker.encoder_cudagraph_defs import (
EncoderCudaGraphConfig,
)

return EncoderCudaGraphConfig(
modalities=["image"],
input_key_by_modality={"image": "pixel_values"},
buffer_keys=[
"pos_embeds",
"rope_freqs_cis",
"cu_seqlens",
"max_seqlen",
"merge_gather_idx",
],
out_hidden_size=self.hidden_size,
)

def get_input_modality(
self,
mm_kwargs: dict[str, Any],
) -> str:
return "image"

def get_max_frames_per_video(self) -> int:
return 0

def get_encoder_cudagraph_budget_range(
self,
vllm_config,
) -> tuple[int, int]:
# Min: estimated smallest possible encoder input.
# 224x224 image with patch_size=14 -> 16x16 patches, then merge
# kernel (2,2) -> 8x8 = 64 output tokens.
min_budget = 64
max_budget = min(
vllm_config.scheduler_config.max_num_batched_tokens,
self.model_config.max_model_len,
)
return (min_budget, max_budget)

def _get_grid_hws(
self,
mm_kwargs: dict[str, Any],
) -> list[tuple[int, int]]:
grid_hws = mm_kwargs["image_grid_hws"]
if not isinstance(grid_hws, list):
grid_hws = grid_hws.tolist()
return grid_hws

def get_encoder_cudagraph_num_items(
self,
mm_kwargs: dict[str, Any],
) -> int:
return len(self._get_grid_hws(mm_kwargs))

def get_encoder_cudagraph_per_item_output_tokens(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
kh, kw = self.config.vision_config.merge_kernel_size
grid_hws = self._get_grid_hws(mm_kwargs)
return [(h // kh) * (w // kw) for h, w in grid_hws]

def get_encoder_cudagraph_per_item_input_sizes(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
grid_hws = self._get_grid_hws(mm_kwargs)
return [h * w for h, w in grid_hws]

def select_encoder_cudagraph_items(
self,
mm_kwargs: dict[str, Any],
indices: list[int],
) -> dict[str, Any]:
grid_hws = self._get_grid_hws(mm_kwargs)
pixel_values = mm_kwargs["pixel_values"]

if len(indices) == 0:
return {
"pixel_values": pixel_values[:0],
"image_grid_hws": pixel_values.new_zeros((0, 2), dtype=torch.long),
}

patches_per_item = [h * w for h, w in grid_hws]
cum_patches = [0]
for p in patches_per_item:
cum_patches.append(cum_patches[-1] + p)

selected_pv = torch.cat(
[pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
)
selected_grid = torch.tensor(
[grid_hws[i] for i in indices],
dtype=torch.long,
device=pixel_values.device,
)
return {
"pixel_values": selected_pv,
"image_grid_hws": selected_grid,
}

def prepare_encoder_cudagraph_capture_inputs(
self,
token_budget: int,
max_batch_size: int,
max_frames_per_batch: int,
device: torch.device,
dtype: torch.dtype,
):
from vllm.v1.worker.encoder_cudagraph_defs import (
EncoderCudaGraphCaptureInputs,
)

kh, kw = self.config.vision_config.merge_kernel_size
per_mm_item_output = token_budget // max_batch_size

# Build a worst-case grid_hws that yields ``per_mm_item_output``
# merged tokens per item (h=kh, w=kw*per_mm_item_output).
grid_hws_list = [(kh, kw * per_mm_item_output) for _ in range(max_batch_size)]

patch_size = self.config.vision_config.patch_size
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)

total_patches = sum(h * w for h, w in grid_hws_list)
in_channels = 3
dummy_pixel_values = torch.randn(
total_patches,
in_channels,
patch_size[0],
patch_size[1],
device=device,
dtype=dtype,
)

buffers = self.vision_tower.prepare_encoder_metadata(
grid_hws_list,
max_batch_size=max_batch_size,
max_seqlen_override=token_budget,
device=device,
)

mm_kwargs = {"pixel_values": dummy_pixel_values}

return EncoderCudaGraphCaptureInputs(
mm_kwargs=mm_kwargs,
buffers=buffers,
)

def prepare_encoder_cudagraph_replay_buffers(
self,
mm_kwargs: dict[str, Any],
max_batch_size: int,
max_frames_per_batch: int,
):
grid_hws_list = self._get_grid_hws(mm_kwargs)
buffers = self.vision_tower.prepare_encoder_metadata(
grid_hws_list,
max_batch_size=max_batch_size,
device=mm_kwargs["pixel_values"].device,
)
return EncoderCudaGraphReplayBuffers(buffers=buffers)

def encoder_cudagraph_forward(
self,
mm_kwargs: dict[str, Any],
buffers: dict[str, torch.Tensor],
) -> torch.Tensor:
pixel_values = mm_kwargs["pixel_values"]
image_features = self.vision_tower(
pixel_values, grid_hw=None, encoder_metadata=buffers
)
return self.multi_modal_projector(image_features)

def encoder_eager_forward(
self,
mm_kwargs: dict[str, Any],
) -> torch.Tensor:
pixel_values = mm_kwargs["pixel_values"]
image_grid_hws = mm_kwargs["image_grid_hws"]
image_features = self.vision_tower(pixel_values, image_grid_hws)
return self.multi_modal_projector(torch.cat(image_features))

def _parse_and_validate_image_input(
self, **kwargs: object
) -> KimiVLImageInputs | None:
Expand Down
Loading
Loading