|
35 | 35 | from vllm.config import VllmConfig |
36 | 36 | from vllm.config.multimodal import BaseDummyOptions |
37 | 37 | from vllm.distributed import get_tensor_model_parallel_world_size |
| 38 | +from vllm.model_executor.layers.fused_moe import FusedMoE |
38 | 39 | from vllm.model_executor.layers.linear import ( |
39 | 40 | ColumnParallelLinear, |
40 | 41 | QKVParallelLinear, |
|
45 | 46 | from vllm.model_executor.layers.rotary_embedding import get_rope |
46 | 47 | from vllm.model_executor.model_loader.utils import initialize_model |
47 | 48 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
| 49 | +from vllm.model_executor.models.module_mapping import MultiModelKeys |
48 | 50 | from vllm.multimodal import MULTIMODAL_REGISTRY |
49 | 51 | from vllm.multimodal.inputs import ( |
50 | 52 | MultiModalDataDict, |
|
68 | 70 | MixtureOfExperts, |
69 | 71 | MultiModalEmbeddings, |
70 | 72 | SupportsEagle3, |
| 73 | + SupportsLoRA, |
71 | 74 | SupportsMultiModal, |
72 | 75 | SupportsPP, |
73 | 76 | ) |
74 | 77 | from .llama4 import Llama4ForCausalLM |
75 | | -from .utils import AutoWeightsLoader, maybe_prefix |
| 78 | +from .utils import ( |
| 79 | + AutoWeightsLoader, |
| 80 | + maybe_prefix, |
| 81 | +) |
76 | 82 | from .vision import run_dp_sharded_vision_model |
77 | 83 |
|
78 | 84 |
|
@@ -724,7 +730,12 @@ def get_dummy_mm_data( |
724 | 730 | dummy_inputs=Mllama4DummyInputsBuilder, |
725 | 731 | ) |
726 | 732 | class Llama4ForConditionalGeneration( |
727 | | - nn.Module, SupportsMultiModal, SupportsPP, MixtureOfExperts, SupportsEagle3 |
| 733 | + nn.Module, |
| 734 | + SupportsMultiModal, |
| 735 | + SupportsPP, |
| 736 | + MixtureOfExperts, |
| 737 | + SupportsEagle3, |
| 738 | + SupportsLoRA, |
728 | 739 | ): |
729 | 740 | merge_by_field_config = True |
730 | 741 |
|
@@ -1067,6 +1078,17 @@ def _load_other_weights( |
1067 | 1078 |
|
1068 | 1079 | return updated_params |
1069 | 1080 |
|
| 1081 | + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: |
| 1082 | + # Params for weights, fp8 weight scales, fp8 activation scales |
| 1083 | + # (param_name, weight_name, expert_id, shard_id) |
| 1084 | + return FusedMoE.make_expert_params_mapping( |
| 1085 | + ckpt_gate_proj_name="gate_proj", |
| 1086 | + ckpt_down_proj_name="down_proj", |
| 1087 | + ckpt_up_proj_name="up_proj", |
| 1088 | + num_experts=self.config.text_config.num_local_experts, |
| 1089 | + num_redundant_experts=self.num_redundant_experts, |
| 1090 | + ) |
| 1091 | + |
1070 | 1092 | def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
1071 | 1093 | stacked_params_mapping = [ |
1072 | 1094 | # (param_name, shard_name, shard_id) |
@@ -1113,3 +1135,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
1113 | 1135 | ) |
1114 | 1136 |
|
1115 | 1137 | return updated_params |
| 1138 | + |
| 1139 | + def get_mm_mapping(self) -> MultiModelKeys: |
| 1140 | + """ |
| 1141 | + Get the module prefix in multimodal models |
| 1142 | + """ |
| 1143 | + return MultiModelKeys.from_string_field( |
| 1144 | + language_model="language_model", |
| 1145 | + connector="multi_modal_projector.", |
| 1146 | + tower_model="vision_model.", |
| 1147 | + ) |
0 commit comments