diff --git a/tests/compile/fullgraph/test_multimodal_compile.py b/tests/compile/fullgraph/test_multimodal_compile.py index 621f6a51a918..c5dc6f96b2a5 100644 --- a/tests/compile/fullgraph/test_multimodal_compile.py +++ b/tests/compile/fullgraph/test_multimodal_compile.py @@ -71,3 +71,40 @@ def test_qwen2_5_vl_no_vit_compilation(vllm_runner, monkeypatch): ) as _, ): pass + + +# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 +# Requires Cuda and 8 gpus as well +@pytest.mark.forked +@pytest.mark.skip(reason="Skipping due to CI resource constraints") +def test_mllama4_vit_compilation(vllm_runner, monkeypatch): + """Test that Mllama4 vision submodules are compiled. + + This test verifies that the 2 vision submodules (Llama4VisionEncoder, + Llama4VisionPixelShuffleMLP) are properly tagged + for compilation by checking that num_models_seen increases to 3. + + However since we are using TP=8, we compilation_counter will not + work properly so we will just check the run succeeds rn + """ + # Disable multiprocessing so that the counter is in the same process + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + with ( + monkeypatch.context(), + # TODO: Since we require TP=8, this messes with the compilation + # counter. We should fix this in the future, but leave for now + # to make sure that compilation runs (no crash) with llama vision encoder + compilation_counter.expect(num_models_seen=0), + vllm_runner( + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + max_model_len=512, + gpu_memory_utilization=0.8, + tensor_parallel_size=8, + compilation_config={ + "mode": CompilationMode.VLLM_COMPILE, + "compile_mm_encoder": True, + }, + ), + ): + pass diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c8a531c02fd7..61730d6819db 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -427,8 +427,9 @@ class CompilationConfig: If empty list [], no ops are excluded (suitable for full cudagraphs).""" compile_mm_encoder: bool = False """Whether or not to compile the multimodal encoder. - Currently, this only works for `Qwen2_5_vl` on selected platforms. - Disabled by default until more models are supported/tested to work.""" + Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models + on selected platforms. Disabled by default until more models + are supported/tested to work.""" # Inductor capture compile_sizes: list[int | str] | None = None diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 099fe23914cc..44e990d29c16 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -171,12 +171,12 @@ def _forward_fa( q=query, k=key, v=value, - scale=self.scale, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, batch_size=bsz, is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), fa_version=self._fa_version, + scale=self.scale, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, ) if is_reshaped: output = output.reshape(bsz, q_len, -1) diff --git a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py index 9fdac309df7e..f51429cd75c1 100644 --- a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py @@ -60,14 +60,17 @@ def forward_native( # type: ignore[override] assert key is not None # self.cos_sin_cache here is complex tensor so we cannot cast into # query's dtype directly with self._match_cos_sin_cache_dtype - self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) + + # NOTE: by not storing cos_sin_cache in self, we can avoid + # memory buffer update which is costly to runtime + cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2)) key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2)) broadcast_shape = [ d if i == 1 or i == (query_.ndim - 1) else 1 for i, d in enumerate(query_.shape) ] - freqs_ci = self.cos_sin_cache.view(*broadcast_shape) + freqs_ci = cos_sin_cache.view(*broadcast_shape) query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) return query_out.type_as(query), key_out.type_as(key) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 95b5f0f5bf19..b6427b866aa3 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -369,7 +369,11 @@ def llama_model_invariants( torch._check(positions.size()[0] == input_ids.size()[0]) -@support_torch_compile(shape_invariants=llama_model_invariants) +@support_torch_compile( + # TODO[#32068]: Investigate recompilation + # mark_unbacked_dims={"input_ids": 0}, + shape_invariants=llama_model_invariants +) class LlamaModel(nn.Module): def __init__( self, diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index fb66a03b8b22..13f79c91fec5 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -31,9 +31,11 @@ get_best_fit, ) -from vllm.config import VllmConfig +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig, set_current_vllm_config from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import set_forward_context from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ( @@ -47,6 +49,7 @@ from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.vision import should_torch_compile_mm_vit from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, @@ -456,6 +459,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states +@support_torch_compile( + dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_vit +) class Llama4VisionModel(nn.Module): def __init__( self, @@ -497,6 +503,7 @@ def __init__( prefix=f"{prefix}.model", use_data_parallel=use_data_parallel, ) + self.vision_adapter = Llama4VisionPixelShuffleMLP( config, quant_config, @@ -762,18 +769,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): multimodal_config = vllm_config.model_config.multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.vllm_config = vllm_config self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config if multimodal_config.get_limit_per_prompt("image"): - self.vision_model = Llama4VisionModel( - config.vision_config, - None, - prefix=maybe_prefix(prefix, "vision_model"), - use_data_parallel=self.use_data_parallel, - ) + from vllm.compilation.backends import set_model_tag + + with ( + set_current_vllm_config(vllm_config), + set_model_tag("Llama4VisionModel", is_encoder=True), + ): + self.vision_model = Llama4VisionModel( + config=config.vision_config, + quant_config=None, + prefix=maybe_prefix(prefix, "vision_model"), + use_data_parallel=self.use_data_parallel, + ) + self.multi_modal_projector = Llama4MultiModalProjector( - self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector") + config=self.config, + quant_config=None, + prefix=maybe_prefix(prefix, "multi_modal_projector"), ) else: self.vision_model = None @@ -883,7 +900,10 @@ def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: if image_input is None: return [] - return self._process_image_input(image_input) + with ( + set_forward_context(None, self.vllm_config), + ): + return self._process_image_input(image_input) def forward( self, diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index 72c45571f89a..f077a61c984f 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -72,9 +72,9 @@ def flash_attn_maxseqlen_wrapper_fake( batch_size: int, is_rocm_aiter: bool, fa_version: int | None, - scale: float | None, - cu_seqlens: torch.Tensor | None, - max_seqlen: torch.Tensor | None, + scale: float | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(q)