Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/compile/fullgraph/test_multimodal_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,40 @@ def test_qwen2_5_vl_no_vit_compilation(vllm_runner, monkeypatch):
) as _,
):
pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
# Requires Cuda and 8 gpus as well
@pytest.mark.forked
@pytest.mark.skip(reason="Skipping due to CI resource constraints")
def test_mllama4_vit_compilation(vllm_runner, monkeypatch):
"""Test that Mllama4 vision submodules are compiled.

This test verifies that the 2 vision submodules (Llama4VisionEncoder,
Llama4VisionPixelShuffleMLP) are properly tagged
for compilation by checking that num_models_seen increases to 3.

However since we are using TP=8, we compilation_counter will not
work properly so we will just check the run succeeds rn
"""
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")

with (
monkeypatch.context(),
# TODO: Since we require TP=8, this messes with the compilation
# counter. We should fix this in the future, but leave for now
# to make sure that compilation runs (no crash) with llama vision encoder
compilation_counter.expect(num_models_seen=0),
vllm_runner(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
max_model_len=512,
gpu_memory_utilization=0.8,
tensor_parallel_size=8,
compilation_config={
"mode": CompilationMode.VLLM_COMPILE,
"compile_mm_encoder": True,
},
),
):
pass
5 changes: 3 additions & 2 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,9 @@ class CompilationConfig:
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
compile_mm_encoder: bool = False
"""Whether or not to compile the multimodal encoder.
Currently, this only works for `Qwen2_5_vl` on selected platforms.
Disabled by default until more models are supported/tested to work."""
Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models
on selected platforms. Disabled by default until more models
are supported/tested to work."""

# Inductor capture
compile_sizes: list[int | str] | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,12 @@ def _forward_fa(
q=query,
k=key,
v=value,
scale=self.scale,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
fa_version=self._fa_version,
scale=self.scale,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
if is_reshaped:
output = output.reshape(bsz, q_len, -1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,17 @@ def forward_native( # type: ignore[override]
assert key is not None
# self.cos_sin_cache here is complex tensor so we cannot cast into
# query's dtype directly with self._match_cos_sin_cache_dtype
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)

# NOTE: by not storing cos_sin_cache in self, we can avoid
# memory buffer update which is costly to runtime
cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
broadcast_shape = [
d if i == 1 or i == (query_.ndim - 1) else 1
for i, d in enumerate(query_.shape)
]
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
freqs_ci = cos_sin_cache.view(*broadcast_shape)
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
return query_out.type_as(query), key_out.type_as(key)
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,11 @@ def llama_model_invariants(
torch._check(positions.size()[0] == input_ids.size()[0])


@support_torch_compile(shape_invariants=llama_model_invariants)
@support_torch_compile(
# TODO[#32068]: Investigate recompilation
# mark_unbacked_dims={"input_ids": 0},
shape_invariants=llama_model_invariants
)
class LlamaModel(nn.Module):
def __init__(
self,
Expand Down
38 changes: 29 additions & 9 deletions vllm/model_executor/models/mllama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
get_best_fit,
)

from vllm.config import VllmConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
Expand All @@ -47,6 +49,7 @@
from vllm.model_executor.model_loader.utils import initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
Expand Down Expand Up @@ -456,6 +459,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


@support_torch_compile(
dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_vit
)
class Llama4VisionModel(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -497,6 +503,7 @@ def __init__(
prefix=f"{prefix}.model",
use_data_parallel=use_data_parallel,
)

self.vision_adapter = Llama4VisionPixelShuffleMLP(
config,
quant_config,
Expand Down Expand Up @@ -762,18 +769,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
multimodal_config = vllm_config.model_config.multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

self.vllm_config = vllm_config
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt("image"):
self.vision_model = Llama4VisionModel(
config.vision_config,
None,
prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel,
)
from vllm.compilation.backends import set_model_tag

with (
set_current_vllm_config(vllm_config),
set_model_tag("Llama4VisionModel", is_encoder=True),
):
self.vision_model = Llama4VisionModel(
config=config.vision_config,
quant_config=None,
prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel,
)

self.multi_modal_projector = Llama4MultiModalProjector(
self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector")
config=self.config,
quant_config=None,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
else:
self.vision_model = None
Expand Down Expand Up @@ -883,7 +900,10 @@ def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
if image_input is None:
return []

return self._process_image_input(image_input)
with (
set_forward_context(None, self.vllm_config),
):
return self._process_image_input(image_input)

def forward(
self,
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/attention/ops/vit_attn_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def flash_attn_maxseqlen_wrapper_fake(
batch_size: int,
is_rocm_aiter: bool,
fa_version: int | None,
scale: float | None,
cu_seqlens: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(q)

Expand Down