Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,9 @@ def forward_deepgemm_masked(
def get_moe_impl_class():
if global_server_args_dict["enable_deepep_moe"]:
return DeepEPMoE
if global_server_args_dict["enable_flashinfer_moe"]:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
return FusedMoE
if global_server_args_dict["enable_ep_moe"]:
return EPMoE
return FusedMoE
72 changes: 62 additions & 10 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ def __init__(
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False,
):
super().__init__()

Expand All @@ -324,9 +326,34 @@ def __init__(
self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
self.tp_rank = get_tensor_model_parallel_rank()
self.num_experts = num_experts
self.expert_map = None
self.enable_flashinfer_moe = enable_flashinfer_moe
if enable_ep_moe:
assert (
self.enable_flashinfer_moe
), "FusedMoE only supports EP with --enable-flashinfer-moe"
self.ep_size = self.tp_size
self.ep_rank = self.tp_rank
self.tp_size = 1
self.tp_rank = 0
# Create a tensor of size num_experts filled with -1
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
# Create a expert map for the local experts
assert num_experts % self.ep_size == 0
self.local_num_experts = num_experts // self.ep_size
self.expert_map[
self.ep_rank
* self.local_num_experts : (self.ep_rank + 1)
* self.local_num_experts
] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
else:
self.ep_size = 1
self.ep_rank = 0
self.local_num_experts = num_experts
self.routed_scaling_factor = routed_scaling_factor
self.top_k = top_k
self.num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
Expand All @@ -344,19 +371,20 @@ def __init__(
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
self.local_num_experts = num_experts

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod()
)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
assert self.quant_method is not None

self.quant_method.create_weights(
layer=self,
num_experts=num_experts,
num_experts=self.local_num_experts,
hidden_size=hidden_size,
# FIXME: figure out which intermediate_size to use
intermediate_size=self.intermediate_size_per_partition,
Expand Down Expand Up @@ -450,12 +478,15 @@ def _load_w13(

# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
# trtllm cutlass kernel assumes differently
assert shard_id in ("w1", "w3")
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
start = shard_size
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
start = 0
expert_data = expert_data.narrow(shard_dim, start, shard_size)
expert_data.copy_(loaded_weight)

def _load_w2(
Expand Down Expand Up @@ -509,6 +540,11 @@ def _load_g_idx(
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)

def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self.expert_map is None:
return expert_id
return self.expert_map[expert_id].item()

def weight_loader(
self,
param: torch.nn.Parameter,
Expand All @@ -517,6 +553,13 @@ def weight_loader(
shard_id: str,
expert_id: int,
) -> None:
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return

# TP rank is set to 0 if EP is enabled
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()

# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
Expand All @@ -541,15 +584,14 @@ def weight_loader(
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}

expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()

# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = ~shard_dim
shard_dim = int(not shard_dim)

# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
Expand Down Expand Up @@ -690,9 +732,19 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
**(
dict(
tp_rank=self.tp_rank,
tp_size=self.tp_size,
ep_rank=self.ep_rank,
ep_size=self.ep_size,
)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
else {}
),
)

if self.reduce_results and self.tp_size > 1:
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

return final_hidden_states
Expand Down
65 changes: 58 additions & 7 deletions python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_cuda
from sglang.srt.utils import is_cuda, next_power_of_2

if is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant

try:
from flashinfer import fp4_quantize as fp4_quantize
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
except ImportError:
flashinfer_cutlass_fused_moe = None

# Initialize logger for the module
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -521,6 +527,7 @@ def __init__(self, quant_config: ModelOptFp4Config):
" quantization. Please use Blackwell and"
" above."
)
self.enable_flashinfer_moe = False

def create_weights(
self,
Expand Down Expand Up @@ -674,7 +681,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)

w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
if self.enable_flashinfer_moe:
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
layer.g1_alphas = Parameter(
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
requires_grad=False,
Expand All @@ -700,14 +710,19 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)

# GEMM 2
if self.enable_flashinfer_moe:
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
else:
w2_input_scale = layer.w2_input_scale

layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False,
)

# This is for quantization, so we need to invert it.
layer.w2_input_scale_quant = Parameter(
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
(1 / w2_input_scale).to(torch.float32), requires_grad=False
)

assert (
Expand All @@ -727,11 +742,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.cutlass_moe_params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
device,
num_experts=layer.num_experts,
num_experts=layer.num_experts, # global num experts
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
hidden_size=layer.w13_weight.shape[2] * 2,
) # k

@property
def load_up_proj_weight_first(self) -> bool:
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
return self.enable_flashinfer_moe

def apply(
self,
layer: torch.nn.Module,
Expand All @@ -750,11 +770,13 @@ def apply(
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
) -> torch.Tensor:

assert activation == "silu", "Only SiLU activation is supported."

from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts

topk_weights, topk_ids = select_experts(
Expand All @@ -771,6 +793,35 @@ def apply(
routed_scaling_factor=routed_scaling_factor,
)

if self.enable_flashinfer_moe:
assert (
not apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
output = flashinfer_cutlass_fused_moe(
x,
topk_ids.to(torch.int),
topk_weights,
layer.w13_weight.view(torch.long),
layer.w2_weight.view(torch.long),
x.dtype,
quant_scales=[
layer.w13_input_scale_quant,
layer.w13_blockscale_swizzled.view(torch.int32),
layer.g1_alphas,
layer.w2_input_scale_quant,
layer.w2_blockscale_swizzled.view(torch.int32),
layer.g2_alphas,
],
ep_size=ep_size,
ep_rank=ep_rank,
tp_size=tp_size,
tp_rank=tp_rank,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid confusion, could you add a comment here indicating that only EP or TP is supported under the current parallel configuration? Once EP and TP can coexist, the comment can be removed.

tune_max_num_tokens=next_power_of_2(x.shape[0]),
)
return output[0]

from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4

return cutlass_moe_fp4(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"enable_deepep_moe",
"deepep_mode",
"enable_ep_moe",
"enable_flashinfer_moe",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
"deepep_config",
Expand Down
35 changes: 34 additions & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def __init__(
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -238,6 +239,7 @@ def __init__(
)
self.config = config
self.layer_id = layer_id
self.alt_stream = alt_stream

if self.tp_size > config.n_routed_experts:
raise ValueError(
Expand Down Expand Up @@ -275,6 +277,15 @@ def __init__(
if global_server_args_dict["enable_deepep_moe"]
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
)
if global_server_args_dict["enable_flashinfer_moe"]
else {}
),
)

if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
Expand Down Expand Up @@ -338,10 +349,31 @@ def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor:
if not self._enable_deepep_moe:
return self.forward_normal(hidden_states)
if self.alt_stream is not None and self.num_fused_shared_experts == 0:
return self.forward_normal_dual_stream(hidden_states)
else:
return self.forward_normal(hidden_states)
else:
return self.forward_deepep(hidden_states, forward_batch)

def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
shared_output = self._forward_shared_experts(hidden_states)
with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream)
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states

def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
Expand Down Expand Up @@ -1446,6 +1478,7 @@ def __init__(
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
layer_id=self.layer_id,
alt_stream=alt_stream,
)
else:
if enable_moe_dense_fully_dp():
Expand Down
Loading
Loading