diff --git a/docs/design/torch_compile_multimodal.md b/docs/design/torch_compile_multimodal.md index 674ddd801d65..4abf1d08c517 100644 --- a/docs/design/torch_compile_multimodal.md +++ b/docs/design/torch_compile_multimodal.md @@ -26,7 +26,7 @@ This feature is off by default, but can be enabled by setting `compile_mm_encode To compile a multimodal component such as an encoder, we follow the same mechanism as the LLM text backbone, with a few additional scaffoldings: -1. The `@support_torch_compile` decorator should include `enable_if=should_torch_compile_mm_vit`. This will gate the compilation behind our +1. The `@support_torch_compile` decorator should include `enable_if=should_torch_compile_mm_encoder`. This will gate the compilation behind our `compile_mm_encoder` configuration 2. `with set_model_tag("", is_encoder=True)` context manager should be used around the nn.Module's instantiation. Since torch.compile diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index f8629be34b53..d52d457083ec 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -47,6 +47,11 @@ _T = TypeVar("_T", bound=nn.Module) +def should_torch_compile_mm_encoder(vllm_config: VllmConfig) -> bool: + """Callable to be passed to `@support_torch_compile`'s `enable_if` argument.""" + return vllm_config.compilation_config.compile_mm_encoder + + def ignore_torch_compile(cls: type[_T]) -> type[_T]: """ A decorator to ignore support_torch_compile decorator diff --git a/vllm/model_executor/models/lfm2_siglip2.py b/vllm/model_executor/models/lfm2_siglip2.py index 92ea42f27100..15ce3d8de428 100644 --- a/vllm/model_executor/models/lfm2_siglip2.py +++ b/vllm/model_executor/models/lfm2_siglip2.py @@ -10,7 +10,10 @@ from torch.nn import functional as F from transformers import Siglip2VisionConfig -from vllm.compilation.decorators import support_torch_compile +from vllm.compilation.decorators import ( + should_torch_compile_mm_encoder, + support_torch_compile, +) from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import MMEncoderAttention @@ -25,7 +28,6 @@ from .vision import ( is_vit_use_data_parallel, resolve_visual_encoder_outputs, - should_torch_compile_mm_vit, ) @@ -269,7 +271,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @support_torch_compile( dynamic_arg_dims={"hidden_states": [0, 1], "cu_seqlens": 0}, - enable_if=should_torch_compile_mm_vit, + enable_if=should_torch_compile_mm_encoder, ) class Siglip2EncoderLayer(nn.Module): def __init__( diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 305d13996b5a..6956f70235d5 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -31,7 +31,10 @@ get_best_fit, ) -from vllm.compilation.decorators import support_torch_compile +from vllm.compilation.decorators import ( + should_torch_compile_mm_encoder, + support_torch_compile, +) from vllm.config import VllmConfig, set_current_vllm_config from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size @@ -49,7 +52,6 @@ from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.models.vision import should_torch_compile_mm_vit from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, @@ -454,7 +456,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @support_torch_compile( - dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_vit + dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_encoder ) class Llama4VisionModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index cd5c5356e558..245748249819 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -42,7 +42,10 @@ Qwen2_5_VLVisionConfig, ) -from vllm.compilation.decorators import support_torch_compile +from vllm.compilation.decorators import ( + should_torch_compile_mm_encoder, + support_torch_compile, +) from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -65,7 +68,6 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.models.vision import should_torch_compile_mm_vit from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.evs import ( compute_mrope_for_media, @@ -424,7 +426,7 @@ def forward( "rotary_pos_emb_cos": 0, "rotary_pos_emb_sin": 0, }, - enable_if=should_torch_compile_mm_vit, + enable_if=should_torch_compile_mm_encoder, ) class Qwen2_5_VisionBlock(nn.Module): def __init__( @@ -483,7 +485,7 @@ def forward( dynamic_arg_dims={ "x": 0, }, - enable_if=should_torch_compile_mm_vit, + enable_if=should_torch_compile_mm_encoder, ) class Qwen2_5_VisionPatchEmbed(nn.Module): def __init__( @@ -518,7 +520,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dynamic_arg_dims={ "x": 0, }, - enable_if=should_torch_compile_mm_vit, + enable_if=should_torch_compile_mm_encoder, ) class Qwen2_5_VisionPatchMerger(nn.Module): def __init__( diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 8882754b3cc2..e6a243006759 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -143,11 +143,6 @@ def is_vit_use_data_parallel(): return mm_encoder_tp_mode == "data" -def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool: - """Callable to be passed to `@support_torch_compile`'s `enable_if` argument.""" - return vllm_config.compilation_config.compile_mm_encoder - - VisionFeatureSelectStrategyStr = Literal["class", "default", "full"] VisionFeatureSelectStrategy: TypeAlias = (