Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

import copy
import math
import os
from collections.abc import Sequence
Expand Down Expand Up @@ -34,6 +33,7 @@
from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.model_executor.utils import get_packed_modules_mapping
from vllm.utils import is_pin_memory_available

logger = init_logger(__name__)
Expand Down Expand Up @@ -364,8 +364,8 @@ def __init__(
# We need to replace rotary emb layer to do batch computation
# for long lora.
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)

self.packed_modules_mapping = get_packed_modules_mapping(self.model)
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/model_loader/bitsandbytes_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: SIM117
import copy
import fnmatch
import glob
import itertools
Expand Down Expand Up @@ -36,7 +35,8 @@
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models import is_pooling_model
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.utils import (get_packed_modules_mapping,
set_weight_attrs)
from vllm.platforms import current_platform

logger = init_logger(__name__)
Expand Down Expand Up @@ -420,8 +420,8 @@ def _load_weights(self, model_config: ModelConfig,
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found.")
self.is_pool_model=is_pooling_model(model)
self.modules_mapping = ParamMapping(
copy.deepcopy(model.packed_modules_mapping))

self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))

# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/models/intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,10 @@ def forward(self, inputs_embeds: torch.Tensor):

class InternVisionModel(nn.Module):

packed_modules_mapping = {
"qkv": ["qkv"],
}

def __init__(
self,
config: PretrainedConfig,
Expand Down
9 changes: 0 additions & 9 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,15 +1019,6 @@ def get_video_replacement_internvl(item_idx: int):
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):

packed_modules_mapping = {
"wqkv": ["wqkv"],
"qkv": ["qkv"],
"gate_up_proj": [
"w1",
"w3",
],
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()

Expand Down
11 changes: 0 additions & 11 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,17 +821,6 @@ def _get_mm_fields_config(
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
Expand Down
11 changes: 0 additions & 11 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,17 +1069,6 @@ def _get_mm_fields_config(
dummy_inputs=Qwen2VLDummyInputsBuilder)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
Expand Down
21 changes: 21 additions & 0 deletions vllm/model_executor/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Utils for model executor."""
import copy
from typing import Any, Optional

import torch
Expand Down Expand Up @@ -51,3 +52,23 @@ def _synced_weight_loader(param, *args, **kwargs):
torch._sync(param)

return _synced_weight_loader


def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
parent_map = copy.deepcopy(getattr(model, "packed_modules_mapping", {}))

# don't infer mapping if the model has defined it explicitly.
if parent_map:
return parent_map

# We only check main components instead of whole model submodules
for child in model.children():
child_map = getattr(child, "packed_modules_mapping", {})
if any((k in parent_map and parent_map[k] != v)
for k, v in child_map.items()):
raise ValueError(
f"Can't update {type(model).__name__}'s packed_modules_mapping "
f"safely because of conflicts from {type(child).__name__}.")
else:
parent_map.update(child_map)
return parent_map