From 0b4bd65a4b4421e0d38492541187abc8082e7a02 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Sun, 4 Jan 2026 00:00:43 +0200 Subject: [PATCH] [BugFix] LoRA: Support loading base_layer of experts Signed-off-by: Hollow Man --- vllm/model_executor/layers/fused_moe/layer.py | 13 ++++++++++--- vllm/model_executor/models/afmoe.py | 1 + vllm/model_executor/models/bailing_moe.py | 1 + vllm/model_executor/models/deepseek_eagle.py | 1 + vllm/model_executor/models/deepseek_mtp.py | 1 + vllm/model_executor/models/deepseek_v2.py | 2 ++ vllm/model_executor/models/dots1.py | 1 + vllm/model_executor/models/ernie45_moe.py | 1 + vllm/model_executor/models/ernie45_vl_moe.py | 1 + vllm/model_executor/models/glm4_moe.py | 1 + vllm/model_executor/models/glm4_moe_mtp.py | 1 + vllm/model_executor/models/gpt_oss.py | 1 + vllm/model_executor/models/granitemoe.py | 1 + vllm/model_executor/models/grok1.py | 1 + vllm/model_executor/models/hunyuan_v1.py | 1 + vllm/model_executor/models/jamba.py | 1 + vllm/model_executor/models/kimi_linear.py | 1 + vllm/model_executor/models/kimi_vl.py | 1 + vllm/model_executor/models/lfm2_moe.py | 1 + vllm/model_executor/models/llama4.py | 2 ++ vllm/model_executor/models/longcat_flash.py | 1 + vllm/model_executor/models/mimo_v2_flash.py | 1 + vllm/model_executor/models/minimax_m2.py | 1 + vllm/model_executor/models/mixtral.py | 1 + vllm/model_executor/models/mllama4.py | 1 + vllm/model_executor/models/nemotron_h.py | 1 + vllm/model_executor/models/olmoe.py | 1 + vllm/model_executor/models/openpangu.py | 1 + vllm/model_executor/models/openpangu_mtp.py | 1 + vllm/model_executor/models/phimoe.py | 1 + vllm/model_executor/models/qwen2_moe.py | 1 + vllm/model_executor/models/qwen3_moe.py | 1 + vllm/model_executor/models/qwen3_next.py | 1 + vllm/model_executor/models/qwen3_next_mtp.py | 1 + vllm/model_executor/models/transformers/moe.py | 1 + 35 files changed, 46 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 323e0ee09fc9..48c3ef83dbd5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2002,6 +2002,7 @@ def combine_output(states: torch.Tensor) -> torch.Tensor: @classmethod def make_expert_params_mapping( cls, + model: torch.nn.Module, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, ckpt_up_proj_name: str, @@ -2020,13 +2021,19 @@ def make_expert_params_mapping( ) ) + base_layer = ( + "base_layer." + if any(".base_layer." in name for name, _ in model.named_parameters()) + else "" + ) + return [ # (param_name, weight_name, expert_id, shard_id) ( - "experts.w13_" + f"experts.{base_layer}w13_" if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] - else "experts.w2_", - f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.", + else f"experts.{base_layer}w2_", + f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.{base_layer}", expert_id, shard_id, ) diff --git a/vllm/model_executor/models/afmoe.py b/vllm/model_executor/models/afmoe.py index f5dfe4306741..f4248b67f734 100644 --- a/vllm/model_executor/models/afmoe.py +++ b/vllm/model_executor/models/afmoe.py @@ -475,6 +475,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 4bccee752174..e1e675bd5a05 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -476,6 +476,7 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 8f6b4a4b021f..5c439cdf486d 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -106,6 +106,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index c25e8422da15..b8b73bd24f6a 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -245,6 +245,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ] expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b22cdb6d6c80..7f1880e44bd8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1486,6 +1486,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", @@ -1519,6 +1520,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 870a37039f15..b64f163761c8 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -424,6 +424,7 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index fbbd31a48538..8c8cb73b8d6e 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -497,6 +497,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 72f9957fc882..75be587eedb2 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -675,6 +675,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 6fb09be7c67f..0d4111cbad70 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -495,6 +495,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index e34ae6c85a4f..22e62311746d 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -248,6 +248,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 6a92cf153321..8a8df9f6ed95 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -729,6 +729,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, weight scales, activation scales # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 0b1064b6343e..237fabff98f7 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -353,6 +353,7 @@ def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str] # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 0a2e5cf39ffd..2c41f2a1123b 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -369,6 +369,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Grok1 uses "num_experts" in its config num_experts = getattr(self.config, "num_experts", 8) return FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="linear", # Grok1 specific ckpt_down_proj_name="linear_1", # Grok1 specific ckpt_up_proj_name="linear_v", # Grok1 specific diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 0e82e84c4edb..787f6c674d36 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -705,6 +705,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index b2ad12be1e35..34e9feb77855 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -377,6 +377,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/kimi_linear.py b/vllm/model_executor/models/kimi_linear.py index 4562b2202c5e..d149c3642406 100644 --- a/vllm/model_executor/models/kimi_linear.py +++ b/vllm/model_executor/models/kimi_linear.py @@ -560,6 +560,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 85267ccda8a9..a52e17283327 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -462,6 +462,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py index 70804e0a843e..6677eb9f93e8 100644 --- a/vllm/model_executor/models/lfm2_moe.py +++ b/vllm/model_executor/models/lfm2_moe.py @@ -486,6 +486,7 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 7b3da3e10ab8..9ed0741acba1 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -539,6 +539,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Expert parameter mapping for the case where the expert weights are # not fused into a single weight tensor. expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", @@ -548,6 +549,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Expert parameter mapping for the case where the expert weights are # fused into a single weight tensor. expert_params_mapping_fused = SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_up_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="gate_up_proj", diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py index 774737387639..fed3a1caee8f 100644 --- a/vllm/model_executor/models/longcat_flash.py +++ b/vllm/model_executor/models/longcat_flash.py @@ -626,6 +626,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/mimo_v2_flash.py b/vllm/model_executor/models/mimo_v2_flash.py index 12b486f001e0..2649d27b742e 100644 --- a/vllm/model_executor/models/mimo_v2_flash.py +++ b/vllm/model_executor/models/mimo_v2_flash.py @@ -503,6 +503,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index 822bf9b5c93a..292969db6d03 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -392,6 +392,7 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e170c530ca29..89dab5f3cb8e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -366,6 +366,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 886d5151e43f..aeea4a140465 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -1084,6 +1084,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 8bc9ce6154d9..32b6326be429 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -640,6 +640,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # what the activation is applied to # - FusedMoe.w3 (aka up_proj) should be ignored since we're # using non-gated MoE + self, ckpt_gate_proj_name="up_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="", diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index a5a926151c5c..0697dfc015bb 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -334,6 +334,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index 662ecef3ac8f..44e3baee0206 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -1161,6 +1161,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: has_experts = hasattr(self.config, "n_routed_experts") if has_experts: expert_merge_mapping = SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/openpangu_mtp.py b/vllm/model_executor/models/openpangu_mtp.py index 436b7f981b1f..e2cea29c2cba 100644 --- a/vllm/model_executor/models/openpangu_mtp.py +++ b/vllm/model_executor/models/openpangu_mtp.py @@ -149,6 +149,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ] expert_params_mapping = FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 14f73d0c6458..53951cd65ea8 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -515,6 +515,7 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 82837b77e537..fbfd681d59e5 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -423,6 +423,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 0be81ecc7dd3..f2f3546047aa 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -470,6 +470,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index ccf6cc6e5894..e8799ac75e3f 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -1030,6 +1030,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return SharedFusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index 83694caa5248..c07ed593240a 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -147,6 +147,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/transformers/moe.py b/vllm/model_executor/models/transformers/moe.py index 31db9d682bd4..e9b53e7e926c 100644 --- a/vllm/model_executor/models/transformers/moe.py +++ b/vllm/model_executor/models/transformers/moe.py @@ -165,6 +165,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: for gate_proj, down_proj, up_proj in ckpt_names: expert_mapping.extend( FusedMoE.make_expert_params_mapping( + self, ckpt_gate_proj_name=gate_proj, ckpt_down_proj_name=down_proj, ckpt_up_proj_name=up_proj,