Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions python/sglang/srt/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_experts=self.config.num_experts,
)

params_dict = dict(self.named_parameters())
# Cache params_dict to avoid repeated expensive traversal of model parameters
if not hasattr(self, "_cached_params_dict"):
self._cached_params_dict = dict(self.named_parameters())
params_dict = self._cached_params_dict
for name, loaded_weight in weights:
layer_id = get_layer_id(name)
if (
Expand Down Expand Up @@ -805,11 +808,22 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader(param, loaded_weight, shard_id)
break
else:
# Track if this is an expert weight to enable early skipping
is_expert_weight = False

for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue

# Mark as expert weight regardless of whether we can process it
is_expert_weight = True

name = name.replace(weight_name, param_name)
if name not in params_dict:
# Expert weight not on this rank, will be skipped below
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
Expand All @@ -821,6 +835,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
)
break
else:
if is_expert_weight:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
Expand All @@ -837,11 +855,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
logger.warning(f"Parameter {name} not found in params_dict")

# TODO mimic deepseek
self.routed_experts_weights_of_layer = {
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
for layer_id in range(self.start_layer, self.end_layer)
if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
}
# Lazy initialization of expert weights cache to avoid slowing down load_weights
if not hasattr(self, "routed_experts_weights_of_layer"):
self.routed_experts_weights_of_layer = {
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
for layer_id in range(self.start_layer, self.end_layer)
if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
}

@classmethod
def get_model_config_for_expert_location(cls, config):
Expand Down
Loading