Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,9 @@ def forward_deepgemm_masked(
def get_moe_impl_class():
if global_server_args_dict["enable_deepep_moe"]:
return DeepEPMoE
if global_server_args_dict["enable_flashinfer_moe"]:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
return FusedMoE
if global_server_args_dict["enable_ep_moe"]:
return EPMoE
return FusedMoE
87 changes: 76 additions & 11 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ def __init__(
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False,
):
super().__init__()

Expand All @@ -324,12 +326,45 @@ def __init__(
self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
self.tp_rank = get_tensor_model_parallel_rank()
self.num_experts = num_experts
self.reduce_results = reduce_results
self.expert_map = None
self.enable_flashinfer_moe = enable_flashinfer_moe
if self.enable_flashinfer_moe:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will need to handle MTP which is not quantized. Can we just disable flashinfer_moe and ep_moe for MTP module if quant_config is None (context)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.reduce_results = True
if enable_ep_moe:
assert (
self.enable_flashinfer_moe
), "FusedMoE only supports EP with --enable-flashinfer-moe"
self.ep_size = self.tp_size
self.ep_rank = self.tp_rank
self.tp_size = 1
self.tp_rank = 0
# Create a tensor of size num_experts filled with -1
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
# Create a expert map for the local experts
local_num_experts = num_experts // self.ep_size
if self.ep_rank < (self.ep_size - 1):
# Each non-last rank gets local_num_experts experts.
self.expert_map[
self.ep_rank
* local_num_experts : (self.ep_rank + 1)
* local_num_experts
] = torch.arange(0, local_num_experts, dtype=torch.int32)
else:
# All remaining experts are assigned to the last rank.
local_num_experts = num_experts - self.ep_rank * local_num_experts
self.expert_map[-local_num_experts:] = torch.arange(
0, local_num_experts, dtype=torch.int32
)
else:
self.ep_size = 1
self.ep_rank = 0
self.routed_scaling_factor = routed_scaling_factor
self.top_k = top_k
self.num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
Expand All @@ -344,19 +379,25 @@ def __init__(
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
self.local_num_experts = num_experts
self.local_num_experts = (
torch.sum(self.expert_map != -1)
if self.expert_map is not None
else num_experts
)

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod()
)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
assert self.quant_method is not None

self.quant_method.create_weights(
layer=self,
num_experts=num_experts,
num_experts=self.local_num_experts,
hidden_size=hidden_size,
# FIXME: figure out which intermediate_size to use
intermediate_size=self.intermediate_size_per_partition,
Expand Down Expand Up @@ -450,12 +491,15 @@ def _load_w13(

# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
# trtllm cutlass kernel assumes differently
assert shard_id in ("w1", "w3")
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
start = shard_size
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
start = 0
expert_data = expert_data.narrow(shard_dim, start, shard_size)
expert_data.copy_(loaded_weight)

def _load_w2(
Expand Down Expand Up @@ -509,6 +553,11 @@ def _load_g_idx(
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)

def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self.expert_map is None:
return expert_id
return self.expert_map[expert_id].item()

def weight_loader(
self,
param: torch.nn.Parameter,
Expand All @@ -517,6 +566,13 @@ def weight_loader(
shard_id: str,
expert_id: int,
) -> None:
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return

# TP rank is set to 0 if EP is enabled
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()

# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
Expand All @@ -541,15 +597,14 @@ def weight_loader(
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}

expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()

# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = ~shard_dim
shard_dim = int(not shard_dim)

# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
Expand Down Expand Up @@ -690,9 +745,19 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
**(
dict(
tp_rank=self.tp_rank,
tp_size=self.tp_size,
ep_rank=self.ep_rank,
ep_size=self.ep_size,
)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
else {}
),
)

if self.reduce_results and self.tp_size > 1:
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

return final_hidden_states
Expand Down
60 changes: 58 additions & 2 deletions python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_cuda
from sglang.srt.utils import get_bool_env_var, is_cuda

if is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant

try:
from flashinfer import fp4_quantize as fp4_quantize
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
except ImportError:
flashinfer_cutlass_fused_moe = None

# Initialize logger for the module
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -521,6 +527,7 @@ def __init__(self, quant_config: ModelOptFp4Config):
" quantization. Please use Blackwell and"
" above."
)
self.enable_flashinfer_moe = False

def create_weights(
self,
Expand Down Expand Up @@ -727,11 +734,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.cutlass_moe_params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
device,
num_experts=layer.num_experts,
num_experts=layer.num_experts, # global num experts
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
hidden_size=layer.w13_weight.shape[2] * 2,
) # k

@property
def load_up_proj_weight_first(self) -> bool:
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
return self.enable_flashinfer_moe

def apply(
self,
layer: torch.nn.Module,
Expand All @@ -750,6 +762,10 @@ def apply(
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
) -> torch.Tensor:

assert activation == "silu", "Only SiLU activation is supported."
Expand All @@ -771,6 +787,46 @@ def apply(
routed_scaling_factor=routed_scaling_factor,
)

if self.enable_flashinfer_moe:
assert (
not apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer"
a1_gs = torch.min(layer.w13_input_scale_quant)
a2_gs = torch.min(layer.w2_input_scale_quant)
w1_blockscale = layer.w13_blockscale_swizzled
w2_blockscale = layer.w2_blockscale_swizzled
g1_alphas = layer.g1_alphas
g2_alphas = layer.g2_alphas

quant_scales = [
a1_gs,
w1_blockscale.view(torch.int32),
g1_alphas,
a2_gs,
w2_blockscale.view(torch.int32),
g2_alphas,
]
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
out_dtype = x.dtype
output = x if inplace else torch.zeros_like(x)
x, x_sf = fp4_quantize(x, a1_gs)
output = flashinfer_cutlass_fused_moe(
x,
topk_ids.to(torch.int),
topk_weights,
layer.w13_weight.view(torch.long),
layer.w2_weight.view(torch.long),
out_dtype,
quant_scales=quant_scales,
input_sf=x_sf,
ep_size=ep_size,
ep_rank=ep_rank,
tp_size=tp_size,
tp_rank=tp_rank,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid confusion, could you add a comment here indicating that only EP or TP is supported under the current parallel configuration? Once EP and TP can coexist, the comment can be removed.

)
return output[0]

from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4

return cutlass_moe_fp4(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"enable_deepep_moe",
"deepep_mode",
"enable_ep_moe",
"enable_flashinfer_moe",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
"deepep_config",
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,15 @@ def __init__(
if global_server_args_dict["enable_deepep_moe"]
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
)
if global_server_args_dict["enable_flashinfer_moe"]
else {}
),
)

if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
Expand Down
15 changes: 14 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class ServerArgs:
ep_size: int = 1
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
enable_flashinfer_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
ep_num_redundant_experts: int = 0
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
Expand Down Expand Up @@ -244,7 +245,14 @@ def __post_init__(self):
logger.warning(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)

if self.enable_flashinfer_moe:
assert (
self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE"
self.disable_shared_experts_fusion = True
logger.warning(
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
)
# Set missing default values
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
Expand Down Expand Up @@ -1162,6 +1170,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
)
parser.add_argument(
"--enable-flashinfer-moe",
action="store_true",
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
)
parser.add_argument(
"--enable-deepep-moe",
action="store_true",
Expand Down
Loading