Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,9 @@ def forward_deepgemm_masked(
def get_moe_impl_class():
if global_server_args_dict["enable_deepep_moe"]:
return DeepEPMoE
if global_server_args_dict["enable_flashinfer_moe"]:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
return FusedMoE
if global_server_args_dict["enable_ep_moe"]:
return EPMoE
return FusedMoE
72 changes: 62 additions & 10 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ def __init__(
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False,
):
super().__init__()

Expand All @@ -324,9 +326,34 @@ def __init__(
self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
self.tp_rank = get_tensor_model_parallel_rank()
self.num_experts = num_experts
self.expert_map = None
self.enable_flashinfer_moe = enable_flashinfer_moe
if enable_ep_moe:
assert (
self.enable_flashinfer_moe
), "FusedMoE only supports EP with --enable-flashinfer-moe"
self.ep_size = self.tp_size
self.ep_rank = self.tp_rank
self.tp_size = 1
self.tp_rank = 0
# Create a tensor of size num_experts filled with -1
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
# Create a expert map for the local experts
assert num_experts % self.ep_size == 0
self.local_num_experts = num_experts // self.ep_size
self.expert_map[
self.ep_rank
* self.local_num_experts : (self.ep_rank + 1)
* self.local_num_experts
] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
else:
self.ep_size = 1
self.ep_rank = 0
self.local_num_experts = num_experts
self.routed_scaling_factor = routed_scaling_factor
self.top_k = top_k
self.num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
Expand All @@ -344,19 +371,20 @@ def __init__(
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
self.local_num_experts = num_experts

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod()
)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
assert self.quant_method is not None

self.quant_method.create_weights(
layer=self,
num_experts=num_experts,
num_experts=self.local_num_experts,
hidden_size=hidden_size,
# FIXME: figure out which intermediate_size to use
intermediate_size=self.intermediate_size_per_partition,
Expand Down Expand Up @@ -450,12 +478,15 @@ def _load_w13(

# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
# trtllm cutlass kernel assumes differently
assert shard_id in ("w1", "w3")
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
start = shard_size
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
start = 0
expert_data = expert_data.narrow(shard_dim, start, shard_size)
expert_data.copy_(loaded_weight)

def _load_w2(
Expand Down Expand Up @@ -509,6 +540,11 @@ def _load_g_idx(
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)

def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self.expert_map is None:
return expert_id
return self.expert_map[expert_id].item()

def weight_loader(
self,
param: torch.nn.Parameter,
Expand All @@ -517,6 +553,13 @@ def weight_loader(
shard_id: str,
expert_id: int,
) -> None:
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return

# TP rank is set to 0 if EP is enabled
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()

# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
Expand All @@ -541,15 +584,14 @@ def weight_loader(
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}

expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()

# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = ~shard_dim
shard_dim = int(not shard_dim)

# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
Expand Down Expand Up @@ -690,9 +732,19 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
**(
dict(
tp_rank=self.tp_rank,
tp_size=self.tp_size,
ep_rank=self.ep_rank,
ep_size=self.ep_size,
)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
else {}
),
)

if self.reduce_results and self.tp_size > 1:
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

return final_hidden_states
Expand Down
70 changes: 62 additions & 8 deletions python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_cuda
from sglang.srt.utils import is_cuda, next_power_of_2

if is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant

try:
from flashinfer import fp4_quantize as fp4_quantize
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
except ImportError:
flashinfer_cutlass_fused_moe = None

# Initialize logger for the module
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -429,6 +435,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.alpha = Parameter(
layer.input_scale * layer.weight_scale_2, requires_grad=False
)
layer.input_scale_inv = Parameter(
(1 / input_scale_2).to(torch.float32), requires_grad=False
)

# Pad and blockwise interleave weight_scale
scales = layer.weight_scale
Expand Down Expand Up @@ -467,7 +476,7 @@ def apply(
output_shape = [x_m, w_n]

# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale)
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv)

assert x_fp4.dtype == torch.uint8
assert x_scale_interleaved.dtype == torch.float8_e4m3fn
Expand Down Expand Up @@ -521,6 +530,7 @@ def __init__(self, quant_config: ModelOptFp4Config):
" quantization. Please use Blackwell and"
" above."
)
self.enable_flashinfer_moe = False

def create_weights(
self,
Expand Down Expand Up @@ -674,7 +684,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)

w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
if self.enable_flashinfer_moe:
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
layer.g1_alphas = Parameter(
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
requires_grad=False,
Expand All @@ -700,14 +713,19 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)

# GEMM 2
if self.enable_flashinfer_moe:
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
else:
w2_input_scale = layer.w2_input_scale

layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False,
)

# This is for quantization, so we need to invert it.
layer.w2_input_scale_quant = Parameter(
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
(1 / w2_input_scale).to(torch.float32), requires_grad=False
)

assert (
Expand All @@ -727,11 +745,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.cutlass_moe_params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
device,
num_experts=layer.num_experts,
num_experts=layer.num_experts, # global num experts
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
hidden_size=layer.w13_weight.shape[2] * 2,
) # k

@property
def load_up_proj_weight_first(self) -> bool:
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
return self.enable_flashinfer_moe

def apply(
self,
layer: torch.nn.Module,
Expand All @@ -750,11 +773,13 @@ def apply(
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
) -> torch.Tensor:

assert activation == "silu", "Only SiLU activation is supported."

from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts

topk_weights, topk_ids = select_experts(
Expand All @@ -771,6 +796,35 @@ def apply(
routed_scaling_factor=routed_scaling_factor,
)

if self.enable_flashinfer_moe:
assert (
not apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
output = flashinfer_cutlass_fused_moe(
x,
topk_ids.to(torch.int),
topk_weights,
layer.w13_weight.view(torch.long),
layer.w2_weight.view(torch.long),
x.dtype,
quant_scales=[
layer.w13_input_scale_quant,
layer.w13_blockscale_swizzled.view(torch.int32),
layer.g1_alphas,
layer.w2_input_scale_quant,
layer.w2_blockscale_swizzled.view(torch.int32),
layer.g2_alphas,
],
ep_size=ep_size,
ep_rank=ep_rank,
tp_size=tp_size,
tp_rank=tp_rank,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid confusion, could you add a comment here indicating that only EP or TP is supported under the current parallel configuration? Once EP and TP can coexist, the comment can be removed.

tune_max_num_tokens=next_power_of_2(x.shape[0]),
)
return output[0]

from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4

return cutlass_moe_fp4(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"enable_deepep_moe",
"deepep_mode",
"enable_ep_moe",
"enable_flashinfer_moe",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
"deepep_config",
Expand Down
40 changes: 39 additions & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def __init__(
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -238,6 +239,7 @@ def __init__(
)
self.config = config
self.layer_id = layer_id
self.alt_stream = alt_stream

if self.tp_size > config.n_routed_experts:
raise ValueError(
Expand Down Expand Up @@ -275,6 +277,15 @@ def __init__(
if global_server_args_dict["enable_deepep_moe"]
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
)
if global_server_args_dict["enable_flashinfer_moe"]
else {}
),
)

if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
Expand Down Expand Up @@ -338,10 +349,36 @@ def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor:
if not self._enable_deepep_moe:
return self.forward_normal(hidden_states)
DUAL_STREAM_TOKEN_THRESHOLD = 1024
if (
self.alt_stream is not None
and self.num_fused_shared_experts == 0
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
):
return self.forward_normal_dual_stream(hidden_states)
else:
return self.forward_normal(hidden_states)
else:
return self.forward_deepep(hidden_states, forward_batch)

def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
shared_output = self._forward_shared_experts(hidden_states)
with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream)
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states

def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
Expand Down Expand Up @@ -1446,6 +1483,7 @@ def __init__(
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
layer_id=self.layer_id,
alt_stream=alt_stream,
)
else:
if enable_moe_dense_fully_dp():
Expand Down
Loading
Loading