Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/torch_compile_multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ This feature is off by default, but can be enabled by setting `compile_mm_encode

To compile a multimodal component such as an encoder, we follow the same mechanism as the LLM text backbone, with a few additional scaffoldings:

1. The `@support_torch_compile` decorator should include `enable_if=should_torch_compile_mm_vit`. This will gate the compilation behind our
1. The `@support_torch_compile` decorator should include `enable_if=should_torch_compile_mm_encoder`. This will gate the compilation behind our
`compile_mm_encoder` configuration

2. `with set_model_tag("<component_name>", is_encoder=True)` context manager should be used around the nn.Module's instantiation. Since torch.compile
Expand Down
5 changes: 5 additions & 0 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
_T = TypeVar("_T", bound=nn.Module)


def should_torch_compile_mm_encoder(vllm_config: VllmConfig) -> bool:
"""Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
return vllm_config.compilation_config.compile_mm_encoder


def ignore_torch_compile(cls: type[_T]) -> type[_T]:
"""
A decorator to ignore support_torch_compile decorator
Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/models/lfm2_siglip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from torch.nn import functional as F
from transformers import Siglip2VisionConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.decorators import (
should_torch_compile_mm_encoder,
support_torch_compile,
)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import MMEncoderAttention
Expand All @@ -25,7 +28,6 @@
from .vision import (
is_vit_use_data_parallel,
resolve_visual_encoder_outputs,
should_torch_compile_mm_vit,
)


Expand Down Expand Up @@ -269,7 +271,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

@support_torch_compile(
dynamic_arg_dims={"hidden_states": [0, 1], "cu_seqlens": 0},
enable_if=should_torch_compile_mm_vit,
enable_if=should_torch_compile_mm_encoder,
)
class Siglip2EncoderLayer(nn.Module):
def __init__(
Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/models/mllama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
get_best_fit,
)

from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.decorators import (
should_torch_compile_mm_encoder,
support_torch_compile,
)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
Expand All @@ -49,7 +52,6 @@
from vllm.model_executor.model_loader.utils import initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
Expand Down Expand Up @@ -454,7 +456,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


@support_torch_compile(
dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_vit
dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_encoder
)
class Llama4VisionModel(nn.Module):
def __init__(
Expand Down
12 changes: 7 additions & 5 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
Qwen2_5_VLVisionConfig,
)

from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.decorators import (
should_torch_compile_mm_encoder,
support_torch_compile,
)
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
Expand All @@ -65,7 +68,6 @@
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.evs import (
compute_mrope_for_media,
Expand Down Expand Up @@ -424,7 +426,7 @@ def forward(
"rotary_pos_emb_cos": 0,
"rotary_pos_emb_sin": 0,
},
enable_if=should_torch_compile_mm_vit,
enable_if=should_torch_compile_mm_encoder,
)
class Qwen2_5_VisionBlock(nn.Module):
def __init__(
Expand Down Expand Up @@ -483,7 +485,7 @@ def forward(
dynamic_arg_dims={
"x": 0,
},
enable_if=should_torch_compile_mm_vit,
enable_if=should_torch_compile_mm_encoder,
)
class Qwen2_5_VisionPatchEmbed(nn.Module):
def __init__(
Expand Down Expand Up @@ -518,7 +520,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
dynamic_arg_dims={
"x": 0,
},
enable_if=should_torch_compile_mm_vit,
enable_if=should_torch_compile_mm_encoder,
)
class Qwen2_5_VisionPatchMerger(nn.Module):
def __init__(
Expand Down
5 changes: 0 additions & 5 deletions vllm/model_executor/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,6 @@ def is_vit_use_data_parallel():
return mm_encoder_tp_mode == "data"


def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:
"""Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
return vllm_config.compilation_config.compile_mm_encoder


VisionFeatureSelectStrategyStr = Literal["class", "default", "full"]

VisionFeatureSelectStrategy: TypeAlias = (
Expand Down