Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions vllm/model_executor/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
import torch
from transformers import PretrainedConfig

from vllm.config import MultiModalConfig, VllmConfig, get_current_vllm_config
from vllm.config import (
ModelConfig,
MultiModalConfig,
VllmConfig,
get_current_vllm_config,
)
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -95,20 +100,32 @@ def _get_vit_attn_backend(
)


def _get_current_mm_config() -> MultiModalConfig | None:
"""
Get the current MultiModalConfig if one exists.
"""
try:
vllm_config: VllmConfig = get_current_vllm_config()
except AssertionError:
multimodal_config = None
else:
# We have a vLLM config; if we don't have a model config for
# any reason, set the MM config to None by default as well.
model_config: ModelConfig | None = vllm_config.model_config
multimodal_config: MultiModalConfig | None = (
model_config.multimodal_config if model_config is not None else None
)
return multimodal_config


def get_vit_attn_backend(
head_size: int,
dtype: torch.dtype,
) -> AttentionBackendEnum:
"""
Get the attention backend for Vision Transformer.
"""
try:
vllm_config: VllmConfig = get_current_vllm_config()
multimodal_config: MultiModalConfig | None = (
vllm_config.model_config.multimodal_config
)
except AssertionError:
multimodal_config = None
multimodal_config: MultiModalConfig | None = _get_current_mm_config()

attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
Expand All @@ -127,13 +144,7 @@ def is_vit_use_data_parallel():
"""
Get the tensor parallel type for Vision Transformer.
"""
try:
vllm_config: VllmConfig = get_current_vllm_config()
multimodal_config: MultiModalConfig | None = (
vllm_config.model_config.multimodal_config
)
except AssertionError:
multimodal_config = None
multimodal_config: MultiModalConfig | None = _get_current_mm_config()

mm_encoder_tp_mode = (
multimodal_config.mm_encoder_tp_mode if multimodal_config is not None else None
Expand Down