Skip to content

Commit 964d65d

Browse files
kfhfarfrank-wei
andauthored
LLaMA4 LoRA Adapter Enablement (#28602)
Signed-off-by: Fardin Hoque <[email protected]> Co-authored-by: Wei Wei <[email protected]>
1 parent 9261eb3 commit 964d65d

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

vllm/model_executor/models/mllama4.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from vllm.config import VllmConfig
3636
from vllm.config.multimodal import BaseDummyOptions
3737
from vllm.distributed import get_tensor_model_parallel_world_size
38+
from vllm.model_executor.layers.fused_moe import FusedMoE
3839
from vllm.model_executor.layers.linear import (
3940
ColumnParallelLinear,
4041
QKVParallelLinear,
@@ -45,6 +46,7 @@
4546
from vllm.model_executor.layers.rotary_embedding import get_rope
4647
from vllm.model_executor.model_loader.utils import initialize_model
4748
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
49+
from vllm.model_executor.models.module_mapping import MultiModelKeys
4850
from vllm.multimodal import MULTIMODAL_REGISTRY
4951
from vllm.multimodal.inputs import (
5052
MultiModalDataDict,
@@ -68,11 +70,15 @@
6870
MixtureOfExperts,
6971
MultiModalEmbeddings,
7072
SupportsEagle3,
73+
SupportsLoRA,
7174
SupportsMultiModal,
7275
SupportsPP,
7376
)
7477
from .llama4 import Llama4ForCausalLM
75-
from .utils import AutoWeightsLoader, maybe_prefix
78+
from .utils import (
79+
AutoWeightsLoader,
80+
maybe_prefix,
81+
)
7682
from .vision import run_dp_sharded_vision_model
7783

7884

@@ -724,7 +730,12 @@ def get_dummy_mm_data(
724730
dummy_inputs=Mllama4DummyInputsBuilder,
725731
)
726732
class Llama4ForConditionalGeneration(
727-
nn.Module, SupportsMultiModal, SupportsPP, MixtureOfExperts, SupportsEagle3
733+
nn.Module,
734+
SupportsMultiModal,
735+
SupportsPP,
736+
MixtureOfExperts,
737+
SupportsEagle3,
738+
SupportsLoRA,
728739
):
729740
merge_by_field_config = True
730741

@@ -1067,6 +1078,17 @@ def _load_other_weights(
10671078

10681079
return updated_params
10691080

1081+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
1082+
# Params for weights, fp8 weight scales, fp8 activation scales
1083+
# (param_name, weight_name, expert_id, shard_id)
1084+
return FusedMoE.make_expert_params_mapping(
1085+
ckpt_gate_proj_name="gate_proj",
1086+
ckpt_down_proj_name="down_proj",
1087+
ckpt_up_proj_name="up_proj",
1088+
num_experts=self.config.text_config.num_local_experts,
1089+
num_redundant_experts=self.num_redundant_experts,
1090+
)
1091+
10701092
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
10711093
stacked_params_mapping = [
10721094
# (param_name, shard_name, shard_id)
@@ -1113,3 +1135,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
11131135
)
11141136

11151137
return updated_params
1138+
1139+
def get_mm_mapping(self) -> MultiModelKeys:
1140+
"""
1141+
Get the module prefix in multimodal models
1142+
"""
1143+
return MultiModelKeys.from_string_field(
1144+
language_model="language_model",
1145+
connector="multi_modal_projector.",
1146+
tower_model="vision_model.",
1147+
)

0 commit comments

Comments
 (0)