Skip to content
5 changes: 5 additions & 0 deletions python/sglang/srt/eplb/expert_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def init_new(
rank: int,
):
if server_args.expert_distribution_recorder_mode is not None:
assert (
expert_location_metadata is not None
), "ExpertLocationMetadata is required for expert distribution recording. One possible"
"reason is that you are using a model that does not support expert distribution"
"recording. Try setting `get_model_config_for_expert_location` in your model."
return _ExpertDistributionRecorderReal(
server_args, expert_location_metadata, rank
)
Expand Down
23 changes: 17 additions & 6 deletions python/sglang/srt/eplb/expert_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __post_init__(self):
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
"""Trivial location - logical expert i corresponds to physical expert i"""
common = ExpertLocationMetadata._init_common(server_args, model_config)

if common is None:
return None

num_physical_experts = common["num_physical_experts"]
model_config_for_expert_location = common["model_config_for_expert_location"]
num_layers = model_config_for_expert_location.num_layers
Expand Down Expand Up @@ -109,6 +113,10 @@ def init_by_mapping(
physical_to_logical_map = physical_to_logical_map.to(server_args.device)

common = ExpertLocationMetadata._init_common(server_args, model_config)

if common is None:
return None

model_config_for_expert_location = common["model_config_for_expert_location"]
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
physical_to_logical_map,
Expand All @@ -133,6 +141,10 @@ def init_by_eplb(
logical_count = logical_count.to(server_args.device)

common = ExpertLocationMetadata._init_common(server_args, model_config)

if common is None:
return None

model_config_for_expert_location = common["model_config_for_expert_location"]
num_physical_experts = common["num_physical_experts"]
num_groups = model_config_for_expert_location.num_groups
Expand Down Expand Up @@ -168,6 +180,9 @@ def _init_common(server_args: ServerArgs, model_config: ModelConfig):
ModelConfigForExpertLocation.from_model_config(model_config)
)

if model_config_for_expert_location is None:
return None

num_physical_experts = (
model_config_for_expert_location.num_logical_experts
+ server_args.ep_num_redundant_experts
Expand Down Expand Up @@ -398,10 +413,6 @@ class ModelConfigForExpertLocation:
num_logical_experts: int
num_groups: Optional[int] = None

@staticmethod
def init_dummy():
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)

@staticmethod
def from_model_config(model_config: ModelConfig):
model_class, _ = get_model_architecture(model_config)
Expand All @@ -410,12 +421,12 @@ def from_model_config(model_config: ModelConfig):
model_config.hf_config
)
else:
return ModelConfigForExpertLocation.init_dummy()
return None


def compute_initial_expert_location_metadata(
server_args: ServerArgs, model_config: ModelConfig
) -> ExpertLocationMetadata:
) -> Optional[ExpertLocationMetadata]:
data = server_args.init_expert_location
if data == "trivial":
return ExpertLocationMetadata.init_trivial(server_args, model_config)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/eplb/expert_location_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo:
def init_new(cls, layer_id: int):
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
expert_location_metadata = get_global_expert_location_metadata()
assert expert_location_metadata is not None

if ep_dispatch_algorithm is None:
return None
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/eplb/expert_location_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def update(
torch.cuda.empty_cache()

old_expert_location_metadata = get_global_expert_location_metadata()
assert old_expert_location_metadata is not None

_update_expert_weights(
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
old_expert_location_metadata=old_expert_location_metadata,
Expand Down
19 changes: 16 additions & 3 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(
hidden_size: int,
intermediate_size: int,
layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
Expand All @@ -196,6 +197,7 @@ def __init__(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
top_k=top_k,
num_fused_shared_experts=num_fused_shared_experts,
layer_id=layer_id,
params_dtype=params_dtype,
quant_config=quant_config,
Expand Down Expand Up @@ -728,10 +730,19 @@ def weight_loader(
shard_id: str,
expert_id: int,
) -> None:
physical_expert_ids = (
get_global_expert_location_metadata().logical_to_all_physical(
self.layer_id, expert_id
global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None:
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
return

physical_expert_ids = global_expert_location_metadata.logical_to_all_physical(
self.layer_id, expert_id
)
for physical_expert_id in physical_expert_ids:
self._weight_loader_physical(
Expand Down Expand Up @@ -778,6 +789,7 @@ def __init__(
hidden_size: int,
intermediate_size: int,
layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
Expand All @@ -792,6 +804,7 @@ def __init__(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
layer_id=layer_id,
num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype,
quant_config=quant_config,
tp_size=tp_size,
Expand Down
45 changes: 44 additions & 1 deletion python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
Expand Down Expand Up @@ -62,8 +63,9 @@ def __init__(
num_experts: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
top_k: Optional[int] = None,
layer_id: Optional[int] = None,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
quant_config: Optional[QuantizationConfig] = None,
Expand All @@ -84,13 +86,15 @@ def __init__(
if params_dtype is None:
params_dtype = torch.get_default_dtype()

self.layer_id = layer_id
self.top_k = top_k
self.hidden_size = hidden_size
self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
self.tp_rank = get_tensor_model_parallel_rank()
self.num_experts = num_experts
self.num_fused_shared_experts = num_fused_shared_experts
self.expert_map = None

if enable_flashinfer_cutlass_moe and quant_config is None:
Expand Down Expand Up @@ -375,6 +379,45 @@ def weight_loader(
shard_id: str,
expert_id: int,
) -> None:

global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None:
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
return

if expert_id >= self.num_experts - self.num_fused_shared_experts:
# This is a shared expert.
physical_expert_ids = [expert_id]
else:
physical_expert_ids = (
global_expert_location_metadata.logical_to_all_physical(
self.layer_id, expert_id
)
)

for physical_expert_id in physical_expert_ids:
self._weight_loader_physical(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=physical_expert_id,
)

def _weight_loader_physical(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def __init__(
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
Expand Down Expand Up @@ -2112,6 +2113,7 @@ def determine_num_fused_shared_experts(

if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,
f"{disable_reason} Shared experts fusion optimization is disabled.",
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ def __init__(
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
Expand Down Expand Up @@ -740,10 +741,11 @@ def determine_num_fused_shared_experts(
global_server_args_dict["enable_deepep_moe"]
or global_server_args_dict["enable_ep_moe"]
):
disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."

if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,
f"{disable_reason} Shared experts fusion optimization is disabled.",
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/models/granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
top_k: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
layer_id=layer_id,
params_dtype=params_dtype,
reduce_results=True,
quant_config=quant_config,
Expand Down Expand Up @@ -203,6 +205,7 @@ def __init__(
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
layer_id=layer_id,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe",
)
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/models/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_id: int,
num_experts: int,
top_k: int,
hidden_size: int,
Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(
self.experts = MoEImpl(
num_experts=num_experts,
top_k=top_k,
layer_id=layer_id,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
Expand Down Expand Up @@ -331,6 +333,7 @@ def __init__(
)
self.block_sparse_moe = Grok1MoE(
config=config,
layer_id=layer_id,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/models/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
reduce_results=False,
layer_id=layer_id,
quant_config=quant_config,
)

Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def custom_routing_function(
def __init__(
self,
config: Llama4TextConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
num_experts=config.num_local_experts,
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
layer_id=layer_id,
reduce_results=False,
quant_config=quant_config,
apply_router_weight_on_input=True,
Expand Down Expand Up @@ -373,6 +375,7 @@ def __init__(
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("feed_forward", prefix),
)
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
top_k: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
self.experts = MoEImpl(
num_experts=num_experts,
top_k=top_k,
layer_id=layer_id,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
Expand Down Expand Up @@ -226,6 +228,7 @@ def __init__(
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("block_sparse_moe", prefix),
)
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/models/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
layer_id: int = 0,
prefix: str = "",
):
super().__init__()
Expand All @@ -89,6 +90,7 @@ def __init__(
reduce_results=True,
quant_config=quant_config,
tp_size=tp_size,
layer_id=layer_id,
prefix=add_prefix("experts", prefix),
)

Expand Down Expand Up @@ -224,6 +226,7 @@ def __init__(
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
Expand Down
Loading
Loading