diff --git a/docs/design/torch_compile_multimodal.md b/docs/design/torch_compile_multimodal.md index 674ddd801d65..260f4e136a58 100644 --- a/docs/design/torch_compile_multimodal.md +++ b/docs/design/torch_compile_multimodal.md @@ -68,7 +68,45 @@ to alert torch.compile to the fact that this range cannot be inferred, and we de ### Cudagraphs -We have not yet explored compilation for multimodal encoders with CUDAGraph integration; behavior is currently unspecified. +vLLM now supports Piecewise CUDA Graph integration for the Vision Transformer (ViT) encoder in Qwen2.5-VL and Qwen3-VL models. This feature captures CUDA graphs at specified patch sizes to reduce kernel launch overhead and improve performance. + +#### Enabling ViT CUDA Graphs + +**Important**: This feature is **not enabled by default**. The Piecewise CUDA Graph implementation relies on `torch.compile` to trace the computation graph and separate the attention operators. Therefore, users must explicitly enable mm_encoder compilation via the `--compilation-config` argument to activate this feature. + +To enable ViT CUDA graph compilation, use: + +```bash +vllm serve --compilation-config '{"compile_mm_encoder": true}' +``` + +#### Configuring Capture Sizes + +You can specify custom patch sizes for CUDA graph capture using `mm_encoder_cudagraph_capture_sizes`. For models like `Qwen2.5-VL` and `Qwen3-VL`, the capture sizes should be multiples of the square of `merge_size`: + +```bash +vllm serve --compilation-config '{"compile_mm_encoder": true, "mm_encoder_cudagraph_capture_sizes": [512, 1024]}' +``` + +Alternatively, you can specify `max_mm_encoder_cudagraph_capture_size` to generate a default list of capture sizes up to the given value: + +```bash +vllm serve --compilation-config '{"compile_mm_encoder": true, "max_mm_encoder_cudagraph_capture_size": 2048}' +``` + +#### Default Behavior + +Once enabled, if `mm_encoder_cudagraph_capture_sizes` is not specified, vLLM will use a default set of sizes for capture. Since `compile_mm_encoder` is `False` by default, this feature remains inactive unless configured. + +If you only want to enable `torch.compile` for ViT without using the CUDA Graph feature, you can explicitly set the capture sizes to empty: + +```bash +vllm serve --compilation-config '{"compile_mm_encoder": true, "mm_encoder_cudagraph_capture_sizes": []}' +``` + +#### Limitations & Notes + +- **Image Only**: This feature currently only supports image inference. Video inference is not supported yet. ## Troubleshooting diff --git a/tests/compile/piecewise/test_qwenvl_vit_cudagraph.py b/tests/compile/piecewise/test_qwenvl_vit_cudagraph.py new file mode 100644 index 000000000000..f59368fcbd1c --- /dev/null +++ b/tests/compile/piecewise/test_qwenvl_vit_cudagraph.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +import weakref +from functools import partial + +import pytest +import torch + +from vllm import LLM +from vllm.config import CompilationConfig, CUDAGraphMode +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.forward_context import set_forward_context +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.v1.executor.multiproc_executor import MultiprocExecutor +from vllm.v1.worker.mm_cudagraph import MMEncoderCudagraphManager + +# Format: (model_name, tp_size, mm_encoder_tp_mode) +TEST_CONFIGS = [ + ("Qwen/Qwen2.5-VL-3B-Instruct", 1, "weights"), + ("Qwen/Qwen3-VL-4B-Instruct", 1, "weights"), + # TP/DP modes with 2 GPUs + ("Qwen/Qwen2.5-VL-3B-Instruct", 2, "data"), + ("Qwen/Qwen2.5-VL-3B-Instruct", 2, "weights"), + ("Qwen/Qwen3-VL-4B-Instruct", 2, "data"), + ("Qwen/Qwen3-VL-4B-Instruct", 2, "weights"), +] + + +@pytest.fixture( + params=TEST_CONFIGS, ids=lambda x: f"{x[0].split('/')[-1]}-tp{x[1]}-{x[2]}" +) +def llm(request): + model_name, tp_size, mm_mode = request.param + + if torch.cuda.device_count() < tp_size: + pytest.skip(f"Not enough GPUs for tp_size={tp_size}") + + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + # Common configuration + common_args = { + "model": model_name, + "trust_remote_code": True, + "max_model_len": 4096, + "max_num_seqs": 16, + "gpu_memory_utilization": 0.2, + "tensor_parallel_size": tp_size, + "mm_encoder_tp_mode": mm_mode, + } + + # Initialize LLM with ViT CUDA graph enabled (piecewise) + # We only need one LLM instance. For eager execution, we will force + # cudagraph_runtime_mode=NONE at runtime. + llm_instance = None + try: + llm_instance = LLM( + **common_args, + compilation_config=CompilationConfig( + cudagraph_mode="PIECEWISE", + compile_mm_encoder=True, + mm_encoder_cudagraph_capture_sizes=[64, 128, 256], + ), + ) + print(f"LLM initialized for {model_name} tp={tp_size} mode={mm_mode}") + yield weakref.proxy(llm_instance) + finally: + print("Cleaning up LLM after testing.") + if llm_instance: + # Ensure model executor and workers are properly shut down + # llm_instance.llm_engine is vllm.v1.engine.llm_engine.LLMEngine + # which has engine_core (InprocClient). + if hasattr(llm_instance.llm_engine, "engine_core"): + llm_instance.llm_engine.engine_core.shutdown() + del llm_instance + + # Clean up distributed environment + cleanup_dist_env_and_memory() + + +def _worker_embed_multimodal( + worker, vllm_config, multi_modal_data, enforce_eager=False +): + """Helper function to run multimodal embedding on a worker. + This function sets up the necessary forward context for tensor-parallel (TP) + execution and then calls the model's `embed_multimodal` method. + Note: For data-parallel (DP) mode, the forward context is typically + created and managed within the + vision.py:run_dp_sharded_mrope_vision_model(), which would override the + context set here. + This method manually constructs a MMEncoderCudagraphManager because accessing the + one within the GPU model runner is difficult. + Args: + worker: The worker instance containing the model runner. + vllm_config: The vLLM engine configuration. + multi_modal_data: A dictionary of keyword arguments to be passed to + the model's `embed_multimodal` method. + enforce_eager: If True, forces the execution to run in eager mode + Returns: + The output from the model's `embed_multimodal` method. + """ + + # Access model via worker.model_runner.model + # Note: Accessing internal attributes. Assuming V1 worker structure. + model = worker.model_runner.model + + # Move multi_modal_data to the model's device + target_device = next(model.parameters()).device + multi_modal_data = { + k: v.to(target_device) if isinstance(v, torch.Tensor) else v + for k, v in multi_modal_data.items() + } + + processor = MULTIMODAL_REGISTRY.create_processor(vllm_config.model_config) + dummy_inputs_builder = processor.dummy_inputs + mm_cudagraph_manager = MMEncoderCudagraphManager( + vllm_config, + dummy_inputs_builder, + ) + mm_cudagraph_manager.initialize_cudagraph_keys( + CUDAGraphMode.PIECEWISE, + ) + + # Dispatch to get runtime mode and batch descriptor + ( + cudagraph_runtime_mode, + batch_descriptor, + _, + multi_modal_data, + ) = mm_cudagraph_manager.dispatch_and_pad_mm_input(multi_modal_data) + if enforce_eager: + cudagraph_runtime_mode = CUDAGraphMode.NONE + else: + multi_modal_data["mm_cudagraph_manager"] = mm_cudagraph_manager + + with ( + set_forward_context( + None, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ), + torch.inference_mode(), + ): + ans = model.embed_multimodal(**multi_modal_data) + torch.cuda.synchronize() + return ans + + +class TestQwenVLCUDAGraph: + def _run_embed_multimodal(self, llm, multi_modal_data, enforce_eager=False): + """Runs the multimodal embedding process, potentially with CUDA graphs. + The actual embedding is performed on the worker(s) via an RPC call. + Args: + llm: The LLM object containing the model engine and configuration. + multi_modal_data: A dictionary containing the multimodal data to be + processed. + enforce_eager: If True, forces the execution to run in eager mode, + bypassing CUDA graphs. + Returns: + The outputs from the multimodal embedding process executed on the + worker. + """ + vllm_config = llm.llm_engine.vllm_config + model_executor = llm.llm_engine.model_executor + + rpc_kwargs = {} + # Use collective_rpc to execute on driver worker (rank 0) + if isinstance(model_executor, MultiprocExecutor): + rpc_kwargs["unique_reply_rank"] = 0 + + outputs = model_executor.collective_rpc( + partial( + _worker_embed_multimodal, + vllm_config=vllm_config, + multi_modal_data=multi_modal_data, + enforce_eager=enforce_eager, + ), + **rpc_kwargs, + ) + + if isinstance(outputs, list) and len(outputs) == 1: + outputs = outputs[0] + return outputs + + def test_vit_cudagraph_consistency(self, llm): + print("Starting test for ViT CUDA graph consistency.") + + model_name = llm.llm_engine.vllm_config.model_config.model + # Qwen3-VL uses patch_size=16, temporal_patch_size=2 -> 16*16*3*2 = 1536 + # Qwen2.5-VL uses patch_size=14, temporal_patch_size=2 -> 14*14*3*2 = 1176 + input_dim = 1536 if "Qwen3-VL" in model_name else 1176 + + num_patches = 64 + for num_imgs in [1, 2, 4]: + image_grid_thw = torch.tensor( + [[1, 2, num_patches // 2]] * num_imgs, dtype=torch.long, device="cpu" + ) + pixel_values = torch.rand( + (num_patches * num_imgs, input_dim), dtype=torch.bfloat16, device="cpu" + ) + + multi_modal_data = { + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + } + print( + "Running inference with single LLM (Piecewise vs Eager via context)." + "num_imgs:", + num_imgs, + ) + + # Run with Piecewise CUDA Graph + piecewise_outputs = self._run_embed_multimodal( + llm, multi_modal_data, enforce_eager=False + ) + + # Run with Eager Mode (simulated by setting runtime mode to NONE) + eager_outputs = self._run_embed_multimodal( + llm, multi_modal_data, enforce_eager=True + ) + + if isinstance(piecewise_outputs, torch.Tensor): + assert torch.allclose( + piecewise_outputs, eager_outputs, atol=1e-3, rtol=1e-5 + ), ( + f"num_imgs: {num_imgs}. Piecewise and Eager outputs do not match. " + "Max abs diff: " + f"{torch.max(torch.abs(piecewise_outputs - eager_outputs))}. " + "Max rel diff: " + f"{ + torch.max( + torch.abs(piecewise_outputs - eager_outputs) + / (torch.abs(eager_outputs) + 1e-8) + ) + }" + ) + elif isinstance(piecewise_outputs, tuple): + assert isinstance(eager_outputs, tuple), ( + "Output types mismatch, piecewise is tuple but eager is not." + ) + assert len(piecewise_outputs) == len(eager_outputs), ( + "Output tuple lengths mismatch." + ) + for i, (p_out, e_out) in enumerate( + zip(piecewise_outputs, eager_outputs) + ): + assert torch.allclose(p_out, e_out, atol=1e-3, rtol=1e-5), ( + f"num_imgs: {num_imgs}. " + f"Tuple element {i} does not match. " + "Max abs diff: " + f"{torch.max(torch.abs(p_out - e_out))}. " + "Max rel diff: " + f"{ + torch.max( + torch.abs(p_out - e_out) / (torch.abs(e_out) + 1e-8) + ) + }" + ) + else: + raise TypeError(f"Unsupported output type: {type(piecewise_outputs)}") diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 89981fc29963..63bf3690891a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -11,7 +11,7 @@ import pprint import time from collections.abc import Callable, Generator, Sequence -from contextlib import contextmanager +from contextlib import AbstractContextManager, contextmanager from copy import deepcopy from functools import partial from typing import Any @@ -30,6 +30,7 @@ from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config.compilation import DynamicShapesType from vllm.config.utils import Range, hash_factors +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.logging_utils import lazy from vllm.platforms import current_platform @@ -49,6 +50,46 @@ logger = init_logger(__name__) +@contextmanager +def _set_mm_encoder_sequence_flag( + attr_name: str, value: bool +) -> Generator[None, None, None]: + try: + ctx = get_forward_context() + original_value = getattr(ctx, attr_name) + setattr(ctx, attr_name, value) + except Exception: + yield + return + + try: + yield + finally: + setattr(ctx, attr_name, original_value) + + +def set_is_last_graph_in_mm_encoder_sequence( + is_last: bool, +) -> AbstractContextManager[None]: + """Context manager to indicate if the current graph being compiled + is the last one in a sequence of graphs (e.g., a sequence of blocks). + """ + return _set_mm_encoder_sequence_flag( + "is_last_graph_in_mm_encoder_sequence", is_last + ) + + +def set_is_first_graph_in_mm_encoder_sequence( + is_first: bool, +) -> AbstractContextManager[None]: + """Context manager to indicate if the current graph being compiled + is the first one in a sequence of graphs (e.g., a sequence of blocks). + """ + return _set_mm_encoder_sequence_flag( + "is_first_graph_in_mm_encoder_sequence", is_first + ) + + def make_copy_and_call( sym_tensor_indices: list[int], input_buffers: list[torch.Tensor | None], @@ -443,14 +484,24 @@ def wrap_with_cudagraph_if_needed( # CUDAGraphWrapper for piecewise_backend, to distinguish # it from the FULL cudagraph runtime mode, no matter it # is wrapped on a full or piecewise fx graph. + + try: + fwd_ctx = get_forward_context() + is_first_graph_in_sequence = fwd_ctx.is_first_graph_in_mm_encoder_sequence + is_last_graph_in_sequence = fwd_ctx.is_last_graph_in_mm_encoder_sequence + except Exception: + # Fallback for when ForwardContext is not available + is_first_graph_in_sequence = True + is_last_graph_in_sequence = True + return static_graph_wrapper_class( runnable=piecewise_backend, vllm_config=vllm_config, runtime_mode=CUDAGraphMode.PIECEWISE, cudagraph_options=CUDAGraphOptions( debug_log_enable=is_first_graph, - gc_disable=not is_first_graph, - weak_ref_output=is_last_graph, + gc_disable=not is_first_graph or not is_first_graph_in_sequence, + weak_ref_output=is_last_graph and is_last_graph_in_sequence, ), ) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 7a69629f707c..70ba6e68ec5d 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -530,6 +530,13 @@ class CompilationConfig: """Sizes to capture cudagraph. - None (default): capture sizes are inferred from vllm config. - list[int]: capture sizes are specified as given.""" + mm_encoder_cudagraph_capture_sizes: list[int] | None = None + """Sizes to capture mm_encoder cudagraph. + - None (default): capture sizes are inferred from vllm config. + - list[int]: capture sizes are specified as given.""" + max_mm_encoder_cudagraph_capture_size: int = field(default=None) + """The maximum mm_encoder cudagraph capture size. + """ cudagraph_copy_inputs: bool = False """Whether to copy input tensors for cudagraph. If the caller can guarantee that the same input buffers @@ -648,6 +655,8 @@ class CompilationConfig: "vllm::kda_attention", "vllm::sparse_attn_indexer", "vllm::rocm_aiter_sparse_attn_indexer", + "vllm::flash_attn_maxseqlen_wrapper", + "vllm::torch_sdpa_wrapper", ] def compute_hash(self) -> str: diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ea133856360d..137ab56b65b9 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -257,6 +257,13 @@ class VllmConfig: performance, with -O0 having the best startup time and -O3 having the best performance. -02 is used by defult. See OptimizationLevel for full description.""" + in_mm_encoder_tracing: bool = False + """Flag for mm_encoder compilation or mm_encoder CUDA graph capture. + + If true, mm_encoder in DP mode will execute the mm_encoder model directly instead of + `run_dp_sharded_mrope_vision_model` to ensure correct memory profiling + and compilation for each rank. + """ def compute_hash(self) -> str: """ @@ -813,6 +820,7 @@ def has_blocked_weights(): self.compilation_config.cudagraph_num_of_warmups = 1 self._set_cudagraph_sizes() + self._set_mm_encoder_cudagraph_sizes() else: self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE @@ -1332,6 +1340,140 @@ def _set_compile_ranges(self): computed_compile_ranges_split_points ) + def _set_mm_encoder_cudagraph_sizes(self): + """Sets the CUDA graph capture sizes for the multimodal encoder (MM Encoder). + + This method determines the batch sizes (in terms of number of patches) + for which MM Encoder CUDA graphs will be captured. CUDA graphs improve + performance by reducing kernel launch overhead for the multimodal encoder. + + The logic is as follows: + 1. The feature is only enabled if all of the following conditions are met: + - A model is configured (`model_config` is not None). + - Eager mode is not enforced (`enforce_eager` is False). + - CUDA graph mode is enabled (`cudagraph_mode` is not NONE). + - Multimodal encoder compilation is enabled (`compile_mm_encoder` is True). + If these conditions are not met, the list of capture sizes will be empty, + effectively disabling mm_encoder CUDA graphs. + + 2. If the user has explicitly provided `mm_encoder_cudagraph_capture_sizes` + in the compilation config, those sizes are used. The list is + de-duplicated and sorted in ascending order. + + 3. If no sizes are provided by the user, a default list of sizes is + generated. The maximum size for this list is determined automatically + by `compute_encoder_budget` (capped at 8192), or by the user-provided + `max_mm_encoder_cudagraph_capture_size`. The default sizes are: + [512, 1024, 1536] + list(range(2048, 4096, 128)) + list( + range(4096, max_size + 1, 256)) + + 4. The final list of sizes is stored in + `self.compilation_config.mm_encoder_cudagraph_capture_sizes`. The + `max_mm_encoder_cudagraph_capture_size` is also updated to be consistent + with the largest value in this final list. + + At runtime: + - If a batch's size matches or is smaller than a captured size, the + closest captured graph is used. + - If a batch's size is larger than the largest captured size, a CUDA + graph will not be used for that batch (fallback to eager execution). + """ + if ( + self.model_config is not None + and not self.model_config.enforce_eager + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and self.compilation_config.compile_mm_encoder + ): + # determine the initial max_mm_encoder_cudagraph_capture_size + max_mm_encoder_cudagraph_capture_size = ( + self.compilation_config.max_mm_encoder_cudagraph_capture_size + ) + if max_mm_encoder_cudagraph_capture_size is None: + from vllm.multimodal import MULTIMODAL_REGISTRY + from vllm.multimodal.budget import MultiModalBudget + + mm_budget = MultiModalBudget(self, MULTIMODAL_REGISTRY) + encoder_compute_budget = ( + mm_budget.encoder_compute_budget if mm_budget else 0 + ) + max_mm_encoder_cudagraph_capture_size = min( + encoder_compute_budget, 8192 + ) + + # determine the mm_encoder_cudagraph_capture_sizes + if self.compilation_config.mm_encoder_cudagraph_capture_sizes is not None: + # de-duplicate the sizes provided by the config + dedup_sizes = list( + set(self.compilation_config.mm_encoder_cudagraph_capture_sizes) + ) + mm_encoder_cudagraph_capture_sizes = dedup_sizes + # sort to make sure the sizes are in ascending order + mm_encoder_cudagraph_capture_sizes.sort() + else: + mm_encoder_cudagraph_capture_sizes = [ + i + for i in [512, 1024, 1536] + if i <= max_mm_encoder_cudagraph_capture_size + ] + if max_mm_encoder_cudagraph_capture_size >= 2048: + # Step size 128 for larger batch sizes + mm_encoder_cudagraph_capture_sizes += list( + range( + 2048, + min(max_mm_encoder_cudagraph_capture_size + 1, 4096), + 128, + ) + ) + if max_mm_encoder_cudagraph_capture_size >= 4096: + # Step size 256 for largest batch sizes + mm_encoder_cudagraph_capture_sizes += list( + range(4096, max_mm_encoder_cudagraph_capture_size + 1, 256) + ) + + # user-specific compilation_config.max_mm_encoder_cudagraph_capture_size get + # truncated to valid_max_size when they are inconsistent. + valid_max_size = ( + mm_encoder_cudagraph_capture_sizes[-1] + if mm_encoder_cudagraph_capture_sizes + else 0 + ) + if ( + self.compilation_config.max_mm_encoder_cudagraph_capture_size + is not None + and self.compilation_config.max_mm_encoder_cudagraph_capture_size + != valid_max_size + ): + # raise error only when both two flags are user-specified + # and they are inconsistent with each other + if ( + self.compilation_config.mm_encoder_cudagraph_capture_sizes + is not None + ): + raise ValueError( + "customized max_mm_encoder_cudagraph_capture_size(=" + f"{ + self.compilation_config.max_mm_encoder_cudagraph_capture_size + }" + ") should be consistent with the max value of " + f"mm_encoder_cudagraph_capture_sizes(={valid_max_size})" + ) + + logger.warning( + "Truncating max_mm_encoder_cudagraph_capture_size to %d", + valid_max_size, + ) + # always set the final max_mm_encoder_cudagraph_capture_size + self.compilation_config.max_mm_encoder_cudagraph_capture_size = ( + valid_max_size + ) + self.compilation_config.mm_encoder_cudagraph_capture_sizes = ( + mm_encoder_cudagraph_capture_sizes + ) + else: + # no cudagraph in use + self.compilation_config.max_mm_encoder_cudagraph_capture_size = 0 + self.compilation_config.mm_encoder_cudagraph_capture_sizes = [] + def try_verify_and_update_config(self): if self.model_config is None: return diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f3e7729f64e3..7a8599a29e43 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -378,6 +378,9 @@ class EngineArgs: max_cudagraph_capture_size: int | None = get_field( CompilationConfig, "max_cudagraph_capture_size" ) + mm_encoder_cudagraph_capture_sizes: list[int] | None = ( + CompilationConfig.mm_encoder_cudagraph_capture_sizes + ) # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. @@ -1148,6 +1151,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: compilation_group.add_argument( "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"] ) + compilation_group.add_argument( + "--mm_encoder-cudagraph-capture-sizes", + **compilation_kwargs["mm_encoder_cudagraph_capture_sizes"], + ) compilation_group.add_argument( "--max-cudagraph-capture-size", **compilation_kwargs["max_cudagraph_capture_size"], @@ -1737,6 +1744,17 @@ def create_engine_config( "cudagraph_capture_sizes are mutually exclusive" ) compilation_config.cudagraph_capture_sizes = self.cudagraph_capture_sizes + + if self.mm_encoder_cudagraph_capture_sizes is not None: + if compilation_config.mm_encoder_cudagraph_capture_sizes is not None: + raise ValueError( + "mm_encoder_cudagraph_capture_sizes and compilation_config." + "mm_encoder_cudagraph_capture_sizes are mutually exclusive" + ) + compilation_config.mm_encoder_cudagraph_capture_sizes = ( + self.mm_encoder_cudagraph_capture_sizes + ) + if self.max_cudagraph_capture_size is not None: if compilation_config.max_cudagraph_capture_size is not None: raise ValueError( diff --git a/vllm/forward_context.py b/vllm/forward_context.py index e308c05bc669..7d5c48a2e506 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -250,6 +250,10 @@ class ForwardContext: all_moe_layers: list[str] | None = None moe_layer_index: int = 0 + # mm_encoder Multi-Modal Encoder flags used by backend compiler + is_first_graph_in_mm_encoder_sequence: bool = True + is_last_graph_in_mm_encoder_sequence: bool = True + additional_kwargs: dict[str, Any] = field(default_factory=dict) def __post_init__(self): diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 0310c5415dc9..ca13b8c096b1 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -42,11 +42,23 @@ Qwen2_5_VLVisionConfig, ) +from vllm.compilation.backends import ( + set_is_first_graph_in_mm_encoder_sequence, + set_is_last_graph_in_mm_encoder_sequence, +) from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import ( + CUDAGraphMode, + VllmConfig, + get_current_vllm_config, +) from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils -from vllm.forward_context import set_forward_context +from vllm.forward_context import ( + get_forward_context, + is_forward_context_available, + set_forward_context, +) from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.attention import MMEncoderAttention @@ -640,6 +652,30 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.merger", ) + vllm_config: VllmConfig = get_current_vllm_config() + self._persistent_hidden_states_buffer: torch.Tensor | None = None + self._persistent_rotary_pos_emb_cos_buffer: torch.Tensor | None = None + self._persistent_rotary_pos_emb_sin_buffer: torch.Tensor | None = None + if vllm_config.compilation_config.mm_encoder_cudagraph_capture_sizes: + max_compile_size = ( + vllm_config.compilation_config.mm_encoder_cudagraph_capture_sizes[-1] + ) + self._persistent_hidden_states_buffer = torch.empty( + (max_compile_size, self.patch_embed.proj.input_size), + device=self.device, + dtype=self.dtype, + ) + ( + self._persistent_rotary_pos_emb_cos_buffer, + self._persistent_rotary_pos_emb_sin_buffer, + ) = [ + torch.empty( + (max_compile_size, head_dim // 2), + device=self.device, + dtype=torch.bfloat16, + ) + for _ in range(2) + ] @property def dtype(self) -> torch.dtype: @@ -771,6 +807,17 @@ def invert_permutation(perm: torch.Tensor) -> torch.Tensor: inv[perm] = torch.arange(perm.numel(), device=perm.device, dtype=perm.dtype) return inv + def _use_piecewise_cudagraph(self) -> bool: + if self._persistent_hidden_states_buffer is None: + return False + if not is_forward_context_available(): + return False + fwd_ctx = get_forward_context() + return ( + fwd_ctx is not None + and fwd_ctx.cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + ) + def forward( self, x: torch.Tensor, @@ -784,8 +831,19 @@ def forward( cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)] cu_seqlens: list = [] - hidden_states = x.to(device=self.device, dtype=self.dtype) - hidden_states = self.patch_embed(hidden_states) + is_cudagraph_mode = self._use_piecewise_cudagraph() + + if is_cudagraph_mode: + hidden_states = self._persistent_hidden_states_buffer[:seq_len] + hidden_states.copy_(x, non_blocking=True) + else: + hidden_states = x.to(device=self.device, dtype=self.dtype) + + with ( + set_is_first_graph_in_mm_encoder_sequence(True), + set_is_last_graph_in_mm_encoder_sequence(False), + ): + hidden_states = self.patch_embed(hidden_states) window_index_id = 0 cu_window_seqlens_last = 0 @@ -838,34 +896,53 @@ def forward( rotary_pos_emb_sin = rotary_pos_emb_sin.to( device=self.device, non_blocking=True ) + if is_cudagraph_mode: + rotary_pos_emb_sin = self._persistent_rotary_pos_emb_sin_buffer[ + :seq_len + ].copy_(rotary_pos_emb_sin) + rotary_pos_emb_cos = self._persistent_rotary_pos_emb_cos_buffer[ + :seq_len + ].copy_(rotary_pos_emb_cos) window_index = window_index.to(device=hidden_states.device, non_blocking=True) reverse_indices = reverse_indices.to( device=hidden_states.device, non_blocking=True ) + original_hidden_states = hidden_states hidden_states = hidden_states.reshape( seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 ) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) - hidden_states = hidden_states.unsqueeze(1) - for layer_num, blk in enumerate(self.blocks): - if layer_num in self.fullatt_block_indexes: - cu_seqlens_now = cu_seqlens - max_seqlen_now = max_seqlen_full - else: - cu_seqlens_now = cu_window_seqlens - max_seqlen_now = max_seqlen_window - - hidden_states = blk( - hidden_states, - cu_seqlens=cu_seqlens_now, - rotary_pos_emb_cos=rotary_pos_emb_cos, - rotary_pos_emb_sin=rotary_pos_emb_sin, - max_seqlen=max_seqlen_now, - ) + if is_cudagraph_mode: + # The above operations will produce temporary new tensors. + # That is not friendly to cudagraphs, + # so we need to copy them back to the persistent buffer + original_hidden_states = original_hidden_states.view(hidden_states.shape) + original_hidden_states.copy_(hidden_states) + hidden_states = original_hidden_states + + with ( + set_is_first_graph_in_mm_encoder_sequence(False), + set_is_last_graph_in_mm_encoder_sequence(False), + ): + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + max_seqlen_now = max_seqlen_full + else: + cu_seqlens_now = cu_window_seqlens + max_seqlen_now = max_seqlen_window + + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen_now, + ) # For Qwen2.5-VL-3B, float16 will overflow at last block # for long visual tokens sequences. @@ -873,7 +950,11 @@ def forward( hidden_states = cast_overflow_tensors(hidden_states) # adapter - hidden_states = self.merger(hidden_states) + with ( + set_is_first_graph_in_mm_encoder_sequence(False), + set_is_last_graph_in_mm_encoder_sequence(True), + ): + hidden_states = self.merger(hidden_states) hidden_states = hidden_states[reverse_indices, :] return hidden_states @@ -1197,7 +1278,9 @@ def _parse_and_validate_video_input( ) def _process_image_input( - self, image_input: Qwen2_5_VLImageInputs + self, + image_input: Qwen2_5_VLImageInputs, + mm_cudagraph_manager: Any | None = None, ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 @@ -1207,13 +1290,16 @@ def _process_image_input( image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"] - with set_forward_context(None, self.vllm_config): - if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" - ) - else: - image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) + if self.use_data_parallel and not self.vllm_config.in_mm_encoder_tracing: + return run_dp_sharded_mrope_vision_model( + self.visual, + pixel_values, + grid_thw_list, + rope_type="rope_3d", + mm_cudagraph_manager=mm_cudagraph_manager, + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size @@ -1263,7 +1349,10 @@ def _process_video_input( else: pixel_values_videos = video_input["pixel_values_videos"] with set_forward_context(None, self.vllm_config): - if self.use_data_parallel: + if ( + self.use_data_parallel + and not self.vllm_config.in_mm_encoder_tracing + ): return run_dp_sharded_mrope_vision_model( self.visual, pixel_values_videos, @@ -1418,6 +1507,7 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: return mm_input_by_modality def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + mm_cudagraph_manager = kwargs.pop("mm_cudagraph_manager", None) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return [] @@ -1431,7 +1521,9 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: for modality in mm_input_by_modality: multimodal_input = mm_input_by_modality[modality] if modality == "image": - image_embeddings = self._process_image_input(multimodal_input) + image_embeddings = self._process_image_input( + multimodal_input, mm_cudagraph_manager=mm_cudagraph_manager + ) if self.is_multimodal_pruning_enabled: image_embeddings = self._postprocess_image_embeds_evs( image_embeddings, multimodal_input diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index c7c26c206726..94348a77b55a 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1043,6 +1043,25 @@ def get_dummy_mm_data( ), } + def _calculate_patch_size(self, patches: int) -> tuple[int, int]: + vision_config = self.info.get_hf_config().vision_config + merge_size = vision_config.spatial_merge_size + + assert patches % (merge_size * merge_size) == 0, ( + f"Qwen2-VL: Number of patches ({patches}) must be multiple of " + f"merge_size squared ({merge_size}^2)" + ) + h_patches = merge_size + w_patches = patches // merge_size + return h_patches, w_patches + + def _get_img_feature_dim(self) -> int: + vision_config = self.info.get_hf_config().vision_config + in_channels = vision_config.in_channels + temporal_patch_size = vision_config.temporal_patch_size + patch_size = vision_config.patch_size + return in_channels * temporal_patch_size * patch_size * patch_size + class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]): def _get_prompt_updates( diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 97754833953f..3f51ce90fbd2 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -48,10 +48,23 @@ ) from transformers.video_utils import VideoMetadata +from vllm.compilation.backends import ( + set_is_first_graph_in_mm_encoder_sequence, + set_is_last_graph_in_mm_encoder_sequence, +) from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import ( + CUDAGraphMode, + VllmConfig, + get_current_vllm_config, +) from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_pp_group +from vllm.forward_context import ( + get_forward_context, + is_forward_context_available, + set_forward_context, +) from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.conv import Conv3dLayer @@ -65,6 +78,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.vision import should_torch_compile_mm_vit from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.evs import ( compute_mrope_for_media, @@ -139,6 +153,7 @@ DUMMY_VIDEO_NUM_FRAMES = 2048 +@support_torch_compile(dynamic_arg_dims={"x": 0}, enable_if=should_torch_compile_mm_vit) class Qwen3_VisionPatchEmbed(nn.Module): def __init__( self, @@ -205,6 +220,15 @@ def forward(self, x: torch.Tensor): return mlp_output +@support_torch_compile( + dynamic_arg_dims={ + "x": 0, + "cu_seqlens": 0, + "rotary_pos_emb_cos": 0, + "rotary_pos_emb_sin": 0, + }, + enable_if=should_torch_compile_mm_vit, +) class Qwen3_VisionBlock(nn.Module): def __init__( self, @@ -257,6 +281,7 @@ def forward( return x +@support_torch_compile(dynamic_arg_dims={"x": 0}, enable_if=should_torch_compile_mm_vit) class Qwen3_VisionPatchMerger(nn.Module): def __init__( self, @@ -286,6 +311,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.linear_fc1", disable_tp=use_data_parallel, + return_bias=False, ) self.act_fn = nn.GELU() self.linear_fc2 = RowParallelLinear( @@ -295,6 +321,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.linear_fc2", disable_tp=use_data_parallel, + return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -303,9 +330,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: x = self.norm(x).view(-1, self.hidden_size) - x_parallel, _ = self.linear_fc1(x) + x_parallel = self.linear_fc1(x) x_parallel = self.act_fn(x_parallel) - out, _ = self.linear_fc2(x_parallel) + out = self.linear_fc2(x_parallel) return out @@ -333,13 +360,18 @@ def __init__( self.out_hidden_size = vision_config.out_hidden_size * ( 1 + len(self.deepstack_visual_indexes) ) - - self.patch_embed = Qwen3_VisionPatchEmbed( - patch_size=self.patch_size, - temporal_patch_size=self.temporal_patch_size, - in_channels=vision_config.in_channels, - hidden_size=self.hidden_size, - ) + # TODO[@lucaskabela]: Investigate fixing this usage + # see https://github.com/vllm-project/vllm/issues/27044 + # DO NOT MOVE THIS IMPORT + from vllm.compilation.backends import set_model_tag + + with set_model_tag("Qwen3_VisionPatchEmbed", is_encoder=True): + self.patch_embed = Qwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) @@ -352,29 +384,31 @@ def __init__( rope_parameters={"partial_rotary_factor": 0.5}, ) - self.merger = Qwen3_VisionPatchMerger( - d_model=vision_config.out_hidden_size, - context_dim=self.hidden_size, - norm_layer=norm_layer, - spatial_merge_size=self.spatial_merge_size, - quant_config=quant_config, - prefix=f"{prefix}.merger", - ) + with set_model_tag("Qwen3_VisionPatchMerger", is_encoder=True): + self.merger = Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + ) - self.deepstack_merger_list = nn.ModuleList( - [ - Qwen3_VisionPatchMerger( - d_model=vision_config.out_hidden_size, - context_dim=self.hidden_size, - spatial_merge_size=self.spatial_merge_size, - use_postshuffle_norm=True, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", - ) - for layer_idx in range(len(self.deepstack_visual_indexes)) - ] - ) + with set_model_tag("Qwen3_VisionPatchMerger_postshuffle_norm", is_encoder=True): + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", + ) + for layer_idx in range(len(self.deepstack_visual_indexes)) + ] + ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, @@ -389,20 +423,45 @@ def __init__( raise RuntimeError( f"Qwen3-VL does not support {self.attn_backend} backend now." ) - self.blocks = nn.ModuleList( - [ - Qwen3_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", + with set_model_tag("Qwen3_VisionBlock", is_encoder=True): + self.blocks = nn.ModuleList( + [ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + ) + for layer_idx in range(vision_config.depth) + ] + ) + vllm_config: VllmConfig = get_current_vllm_config() + self._persistent_hidden_states_buffer: torch.Tensor | None = None + self._persistent_rotary_pos_emb_cos_buffer: torch.Tensor | None = None + self._persistent_rotary_pos_emb_sin_buffer: torch.Tensor | None = None + if vllm_config.compilation_config.mm_encoder_cudagraph_capture_sizes: + max_compile_size = ( + vllm_config.compilation_config.mm_encoder_cudagraph_capture_sizes[-1] + ) + self._persistent_hidden_states_buffer = torch.empty( + (max_compile_size, self.patch_embed.proj.input_size), + device=self.device, + dtype=self.dtype, + ) + ( + self._persistent_rotary_pos_emb_cos_buffer, + self._persistent_rotary_pos_emb_sin_buffer, + ) = [ + torch.empty( + (max_compile_size, head_dim // 2), + device=self.device, + dtype=torch.bfloat16, ) - for layer_idx in range(vision_config.depth) + for _ in range(2) ] - ) @property def dtype(self) -> torch.dtype: @@ -529,13 +588,38 @@ def compute_attn_mask_seqlen( max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen + def _use_piecewise_cudagraph(self) -> bool: + if self._persistent_hidden_states_buffer is None: + return False + if not is_forward_context_available(): + return False + fwd_ctx = get_forward_context() + return ( + fwd_ctx is not None + and fwd_ctx.cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + ) + def forward( self, x: torch.Tensor, grid_thw: torch.Tensor | list[list[int]], ) -> torch.Tensor: - hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True) - hidden_states = self.patch_embed(hidden_states) + seq_len, _ = x.size() + is_cudagraph_mode = self._use_piecewise_cudagraph() + + if is_cudagraph_mode: + hidden_states = self._persistent_hidden_states_buffer[:seq_len] + hidden_states.copy_(x, non_blocking=True) + else: + hidden_states = x.to( + device=self.device, dtype=self.dtype, non_blocking=True + ) + + with ( + set_is_first_graph_in_mm_encoder_sequence(True), + set_is_last_graph_in_mm_encoder_sequence(False), + ): + hidden_states = self.patch_embed(hidden_states) if isinstance(grid_thw, list): grid_thw_list = grid_thw @@ -545,9 +629,18 @@ def forward( grid_thw = grid_thw.numpy() pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) + original_hidden_states = hidden_states hidden_states = hidden_states + pos_embeds rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) + if is_cudagraph_mode: + rotary_pos_emb_sin = self._persistent_rotary_pos_emb_sin_buffer[ + :seq_len + ].copy_(rotary_pos_emb_sin) + rotary_pos_emb_cos = self._persistent_rotary_pos_emb_cos_buffer[ + :seq_len + ].copy_(rotary_pos_emb_cos) + cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( axis=0, dtype=np.int32 ) @@ -558,22 +651,40 @@ def forward( max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) + if is_cudagraph_mode: + # The above operations will produce temporary new tensors. + # That is not friendly to cudagraphs, + # so we need to copy them back to the persistent buffer + original_hidden_states = original_hidden_states.view(hidden_states.shape) + original_hidden_states.copy_(hidden_states) + hidden_states = original_hidden_states + deepstack_feature_lists = [] - for layer_num, blk in enumerate(self.blocks): - hidden_states = blk( - hidden_states, - cu_seqlens=cu_seqlens, - rotary_pos_emb_cos=rotary_pos_emb_cos, - rotary_pos_emb_sin=rotary_pos_emb_sin, - max_seqlen=max_seqlen, - ) - if layer_num in self.deepstack_visual_indexes: - deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num) - deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx]( - hidden_states + with ( + set_is_first_graph_in_mm_encoder_sequence(False), + set_is_last_graph_in_mm_encoder_sequence(False), + ): + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, ) - deepstack_feature_lists.append(deepstack_feature) - hidden_states = self.merger(hidden_states) + if layer_num in self.deepstack_visual_indexes: + deepstack_merger_idx = self.deepstack_visual_indexes.index( + layer_num + ) + deepstack_feature = self.deepstack_merger_list[ + deepstack_merger_idx + ](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + with ( + set_is_first_graph_in_mm_encoder_sequence(False), + set_is_last_graph_in_mm_encoder_sequence(True), + ): + hidden_states = self.merger(hidden_states) hidden_states = torch.cat( [hidden_states] + deepstack_feature_lists, dim=1 ) # [seq_len, hidden_size * (1 + depth_of_deepstack)] @@ -906,6 +1017,25 @@ def _get_dummy_videos( video_items.append(video_item) return video_items + def _calculate_patch_size(self, patches: int) -> tuple[int, int]: + vision_config = self.info.get_hf_config().vision_config + merge_size = vision_config.spatial_merge_size + + assert patches % (merge_size * merge_size) == 0, ( + f"Qwen3-VL: Number of patches ({patches}) must be multiple of " + f"merge_size squared ({merge_size}^2)" + ) + h_patches = merge_size + w_patches = patches // merge_size + return h_patches, w_patches + + def _get_img_feature_dim(self) -> int: + vision_config = self.info.get_hf_config().vision_config + in_channels = vision_config.in_channels + temporal_patch_size = vision_config.temporal_patch_size + patch_size = vision_config.patch_size + return in_channels * temporal_patch_size * patch_size * patch_size + class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]): def _call_hf_processor( @@ -1257,6 +1387,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): multimodal_config = vllm_config.model_config.multimodal_config self.config = config + self.vllm_config = vllm_config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.video_pruning_rate = multimodal_config.video_pruning_rate @@ -1405,21 +1536,29 @@ def _parse_and_validate_video_input( ) def _process_image_input( - self, image_input: Qwen2_5_VLImageInputs + self, + image_input: Qwen2_5_VLImageInputs, + mm_cudagraph_manager: Any | None = None, ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) - if self.use_data_parallel: + + if self.use_data_parallel and not self.vllm_config.in_mm_encoder_tracing: return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d" + self.visual, + pixel_values, + grid_thw_list, + rope_type="rope_3d", + mm_cudagraph_manager=mm_cudagraph_manager, ) else: - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size @@ -1438,13 +1577,20 @@ def _process_video_input( pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype ) - if self.use_data_parallel: - grid_thw_list = grid_thw.tolist() - return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" - ) - else: - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + with set_forward_context(None, self.vllm_config): + if ( + self.use_data_parallel + and not self.vllm_config.in_mm_encoder_tracing + ): + grid_thw_list = grid_thw.tolist() + return run_dp_sharded_mrope_vision_model( + self.visual, + pixel_values_videos, + grid_thw_list, + rope_type="rope_3d", + ) + else: + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size @@ -1888,6 +2034,7 @@ def get_mrope_input_positions( return torch.from_numpy(llm_positions), mrope_position_delta def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: + mm_cudagraph_manager = kwargs.pop("mm_cudagraph_manager", None) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return None @@ -1901,7 +2048,9 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: for modality in mm_input_by_modality: multimodal_input = mm_input_by_modality[modality] if modality == "image": - image_embeddings = self._process_image_input(multimodal_input) + image_embeddings = self._process_image_input( + multimodal_input, mm_cudagraph_manager=mm_cudagraph_manager + ) if self.is_multimodal_pruning_enabled: image_embeddings = self._postprocess_image_embeds_evs( image_embeddings, multimodal_input diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index a2b78753a0c6..837ddf4a2534 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import itertools import math from abc import ABC, abstractmethod @@ -16,9 +17,11 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.worker.mm_cudagraph import MMEncoderCudagraphManager logger = init_logger(__name__) @@ -387,6 +390,7 @@ def run_dp_sharded_mrope_vision_model( grid_thw_list: list[list[int]], *, rope_type: Literal["rope_3d", "rope_2d"], + mm_cudagraph_manager: MMEncoderCudagraphManager | None = None, ) -> tuple[torch.Tensor, ...]: """Run a vision model with data parallelism (DP) sharding. The function will shard the input image tensor on the @@ -474,31 +478,58 @@ def run_dp_sharded_mrope_vision_model( max_len_per_rank = max(grouped_pixel_values_len) // embed_dim_reduction_factor local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] - # Run the vision model on the local pixel_values_local - if rope_type == "rope_2d": - if pixel_values_local.shape[0] > 0: - image_embeds_local = vision_model( - pixel_values_local, torch.tensor(local_grid_thw_list) - ) - if isinstance(image_embeds_local, list): - image_embeds_local = torch.cat(image_embeds_local, dim=0) - else: - out_dim = getattr(vision_model.config, "hidden_size", None) - image_embeds_local = torch.empty( - (0, embed_dim_reduction_factor, out_dim), - device=pixel_values.device, - dtype=pixel_values.dtype, - ) - else: - if pixel_values_local.shape[0] > 0: - image_embeds_local = vision_model(pixel_values_local, local_grid_thw_list) + # Context setup + ctx = contextlib.nullcontext() + + if mm_cudagraph_manager is not None: + mm_groups: dict[str, torch.Tensor | list] = { + "pixel_values": pixel_values_local, + "image_grid_thw": local_grid_thw_list, + } + ( + cudagraph_runtime_mode, + batch_descriptor, + _, + mm_groups, + ) = mm_cudagraph_manager.dispatch_and_pad_mm_input(mm_groups) + pixel_values_local = mm_groups["pixel_values"] + local_grid_thw_list = mm_groups["image_grid_thw"] + + ctx = set_forward_context( + None, + vllm_config=mm_cudagraph_manager.vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ) + + with ctx: + # Run the vision model on the local pixel_values_local + if rope_type == "rope_2d": + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model( + pixel_values_local, torch.tensor(local_grid_thw_list) + ) + if isinstance(image_embeds_local, list): + image_embeds_local = torch.cat(image_embeds_local, dim=0) + else: + out_dim = getattr(vision_model.config, "hidden_size", None) + image_embeds_local = torch.empty( + (0, embed_dim_reduction_factor, out_dim), + device=pixel_values.device, + dtype=pixel_values.dtype, + ) else: - # Handle empty case - image_embeds_local = torch.empty( - (0, vision_model.out_hidden_size), - device=pixel_values.device, - dtype=pixel_values.dtype, - ) + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model( + pixel_values_local, local_grid_thw_list + ) + else: + # Handle empty case + image_embeds_local = torch.empty( + (0, vision_model.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype, + ) # Pad the output based on max_len_per_rank # for tensor_model_parallel_all_gather to work @@ -522,6 +553,9 @@ def run_dp_sharded_mrope_vision_model( device=image_embeds_local.device, ) image_embeds_local_padded = torch.cat([image_embeds_local, padding], dim=0) + # truncate the padded output from CUDA graph execution + elif current_len > max_len_per_rank: + image_embeds_local_padded = image_embeds_local[:max_len_per_rank] else: image_embeds_local_padded = image_embeds_local diff --git a/vllm/multimodal/processing/dummy_inputs.py b/vllm/multimodal/processing/dummy_inputs.py index b23e2b86cc20..9eb1020db681 100644 --- a/vllm/multimodal/processing/dummy_inputs.py +++ b/vllm/multimodal/processing/dummy_inputs.py @@ -7,6 +7,7 @@ import numpy as np import numpy.typing as npt +import torch from PIL import Image from vllm.config.multimodal import ( @@ -199,3 +200,52 @@ def _get_dummy_videos( height = min(height, overrides.height) video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) return [video] * num_videos + + @abstractmethod + def _get_img_feature_dim(self) -> int: + """ + Get the image feature dimension for MM encoder CUDA graph capture. + + Returns: + The image feature dimension. + """ + raise NotImplementedError + + @abstractmethod + def _calculate_patch_size(self, patches: int) -> tuple[int, int]: + """ + Calculate the patch grid size (height, width) from the total number of + patches. + """ + raise NotImplementedError + + def get_dummy_mm_encoder_input( + self, + num_patches: int, + ) -> "dict[str, torch.Tensor]": + """ + Get dummy MM encoder input for CUDA graph capture or padding. + + Args: + num_patches: Number of patches (tokens) for the dummy input + + Returns: + dict with pixel_values and image_grid_thw + """ + img_feature_dim = self._get_img_feature_dim() + + dtype = self.info.ctx.model_config.dtype + + h_patches, w_patches = self._calculate_patch_size(num_patches) + + pixel_values = torch.zeros( + (num_patches, img_feature_dim), dtype=dtype, device="cuda" + ) + grid_thw_list = torch.tensor( + [[1, h_patches, w_patches]], dtype=torch.long, device="cpu" + ) + + return { + "pixel_values": pixel_values, + "image_grid_thw": grid_thw_list, + } diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index f077a61c984f..b226fb8c1134 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -26,6 +26,7 @@ def flash_attn_maxseqlen_wrapper( v: torch.Tensor, batch_size: int, is_rocm_aiter: bool, + output: torch.Tensor, fa_version: int | None, scale: float | None = None, cu_seqlens: torch.Tensor | None = None, @@ -48,7 +49,7 @@ def flash_attn_maxseqlen_wrapper( max_seqlen = q_len if max_seqlen is None else max_seqlen.item() q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func( + fa_output = flash_attn_varlen_func( q, k, v, @@ -61,8 +62,9 @@ def flash_attn_maxseqlen_wrapper( softmax_scale=scale, **kwargs, ) - context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size) - return context_layer + context_layer = einops.rearrange(fa_output, "(b s) h d -> b s h d", b=batch_size) + output.copy_(context_layer) + return output def flash_attn_maxseqlen_wrapper_fake( @@ -71,6 +73,7 @@ def flash_attn_maxseqlen_wrapper_fake( v: torch.Tensor, batch_size: int, is_rocm_aiter: bool, + output: torch.Tensor, fa_version: int | None, scale: float | None = None, cu_seqlens: torch.Tensor | None = None, @@ -97,12 +100,15 @@ def vit_flash_attn_wrapper( cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, ) -> torch.Tensor: + b, s, h, d = q.shape + output = torch.empty((b, s, h, d), dtype=q.dtype, device=q.device) return torch.ops.vllm.flash_attn_maxseqlen_wrapper( q, k, v, batch_size, is_rocm_aiter, + output, fa_version, scale, cu_seqlens, @@ -132,6 +138,7 @@ def torch_sdpa_wrapper( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output: torch.Tensor, scale: float | None = None, cu_seqlens: torch.Tensor | None = None, ) -> torch.Tensor: @@ -143,7 +150,9 @@ def torch_sdpa_wrapper( v = v.contiguous() if cu_seqlens is None: - return apply_sdpa(q, k, v, scale=scale) + context_layer = apply_sdpa(q, k, v, scale=scale) + output.copy_(context_layer) + return output outputs = [] @@ -155,13 +164,15 @@ def torch_sdpa_wrapper( output_i = apply_sdpa(q_i, k_i, v_i, scale=scale) outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) - return context_layer + output.copy_(context_layer) + return output def torch_sdpa_wrapper_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output: torch.Tensor, scale: float | None, cu_seqlens: torch.Tensor | None, ) -> torch.Tensor: @@ -182,4 +193,6 @@ def vit_torch_sdpa_wrapper( scale: float | None = None, cu_seqlens: torch.Tensor | None = None, ) -> torch.Tensor: - return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, scale, cu_seqlens) + b, s, h, d = q.shape + output = torch.empty((b, s, h, d), dtype=q.dtype, device=q.device) + return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, output, scale, cu_seqlens) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 6f3e029c793b..2fca21831c86 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -65,15 +65,15 @@ def __init__(self, vllm_config: VllmConfig): ) # Default cudagraph_mode to NONE until initialize_cudagraph_keys is called self.cudagraph_mode = CUDAGraphMode.NONE + self.capture_sizes: list[int] = [] + self.max_capture_size: int = 0 def _compute_bs_to_padded_graph_size(self) -> None: """Pre-compute the mapping from batch size to padded graph size.""" - max_size = self.compilation_config.max_cudagraph_capture_size - capture_sizes = self.compilation_config.cudagraph_capture_sizes - self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1) + self._bs_to_padded_graph_size: list[int] = [0] * (self.max_capture_size + 1) for end, start in zip( - capture_sizes + [max_size + 1], - [0] + capture_sizes, + self.capture_sizes + [self.max_capture_size + 1], + [0] + self.capture_sizes, ): for bs in range(start, end): if bs == start: @@ -88,7 +88,7 @@ def _compute_bs_to_padded_graph_size(self) -> None: and self.cudagraph_mode != CUDAGraphMode.NONE ): for size in self.compilation_config.compile_sizes: - if size <= self.compilation_config.max_cudagraph_capture_size: + if size <= self.max_capture_size: padded = self._bs_to_padded_graph_size[size] if padded != size: raise ValueError( @@ -154,12 +154,22 @@ def add_cudagraph_key( self.cudagraph_keys[runtime_mode].add(batch_descriptor) def initialize_cudagraph_keys( - self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int = 1 + self, + cudagraph_mode: CUDAGraphMode, + uniform_decode_query_len: int = 1, + capture_sizes: list[int] | None = None, + max_capture_size: int | None = None, + enable_lora: bool = True, ): # This should be called only after attention backend is initialized. So we can # get the correct cudagraph mode after backend support is resolved. self.cudagraph_mode = cudagraph_mode - + self.capture_sizes = ( + capture_sizes or self.compilation_config.cudagraph_capture_sizes + ) + self.max_capture_size = ( + max_capture_size or self.compilation_config.max_cudagraph_capture_size + ) # Early exit if cudagraphs are disabled if cudagraph_mode == CUDAGraphMode.NONE: self.keys_initialized = True @@ -168,7 +178,7 @@ def initialize_cudagraph_keys( self._compute_bs_to_padded_graph_size() # Get LoRA cases to capture - lora_cases = self._get_lora_cases() + lora_cases = self._get_lora_cases() if enable_lora else [0] self.captured_lora_counts = [ lora_count for lora_count in lora_cases if lora_count ] @@ -177,9 +187,7 @@ def initialize_cudagraph_keys( # guarantee all keys would be used. For example, if we allow lazy # capturing in future PR, some keys may never be triggered. if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: - for bs, num_active_loras in product( - self.compilation_config.cudagraph_capture_sizes, lora_cases - ): + for bs, num_active_loras in product(self.capture_sizes, lora_cases): self.add_cudagraph_key( cudagraph_mode.mixed_mode(), self._create_padded_batch_descriptor( @@ -199,7 +207,7 @@ def initialize_cudagraph_keys( ) cudagraph_capture_sizes_for_decode = [ x - for x in self.compilation_config.cudagraph_capture_sizes + for x in self.capture_sizes if x <= max_num_tokens and x >= uniform_decode_query_len ] for bs, num_active_loras in product( @@ -241,7 +249,7 @@ def dispatch( if ( not self.keys_initialized or self.cudagraph_mode == CUDAGraphMode.NONE - or num_tokens > self.compilation_config.max_cudagraph_capture_size + or num_tokens > self.max_capture_size ): return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 49211c6805ce..bc48153ad060 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -170,6 +170,7 @@ from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.worker.mm_cudagraph import MMEncoderCudagraphManager from vllm.v1.worker.ubatch_utils import ( UBatchSlices, check_ubatch_thresholds, @@ -650,6 +651,16 @@ def __init__( # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) + # MM encoder CUDA graph manager for ViT piecewise CUDA graph. + self.mm_cudagraph_manager: MMEncoderCudagraphManager | None = None + if self.supports_mm_inputs: + processor = self.mm_registry.create_processor(self.model_config) + dummy_inputs_builder = processor.dummy_inputs + self.mm_cudagraph_manager = MMEncoderCudagraphManager( + self.vllm_config, + dummy_inputs_builder, + ) + self.mm_budget = ( MultiModalBudget(self.vllm_config, self.mm_registry) if self.supports_mm_inputs @@ -2420,12 +2431,44 @@ def _execute_mm_encoder( # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - - with self.timed_encoder_operation( - should_time, mm_lora_refs, current_item_idx, num_items - ): - curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) - + mm_mgr = self.mm_cudagraph_manager + is_vit_dp_mode = mm_mgr.is_vit_dp_mode if mm_mgr else False + + if not is_vit_dp_mode: + original_num_imgs = -1 + + # Default values for non-mm_encoder cudagraph case + cudagraph_runtime_mode = CUDAGraphMode.NONE + batch_descriptor = None + if mm_mgr is not None and "pixel_values" in mm_kwargs_group: + ( + cudagraph_runtime_mode, + batch_descriptor, + original_num_imgs, + mm_kwargs_group, + ) = mm_mgr.dispatch_and_pad_mm_input(mm_kwargs_group) + + with ( + set_forward_context( + None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ), + self.timed_encoder_operation( + should_time, mm_lora_refs, current_item_idx, num_items + ), + ): + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) + # Remove the padded items before sanity check + if original_num_imgs != -1: + curr_group_outputs = curr_group_outputs[:original_num_imgs] + else: + with self.timed_encoder_operation( + should_time, mm_lora_refs, current_item_idx, num_items + ): + mm_kwargs_group["mm_cudagraph_manager"] = mm_mgr + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, expected_num_items=num_items, @@ -5098,6 +5141,7 @@ def _dummy_pooler_run( return self._dummy_pooler_run_task(hidden_states, max_task) def profile_run(self) -> None: + self.vllm_config.in_mm_encoder_tracing = True # Profile with multimodal encoder & encoder cache. if self.supports_mm_inputs: mm_config = self.model_config.multimodal_config @@ -5135,9 +5179,10 @@ def profile_run(self) -> None: ) # Run multimodal encoder. - dummy_encoder_outputs = self.model.embed_multimodal( - **batched_dummy_mm_inputs - ) + with set_forward_context(None, self.vllm_config): + dummy_encoder_outputs = self.model.embed_multimodal( + **batched_dummy_mm_inputs + ) sanity_check_mm_encoder_outputs( dummy_encoder_outputs, @@ -5161,6 +5206,7 @@ def profile_run(self) -> None: del hidden_states, output self.encoder_cache.clear() gc.collect() + self.vllm_config.in_mm_encoder_tracing = False def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: @@ -5205,6 +5251,17 @@ def freeze_gc(): batch_descriptors=batch_descs, cudagraph_runtime_mode=runtime_mode, ) + # Capture MM encoder CUDA graphs if enabled + if self.mm_cudagraph_manager is not None: + for ( + runtime_mode, + batch_descs, + ) in self.mm_cudagraph_manager.dispatcher.get_capture_descs(): + self.mm_cudagraph_manager.capture( + model=self.model, + batch_descs=batch_descs, + cudagraph_mode=runtime_mode, + ) torch.cuda.synchronize() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -5569,6 +5626,9 @@ def _check_and_update_cudagraph_mode( cudagraph_mode, self.uniform_decode_query_len ) + if self.mm_cudagraph_manager is not None: + self.mm_cudagraph_manager.initialize_cudagraph_keys(cudagraph_mode) + # Initialize eagle's cudagraph dispatcher if using eagle spec decode. if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) diff --git a/vllm/v1/worker/mm_cudagraph.py b/vllm/v1/worker/mm_cudagraph.py new file mode 100644 index 000000000000..6175d7d5c893 --- /dev/null +++ b/vllm/v1/worker/mm_cudagraph.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, cast + +import torch +import torch.nn as nn +from tqdm import tqdm + +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.distributed.parallel_state import is_global_first_rank +from vllm.forward_context import ( + BatchDescriptor, + set_forward_context, +) +from vllm.logger import init_logger +from vllm.multimodal import BatchedTensorInputs +from vllm.multimodal.processing import BaseDummyInputsBuilder +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher + +logger = init_logger(__name__) + + +class MMEncoderCudagraphManager: + def __init__( + self, + vllm_config: VllmConfig, + dummy_input_builder: BaseDummyInputsBuilder[Any], + ): + self.vllm_config = vllm_config + self.dispatcher = CudagraphDispatcher(self.vllm_config) + self.dummy_input_builder = dummy_input_builder + + # Check if using data parallel mode for ViT + self.is_vit_dp_mode = self._check_vit_dp_mode(vllm_config) + + def _check_vit_dp_mode(self, vllm_config: VllmConfig) -> bool: + """Check if ViT is running in data parallel mode.""" + mm_config = getattr(vllm_config.model_config, "multimodal_config", None) + if mm_config is None: + return False + + mm_encoder_tp_mode = mm_config.mm_encoder_tp_mode + tp_size = vllm_config.parallel_config.tensor_parallel_size + + return mm_encoder_tp_mode == "data" and tp_size > 1 + + def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: + """Initialize cudagraph dispatcher keys for MM Encoder. + + MM Encoder only supports PIECEWISE cudagraphs. + """ + if cudagraph_mode.mixed_mode() in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]: + mm_cudagraph_mode = CUDAGraphMode.PIECEWISE + else: + mm_cudagraph_mode = CUDAGraphMode.NONE + + max_capture_size = ( + self.vllm_config.compilation_config.max_mm_encoder_cudagraph_capture_size + ) + + capture_sizes = ( + self.vllm_config.compilation_config.mm_encoder_cudagraph_capture_sizes + ) + + self.dispatcher.initialize_cudagraph_keys( + mm_cudagraph_mode, + capture_sizes=capture_sizes, + max_capture_size=max_capture_size, + enable_lora=False, + ) + + def dispatch_and_pad_mm_input( + self, + mm_kwargs_group: BatchedTensorInputs, + ) -> tuple[CUDAGraphMode, BatchDescriptor | None, int, BatchedTensorInputs]: + pixel_values = cast(torch.Tensor, mm_kwargs_group["pixel_values"]) + num_tokens = pixel_values.shape[0] + + image_grid_thw = mm_kwargs_group["image_grid_thw"] + if isinstance(image_grid_thw, torch.Tensor): + original_num_imgs = image_grid_thw.shape[0] + else: + original_num_imgs = len(image_grid_thw) + + # Dispatch to get the target padded size + cudagraph_runtime_mode, batch_descriptor = self.dispatcher.dispatch( + num_tokens=num_tokens, + ) + target_num_tokens = batch_descriptor.num_tokens + + # Pad if necessary + if target_num_tokens > num_tokens: + # Pad pixel_values + padding_size = target_num_tokens - num_tokens + padding_mm_inputs = self.dummy_input_builder.get_dummy_mm_encoder_input( + padding_size, + ) + + mm_kwargs_group["pixel_values"] = torch.cat( + [pixel_values, padding_mm_inputs["pixel_values"]], dim=0 + ) + + padding_image_grid_thw = padding_mm_inputs["image_grid_thw"] + if isinstance(image_grid_thw, torch.Tensor): + mm_kwargs_group["image_grid_thw"] = torch.cat( + [image_grid_thw, padding_image_grid_thw], dim=0 + ) + else: + mm_kwargs_group["image_grid_thw"] = ( + image_grid_thw + padding_image_grid_thw.tolist() + ) + + return ( + cudagraph_runtime_mode, + batch_descriptor, + original_num_imgs, + mm_kwargs_group, + ) + + def capture_graph( + self, + num_tokens: int, + model: nn.Module, + cudagraph_mode: CUDAGraphMode, + ) -> None: + dummy_mm_inputs = self.dummy_input_builder.get_dummy_mm_encoder_input( + num_tokens + ) + + batch_descriptor = BatchDescriptor(num_tokens=num_tokens) + + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=cudagraph_mode, + batch_descriptor=batch_descriptor, + ): + model.embed_multimodal(**dummy_mm_inputs) + + @torch.inference_mode() + def capture( + self, + model: nn.Module, + batch_descs: "list[BatchDescriptor]", + cudagraph_mode: CUDAGraphMode, + ) -> None: + self.vllm_config.in_mm_encoder_tracing = True + + if is_global_first_rank(): + batch_descriptors: Any = tqdm( + batch_descs, + disable=not self.vllm_config.load_config.use_tqdm_on_load, + desc="Capturing MM_Encoder CUDA graphs (PIECEWISE)", + ) + else: + batch_descriptors = batch_descs + + for batch_desc in batch_descriptors: + capture_size = batch_desc.num_tokens + self.capture_graph( + capture_size, + model=model, + cudagraph_mode=cudagraph_mode, + ) + + self.vllm_config.in_mm_encoder_tracing = False