Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
28fa0e9
multimodal compile & piecewise graph
wangxingran222 Nov 12, 2025
52da312
hardcoded ViT piecewise cuda graph size without padding
wangxingran222 Nov 12, 2025
438b8eb
feat: add vit padding
HirokenOvo Nov 13, 2025
7eaac5c
fix: fix vit cuda graph weak ref issue and first graph gc issue
HirokenOvo Nov 13, 2025
330fe86
feat: add vit cudagraph capture sizes and related functionality
HirokenOvo Nov 18, 2025
fac98f9
ViT cuda graph dispatcher
wangxingran222 Nov 18, 2025
2762ba6
feat: update Qwen2.5-VL model to support dynamic buffer sizes based o…
HirokenOvo Nov 18, 2025
fb9225e
fix: Ordering vit_cudagraph capture sizes and disable vit dp mode
HirokenOvo Nov 19, 2025
c216a0c
chore: Optimize code structure and add documentation
HirokenOvo Nov 20, 2025
c85b49b
chore: rebase to v0.11.1
HirokenOvo Nov 20, 2025
f1f26d0
chore: ruff format
HirokenOvo Nov 20, 2025
ef26918
feat: Update vit_cudagraph capture size logic
HirokenOvo Nov 23, 2025
2872257
[Model][Qwen3VL] Add `torch.compile` support for Qwen3VL
lgeiger Oct 29, 2025
8bff371
feat: Enhance Qwen3VL with ViT CUDAGraph support
HirokenOvo Dec 26, 2025
7dc0fcf
feat: add vit dp mode cuda graph
HirokenOvo Dec 30, 2025
c0e8849
chore: remove ViT's useless persistent buffer at engine level
HirokenOvo Dec 30, 2025
ef7e45d
feat: add FA and sdpa wrappers to compilation config
HirokenOvo Dec 30, 2025
e23899d
fix: update dummy input type from image to video to avoid preprocess_…
HirokenOvo Dec 31, 2025
506f75b
feat: add max_vit_cudagraph_capture_size and simplify code
HirokenOvo Jan 5, 2026
ee80144
rebase to v0.13.0
HirokenOvo Jan 11, 2026
f8defd7
chore: Reduce unnecessary computations in ViT dp mode
HirokenOvo Jan 11, 2026
602c692
fix: truncate padded output in CUDA graph execution to prevent all_ga…
HirokenOvo Jan 16, 2026
c1746c1
fix: change padding init from empty to zeros to avoid FA3 issues
HirokenOvo Jan 21, 2026
99d8272
rebase to main 7ef587
HirokenOvo Jan 23, 2026
79ea240
rebase to ff6c1d
HirokenOvo Jan 26, 2026
7be22e7
feat: add test
HirokenOvo Jan 27, 2026
eb91c31
ruff
HirokenOvo Jan 27, 2026
f7e4ea9
fix review suggestion
HirokenOvo Jan 29, 2026
3f9950e
chore: rename vit to mm_encoder
HirokenOvo Jan 30, 2026
13c6422
feat: add MMEncoderCudagraphManager and update related components for…
HirokenOvo Feb 2, 2026
0731683
simplify cuda graph conditional judgments
HirokenOvo Feb 3, 2026
53814ec
rebase
HirokenOvo Feb 3, 2026
ae2e8e6
add a dedicated dispatcher for mm encoder
HirokenOvo Feb 4, 2026
9be3fa6
modify to be compatible with V1 design
HirokenOvo Feb 4, 2026
6da9076
simplify CudagraphDispatcher init and restore video logic
HirokenOvo Feb 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion docs/design/torch_compile_multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,45 @@ to alert torch.compile to the fact that this range cannot be inferred, and we de

### Cudagraphs

We have not yet explored compilation for multimodal encoders with CUDAGraph integration; behavior is currently unspecified.
vLLM now supports Piecewise CUDA Graph integration for the Vision Transformer (ViT) encoder in Qwen2.5-VL and Qwen3-VL models. This feature captures CUDA graphs at specified patch sizes to reduce kernel launch overhead and improve performance.

#### Enabling ViT CUDA Graphs

**Important**: This feature is **not enabled by default**. The Piecewise CUDA Graph implementation relies on `torch.compile` to trace the computation graph and separate the attention operators. Therefore, users must explicitly enable mm_encoder compilation via the `--compilation-config` argument to activate this feature.

To enable ViT CUDA graph compilation, use:

```bash
vllm serve <model> --compilation-config '{"compile_mm_encoder": true}'
```

#### Configuring Capture Sizes

You can specify custom patch sizes for CUDA graph capture using `mm_encoder_cudagraph_capture_sizes`. For models like `Qwen2.5-VL` and `Qwen3-VL`, the capture sizes should be multiples of the square of `merge_size`:

```bash
vllm serve <model> --compilation-config '{"compile_mm_encoder": true, "mm_encoder_cudagraph_capture_sizes": [512, 1024]}'
```

Alternatively, you can specify `max_mm_encoder_cudagraph_capture_size` to generate a default list of capture sizes up to the given value:

```bash
vllm serve <model> --compilation-config '{"compile_mm_encoder": true, "max_mm_encoder_cudagraph_capture_size": 2048}'
```

#### Default Behavior

Once enabled, if `mm_encoder_cudagraph_capture_sizes` is not specified, vLLM will use a default set of sizes for capture. Since `compile_mm_encoder` is `False` by default, this feature remains inactive unless configured.

If you only want to enable `torch.compile` for ViT without using the CUDA Graph feature, you can explicitly set the capture sizes to empty:

```bash
vllm serve <model> --compilation-config '{"compile_mm_encoder": true, "mm_encoder_cudagraph_capture_sizes": []}'
```

#### Limitations & Notes

- **Image Only**: This feature currently only supports image inference. Video inference is not supported yet.

## Troubleshooting

Expand Down
260 changes: 260 additions & 0 deletions tests/compile/piecewise/test_qwenvl_vit_cudagraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import weakref
from functools import partial

import pytest
import torch

from vllm import LLM
from vllm.config import CompilationConfig, CUDAGraphMode
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.forward_context import set_forward_context
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
from vllm.v1.worker.mm_cudagraph import MMEncoderCudagraphManager

# Format: (model_name, tp_size, mm_encoder_tp_mode)
TEST_CONFIGS = [
("Qwen/Qwen2.5-VL-3B-Instruct", 1, "weights"),
("Qwen/Qwen3-VL-4B-Instruct", 1, "weights"),
# TP/DP modes with 2 GPUs
("Qwen/Qwen2.5-VL-3B-Instruct", 2, "data"),
("Qwen/Qwen2.5-VL-3B-Instruct", 2, "weights"),
("Qwen/Qwen3-VL-4B-Instruct", 2, "data"),
("Qwen/Qwen3-VL-4B-Instruct", 2, "weights"),
]


@pytest.fixture(
params=TEST_CONFIGS, ids=lambda x: f"{x[0].split('/')[-1]}-tp{x[1]}-{x[2]}"
)
def llm(request):
model_name, tp_size, mm_mode = request.param

if torch.cuda.device_count() < tp_size:
pytest.skip(f"Not enough GPUs for tp_size={tp_size}")

os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
# Common configuration
common_args = {
"model": model_name,
"trust_remote_code": True,
"max_model_len": 4096,
"max_num_seqs": 16,
"gpu_memory_utilization": 0.2,
"tensor_parallel_size": tp_size,
"mm_encoder_tp_mode": mm_mode,
}

# Initialize LLM with ViT CUDA graph enabled (piecewise)
# We only need one LLM instance. For eager execution, we will force
# cudagraph_runtime_mode=NONE at runtime.
llm_instance = None
try:
llm_instance = LLM(
**common_args,
compilation_config=CompilationConfig(
cudagraph_mode="PIECEWISE",
compile_mm_encoder=True,
mm_encoder_cudagraph_capture_sizes=[64, 128, 256],
),
)
print(f"LLM initialized for {model_name} tp={tp_size} mode={mm_mode}")
yield weakref.proxy(llm_instance)
finally:
print("Cleaning up LLM after testing.")
if llm_instance:
# Ensure model executor and workers are properly shut down
# llm_instance.llm_engine is vllm.v1.engine.llm_engine.LLMEngine
# which has engine_core (InprocClient).
if hasattr(llm_instance.llm_engine, "engine_core"):
llm_instance.llm_engine.engine_core.shutdown()
del llm_instance

# Clean up distributed environment
cleanup_dist_env_and_memory()


def _worker_embed_multimodal(
worker, vllm_config, multi_modal_data, enforce_eager=False
):
"""Helper function to run multimodal embedding on a worker.
This function sets up the necessary forward context for tensor-parallel (TP)
execution and then calls the model's `embed_multimodal` method.
Note: For data-parallel (DP) mode, the forward context is typically
created and managed within the
vision.py:run_dp_sharded_mrope_vision_model(), which would override the
context set here.
This method manually constructs a MMEncoderCudagraphManager because accessing the
one within the GPU model runner is difficult.
Args:
worker: The worker instance containing the model runner.
vllm_config: The vLLM engine configuration.
multi_modal_data: A dictionary of keyword arguments to be passed to
the model's `embed_multimodal` method.
enforce_eager: If True, forces the execution to run in eager mode
Returns:
The output from the model's `embed_multimodal` method.
"""

# Access model via worker.model_runner.model
# Note: Accessing internal attributes. Assuming V1 worker structure.
model = worker.model_runner.model

# Move multi_modal_data to the model's device
target_device = next(model.parameters()).device
multi_modal_data = {
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
for k, v in multi_modal_data.items()
}

processor = MULTIMODAL_REGISTRY.create_processor(vllm_config.model_config)
dummy_inputs_builder = processor.dummy_inputs
mm_cudagraph_manager = MMEncoderCudagraphManager(
vllm_config,
dummy_inputs_builder,
)
mm_cudagraph_manager.initialize_cudagraph_keys(
CUDAGraphMode.PIECEWISE,
)

# Dispatch to get runtime mode and batch descriptor
(
cudagraph_runtime_mode,
batch_descriptor,
_,
multi_modal_data,
) = mm_cudagraph_manager.dispatch_and_pad_mm_input(multi_modal_data)
if enforce_eager:
cudagraph_runtime_mode = CUDAGraphMode.NONE
else:
multi_modal_data["mm_cudagraph_manager"] = mm_cudagraph_manager

with (
set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
),
torch.inference_mode(),
):
ans = model.embed_multimodal(**multi_modal_data)
torch.cuda.synchronize()
return ans


class TestQwenVLCUDAGraph:
def _run_embed_multimodal(self, llm, multi_modal_data, enforce_eager=False):
"""Runs the multimodal embedding process, potentially with CUDA graphs.
The actual embedding is performed on the worker(s) via an RPC call.
Args:
llm: The LLM object containing the model engine and configuration.
multi_modal_data: A dictionary containing the multimodal data to be
processed.
enforce_eager: If True, forces the execution to run in eager mode,
bypassing CUDA graphs.
Returns:
The outputs from the multimodal embedding process executed on the
worker.
"""
vllm_config = llm.llm_engine.vllm_config
model_executor = llm.llm_engine.model_executor

rpc_kwargs = {}
# Use collective_rpc to execute on driver worker (rank 0)
if isinstance(model_executor, MultiprocExecutor):
rpc_kwargs["unique_reply_rank"] = 0

outputs = model_executor.collective_rpc(
partial(
_worker_embed_multimodal,
vllm_config=vllm_config,
multi_modal_data=multi_modal_data,
enforce_eager=enforce_eager,
),
**rpc_kwargs,
)

if isinstance(outputs, list) and len(outputs) == 1:
outputs = outputs[0]
return outputs

def test_vit_cudagraph_consistency(self, llm):
print("Starting test for ViT CUDA graph consistency.")

model_name = llm.llm_engine.vllm_config.model_config.model
# Qwen3-VL uses patch_size=16, temporal_patch_size=2 -> 16*16*3*2 = 1536
# Qwen2.5-VL uses patch_size=14, temporal_patch_size=2 -> 14*14*3*2 = 1176
input_dim = 1536 if "Qwen3-VL" in model_name else 1176

num_patches = 64
for num_imgs in [1, 2, 4]:
image_grid_thw = torch.tensor(
[[1, 2, num_patches // 2]] * num_imgs, dtype=torch.long, device="cpu"
)
pixel_values = torch.rand(
(num_patches * num_imgs, input_dim), dtype=torch.bfloat16, device="cpu"
)

multi_modal_data = {
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
}
print(
"Running inference with single LLM (Piecewise vs Eager via context)."
"num_imgs:",
num_imgs,
)

# Run with Piecewise CUDA Graph
piecewise_outputs = self._run_embed_multimodal(
llm, multi_modal_data, enforce_eager=False
)

# Run with Eager Mode (simulated by setting runtime mode to NONE)
eager_outputs = self._run_embed_multimodal(
llm, multi_modal_data, enforce_eager=True
)

if isinstance(piecewise_outputs, torch.Tensor):
assert torch.allclose(
piecewise_outputs, eager_outputs, atol=1e-3, rtol=1e-5
), (
f"num_imgs: {num_imgs}. Piecewise and Eager outputs do not match. "
"Max abs diff: "
f"{torch.max(torch.abs(piecewise_outputs - eager_outputs))}. "
"Max rel diff: "
f"{
torch.max(
torch.abs(piecewise_outputs - eager_outputs)
/ (torch.abs(eager_outputs) + 1e-8)
)
}"
)
elif isinstance(piecewise_outputs, tuple):
assert isinstance(eager_outputs, tuple), (
"Output types mismatch, piecewise is tuple but eager is not."
)
assert len(piecewise_outputs) == len(eager_outputs), (
"Output tuple lengths mismatch."
)
for i, (p_out, e_out) in enumerate(
zip(piecewise_outputs, eager_outputs)
):
assert torch.allclose(p_out, e_out, atol=1e-3, rtol=1e-5), (
f"num_imgs: {num_imgs}. "
f"Tuple element {i} does not match. "
"Max abs diff: "
f"{torch.max(torch.abs(p_out - e_out))}. "
"Max rel diff: "
f"{
torch.max(
torch.abs(p_out - e_out) / (torch.abs(e_out) + 1e-8)
)
}"
)
else:
raise TypeError(f"Unsupported output type: {type(piecewise_outputs)}")
57 changes: 54 additions & 3 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pprint
import time
from collections.abc import Callable, Generator, Sequence
from contextlib import contextmanager
from contextlib import AbstractContextManager, contextmanager
from copy import deepcopy
from functools import partial
from typing import Any
Expand All @@ -30,6 +30,7 @@
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import DynamicShapesType
from vllm.config.utils import Range, hash_factors
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.logging_utils import lazy
from vllm.platforms import current_platform
Expand All @@ -49,6 +50,46 @@
logger = init_logger(__name__)


@contextmanager
def _set_mm_encoder_sequence_flag(
attr_name: str, value: bool
) -> Generator[None, None, None]:
try:
ctx = get_forward_context()
original_value = getattr(ctx, attr_name)
setattr(ctx, attr_name, value)
except Exception:
yield
return

try:
yield
finally:
setattr(ctx, attr_name, original_value)


def set_is_last_graph_in_mm_encoder_sequence(
is_last: bool,
) -> AbstractContextManager[None]:
"""Context manager to indicate if the current graph being compiled
is the last one in a sequence of graphs (e.g., a sequence of blocks).
"""
return _set_mm_encoder_sequence_flag(
"is_last_graph_in_mm_encoder_sequence", is_last
)


def set_is_first_graph_in_mm_encoder_sequence(
is_first: bool,
) -> AbstractContextManager[None]:
"""Context manager to indicate if the current graph being compiled
is the first one in a sequence of graphs (e.g., a sequence of blocks).
"""
return _set_mm_encoder_sequence_flag(
"is_first_graph_in_mm_encoder_sequence", is_first
)


def make_copy_and_call(
sym_tensor_indices: list[int],
input_buffers: list[torch.Tensor | None],
Expand Down Expand Up @@ -443,14 +484,24 @@ def wrap_with_cudagraph_if_needed(
# CUDAGraphWrapper for piecewise_backend, to distinguish
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.

try:
fwd_ctx = get_forward_context()
is_first_graph_in_sequence = fwd_ctx.is_first_graph_in_mm_encoder_sequence
is_last_graph_in_sequence = fwd_ctx.is_last_graph_in_mm_encoder_sequence
except Exception:
# Fallback for when ForwardContext is not available
is_first_graph_in_sequence = True
is_last_graph_in_sequence = True

return static_graph_wrapper_class(
runnable=piecewise_backend,
vllm_config=vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=is_first_graph,
gc_disable=not is_first_graph,
weak_ref_output=is_last_graph,
gc_disable=not is_first_graph or not is_first_graph_in_sequence,
weak_ref_output=is_last_graph and is_last_graph_in_sequence,
),
)

Expand Down
Loading