diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 8be6aedccdd6..c35cc966b835 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -10,7 +10,12 @@ import torch from transformers import PretrainedConfig -from vllm.config import MultiModalConfig, VllmConfig, get_current_vllm_config +from vllm.config import ( + ModelConfig, + MultiModalConfig, + VllmConfig, + get_current_vllm_config, +) from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -95,6 +100,24 @@ def _get_vit_attn_backend( ) +def _get_current_mm_config() -> MultiModalConfig | None: + """ + Get the current MultiModalConfig if one exists. + """ + try: + vllm_config: VllmConfig = get_current_vllm_config() + except AssertionError: + multimodal_config = None + else: + # We have a vLLM config; if we don't have a model config for + # any reason, set the MM config to None by default as well. + model_config: ModelConfig | None = vllm_config.model_config + multimodal_config: MultiModalConfig | None = ( + model_config.multimodal_config if model_config is not None else None + ) + return multimodal_config + + def get_vit_attn_backend( head_size: int, dtype: torch.dtype, @@ -102,13 +125,7 @@ def get_vit_attn_backend( """ Get the attention backend for Vision Transformer. """ - try: - vllm_config: VllmConfig = get_current_vllm_config() - multimodal_config: MultiModalConfig | None = ( - vllm_config.model_config.multimodal_config - ) - except AssertionError: - multimodal_config = None + multimodal_config: MultiModalConfig | None = _get_current_mm_config() attn_backend_override = ( multimodal_config.mm_encoder_attn_backend @@ -127,13 +144,7 @@ def is_vit_use_data_parallel(): """ Get the tensor parallel type for Vision Transformer. """ - try: - vllm_config: VllmConfig = get_current_vllm_config() - multimodal_config: MultiModalConfig | None = ( - vllm_config.model_config.multimodal_config - ) - except AssertionError: - multimodal_config = None + multimodal_config: MultiModalConfig | None = _get_current_mm_config() mm_encoder_tp_mode = ( multimodal_config.mm_encoder_tp_mode if multimodal_config is not None else None