Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,22 +1342,41 @@ def load_weights(
weight_name = qual_name.replace(weight_name, param_name)
param_name = weight_name.removeprefix(f"{self.layer_name}.")
param = getattr(self, param_name)
success = self.weight_loader(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
logger.debug(
"Loaded %s for expert %d into %s",
param_name,
expert_id,
self.layer_name,
# Fused expert weights can be identified by their 3D tensors
if loaded_weight.dim() == 3:
# Repurpose expert_id as shard_idx for deconcatenating w1 and w3
if shard_id in {"w1", "w3"}:
shard_idx = expert_id
experts_shard = loaded_weight.chunk(2, dim=1)[shard_idx]
else:
experts_shard = loaded_weight
start = 0
else:
# loaded_weight is a single expert weight, so we add a dummy expert
# dimension to unify the loading logic with the fused case
experts_shard = loaded_weight.unsqueeze(0)
start = expert_id

# Unified loading logic for fused and non-fused experts
loaded_experts = experts_shard.unbind()
for expert_id, loaded_expert in enumerate(loaded_experts, start=start):
success = self.weight_loader(
param=param,
loaded_weight=loaded_expert,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
yield param_name
if success:
logger.debug(
"Loaded expert %d of shard %s into %s for layer %s",
expert_id,
shard_id,
param_name,
self.layer_name,
)
yield param_name

def get_expert_weights(self) -> Iterable[torch.Tensor]:
def _maybe_make_contiguous(
Expand Down
12 changes: 11 additions & 1 deletion vllm/model_executor/models/transformers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
Params for weights, fp8 weight scales, fp8 activation scales
(param_name, weight_name, expert_id, shard_id)
"""
# Models saved with fused experts. These are checkpoints released:
# - After Transformers v5
# - Before Transformers v5, but re-saved with save_original_format=False
# In the fused experts case, we repurpose the expert_id as shard_idx for
# deconcatenating w1 and w3 in FusedMoE.load_weights.
expert_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w13_weight", "experts.gate_up_proj", 1, "w3"),
("experts.w2_weight", "experts.down_proj", 0, "w2"),
]
# Models saved with ModuleList experts
ckpt_names = [
# (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name)
("gate_proj", "down_proj", "up_proj"), # Most common MoE style
Expand All @@ -164,7 +175,6 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
]
num_experts = self.model_config.get_num_experts()
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
expert_mapping = []
for gate_proj, down_proj, up_proj in ckpt_names:
expert_mapping.extend(
FusedMoE.make_expert_params_mapping(
Expand Down
Loading