Skip to content

Commit 860c924

Browse files
author
白永斌
committed
extract load_expert_weight method
Signed-off-by: 白永斌 <[email protected]>
1 parent a7c1b8a commit 860c924

File tree

4 files changed

+68
-89
lines changed

4 files changed

+68
-89
lines changed

vllm/distributed/eplb/model_register_gpu.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
import torch
77
from vllm.model_executor.layers.fused_moe import FusedMoE
8-
8+
from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter,
9+
make_empty_intermediate_tensors_factory, make_layers,
10+
maybe_prefix)
11+
import typing
12+
from typing import Callable
913
def set_eplb_state(
1014
self,
1115
expert_load_view: torch.Tensor,
@@ -50,6 +54,55 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
5054
num_experts=self.config.n_routed_experts
5155
num_redundant_experts=self.num_redundant_experts)
5256

57+
def load_expert_weight(self, mapping, loaded_weight, params_dict):
58+
ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale",
59+
".v_scale", "_v_scale", ".weight_scale",
60+
"_weight_scale", ".input_scale", "_input_scale")
61+
62+
is_continue = False
63+
is_expert_weight = False
64+
success = False
65+
66+
param_name, weight_name, expert_id, shard_id = mapping
67+
if weight_name not in name:
68+
is_continue = True
69+
return is_continue, is_expert_weight, success
70+
71+
# Anyway, this is an expert weight and should not be
72+
# attempted to load as other weights later
73+
is_expert_weight = True
74+
75+
# Do not modify `name` since the loop may continue here
76+
# Instead, create a new variable
77+
name_mapped = name.replace(weight_name, param_name)
78+
79+
if is_pp_missing_parameter(name_mapped, self):
80+
is_continue = True
81+
return is_continue, is_expert_weight, success
82+
83+
# Skip loading extra parameters for GPTQ/modelopt models.
84+
if name_mapped.endswith(
85+
ignore_suffixes
86+
) and name_mapped not in params_dict:
87+
is_continue = True
88+
return is_continue, is_expert_weight, success
89+
90+
param = params_dict[name_mapped]
91+
# We should ask the weight loader to return success or not
92+
# here since otherwise we may skip experts with other
93+
# available replicas.
94+
weight_loader = typing.cast(Callable[..., bool],
95+
param.weight_loader)
96+
success = weight_loader(param,
97+
loaded_weight,
98+
name_mapped,
99+
shard_id=shard_id,
100+
expert_id=expert_id,
101+
return_success=True)
102+
if success:
103+
name = name_mapped
104+
return is_continue, is_expert_weight, success
105+
53106
def model_register(model):
54107
"""
55108
Registers custom methods related to Expert Parallel Load Balancing (EPLB)
@@ -60,7 +113,8 @@ def model_register(model):
60113
model: The vLLM model instance to which the methods will be added.
61114
"""
62115
model.set_eplb_state = types.MethodType(set_eplb_state, model)
116+
model.load_expert_weight = types.MethodType(load_expert_weight, model)
63117
model.update_physical_experts_metadata = \
64-
types.MethodType(update_physical_experts_metadata, model)
118+
types.MethodType(update_physical_experts_metadata, model)
65119
model.model.get_expert_mapping = \
66120
types.MethodType(get_expert_mapping, model.model)

vllm/model_executor/models/deepseek_v2.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -976,35 +976,12 @@ def load_weights(self, weights: Iterable[tuple[str,
976976
else:
977977
is_expert_weight = False
978978
for mapping in expert_params_mapping:
979-
param_name, weight_name, expert_id, shard_id = mapping
980-
if weight_name not in name:
979+
is_continue, is_expert_weight, success = \
980+
self.load_expert_weight(
981+
mapping, loaded_weight, params_dict)
982+
if is_continue:
981983
continue
982-
983-
# Anyway, this is an expert weight and should not be
984-
# attempted to load as other weights later
985-
is_expert_weight = True
986-
987-
# Do not modify `name` since the loop may continue here
988-
# Instead, create a new variable
989-
name_mapped = name.replace(weight_name, param_name)
990-
991-
if is_pp_missing_parameter(name_mapped, self):
992-
continue
993-
994-
param = params_dict[name_mapped]
995-
# We should ask the weight loader to return success or not
996-
# here since otherwise we may skip experts with other
997-
# available replicas.
998-
weight_loader = typing.cast(Callable[..., bool],
999-
param.weight_loader)
1000-
success = weight_loader(param,
1001-
loaded_weight,
1002-
name_mapped,
1003-
shard_id=shard_id,
1004-
expert_id=expert_id,
1005-
return_success=True)
1006984
if success:
1007-
name = name_mapped
1008985
break
1009986
else:
1010987
if is_expert_weight:

vllm/model_executor/models/glm4_moe.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -515,35 +515,12 @@ def load_weights(self, weights: Iterable[tuple[str,
515515
else:
516516
is_expert_weight = False
517517
for mapping in expert_params_mapping:
518-
param_name, weight_name, expert_id, shard_id = mapping
519-
if weight_name not in name:
518+
is_continue, is_expert_weight, success = \
519+
self.load_expert_weight(
520+
mapping, loaded_weight, params_dict)
521+
if is_continue:
520522
continue
521-
522-
# Anyway, this is an expert weight and should not be
523-
# attempted to load as other weights later
524-
is_expert_weight = True
525-
526-
# Do not modify `name` since the loop may continue here
527-
# Instead, create a new variable
528-
name_mapped = name.replace(weight_name, param_name)
529-
530-
if is_pp_missing_parameter(name_mapped, self):
531-
continue
532-
533-
param = params_dict[name_mapped]
534-
# We should ask the weight loader to return success or not
535-
# here since otherwise we may skip experts with other
536-
# available replicas.
537-
weight_loader = typing.cast(Callable[..., bool],
538-
param.weight_loader)
539-
success = weight_loader(param,
540-
loaded_weight,
541-
name_mapped,
542-
shard_id=shard_id,
543-
expert_id=expert_id,
544-
return_success=True)
545523
if success:
546-
name = name_mapped
547524
break
548525
else:
549526
if is_expert_weight:

vllm/model_executor/models/qwen3_moe.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -498,41 +498,12 @@ def load_weights(self, weights: Iterable[tuple[str,
498498
else:
499499
is_expert_weight = False
500500
for mapping in expert_params_mapping:
501-
param_name, weight_name, expert_id, shard_id = mapping
502-
if weight_name not in name:
501+
is_continue, is_expert_weight, success = \
502+
self.load_expert_weight(
503+
mapping, loaded_weight, params_dict)
504+
if is_continue:
503505
continue
504-
505-
# Anyway, this is an expert weight and should not be
506-
# attempted to load as other weights later
507-
is_expert_weight = True
508-
509-
# Do not modify `name` since the loop may continue here
510-
# Instead, create a new variable
511-
name_mapped = name.replace(weight_name, param_name)
512-
513-
if is_pp_missing_parameter(name_mapped, self):
514-
continue
515-
516-
# Skip loading extra parameters for GPTQ/modelopt models.
517-
if name_mapped.endswith(
518-
ignore_suffixes
519-
) and name_mapped not in params_dict:
520-
continue
521-
522-
param = params_dict[name_mapped]
523-
# We should ask the weight loader to return success or not
524-
# here since otherwise we may skip experts with other
525-
# available replicas.
526-
weight_loader = typing.cast(Callable[..., bool],
527-
param.weight_loader)
528-
success = weight_loader(param,
529-
loaded_weight,
530-
name_mapped,
531-
shard_id=shard_id,
532-
expert_id=expert_id,
533-
return_success=True)
534506
if success:
535-
name = name_mapped
536507
break
537508
else:
538509
if is_expert_weight:

0 commit comments

Comments
 (0)