Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions vllm/model_executor/models/qwen2_5_omni_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
Expand All @@ -66,7 +67,8 @@
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
Expand Down Expand Up @@ -726,14 +728,30 @@ def _process_video_input(
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
)
class Qwen2_5OmniThinkerForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP,
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
Qwen2_5OmniConditionalGenerationMixin):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"thinker.lm_head.": "language_model.lm_head.",
"thinker.model.": "language_model.model.",
"thinker.": "",
})
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"attn.qkv": [
"attn.q",
"attn.k",
"attn.v",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
Expand Down Expand Up @@ -956,3 +974,12 @@ def load_weights(self, weights: Iterable[tuple[str,
mapper=self.hf_to_vllm_mapper)

return loaded_weights

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="merger.",
tower_model=["visual.", "audio_tower."])