Skip to content
1 change: 1 addition & 0 deletions docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models.
* `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`.
* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP.
* `deepep_mode`: Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.

## Memory and scheduling

Expand Down
142 changes: 142 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,148 @@ def silu_and_mul_triton_kernel(
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)


# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
@triton.jit
def _silu_and_mul_post_quant_kernel(
input_ptr,
stride_input_0,
stride_input_1,
stride_input_2,
output_ptr,
stride_output_0,
stride_output_1,
stride_output_2,
output_scale_ptr,
stride_output_scale_0,
stride_output_scale_1,
stride_output_scale_2,
masked_m_ptr,
size_n,
fp8_max,
fp8_min,
BLOCK_N: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
expert_id = tl.program_id(2)
token_id = tl.program_id(1)
hidden_dim_block_index = tl.program_id(0)

block_num_per_expert = tl.num_programs(1)

token_num_cur_expert = tl.load(masked_m_ptr + expert_id)

stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)

offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
output_scale_offs = (
output_scale_ptr
+ expert_id * stride_output_scale_0
+ hidden_dim_block_index * stride_output_scale_2
)

for token_index in tl.range(
token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE
):
gate = tl.load(
input_ptr_offs + token_index * stride_input_1,
mask=offs_in_d < size_n,
other=0.0,
).to(tl.float32)
up = tl.load(
input_ptr_offs + token_index * stride_input_1 + size_n,
mask=offs_in_d < size_n,
other=0.0,
)
gate = gate / (1 + tl.exp(-gate))
gate = gate.to(input_ptr.dtype.element_ty)
gate_up = up * gate
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
output_s = _absmax / fp8_max
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
output_ptr.dtype.element_ty
)
tl.store(
output_ptr_offs + token_index * stride_output_1,
output_q,
mask=offs_in_d < size_n,
)
tl.store(
output_scale_offs + token_index * stride_output_scale_1,
output_s,
)


def silu_and_mul_masked_post_quant_fwd(
input: torch.Tensor,
output: torch.Tensor,
output_scale: torch.Tensor,
quant_group_size: int,
masked_m: torch.Tensor,
):
"""
input shape [expert_num, token_num_padded, hidden_dim]
output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
quant_group_size int,
masked_m shape [expert_num],
"""

assert input.is_contiguous()
assert output.dtype == torch.float8_e4m3fn
assert output.is_contiguous()
assert len(input.shape) == 3
assert input.shape[0] == masked_m.shape[0]
assert input.shape[-1] % 2 == 0

size_n = input.shape[-1] // 2
assert size_n % quant_group_size == 0

expert_num = len(masked_m)

if expert_num < 4:
BLOCK_NUM_PER_EXPERT = 64
else:
BLOCK_NUM_PER_EXPERT = 32

BLOCK_N = quant_group_size
num_warps = 1
NUM_STAGES = 6
hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)
assert BLOCK_N % quant_group_size == 0

grid = (
hidden_dim_split_block_num,
BLOCK_NUM_PER_EXPERT,
expert_num,
)

finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = -fp8_max

_silu_and_mul_post_quant_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
output_scale,
*output_scale.stride(),
masked_m,
size_n,
fp8_max,
fp8_min,
BLOCK_N=BLOCK_N,
NUM_STAGE=NUM_STAGES,
num_warps=num_warps,
)
return


@triton.jit
def tanh(x):
return 2 * tl.sigmoid(2 * x) - 1
Expand Down
159 changes: 81 additions & 78 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@

import torch

# TODO: use deep_gemm masked kernel after low latency dispatch
# import deep_gemm
# from deep_gemm import (
# get_col_major_tma_aligned_tensor,
# m_grouped_gemm_fp8_fp8_bf16_nt_masked,
# )
try:
from deep_gemm import (
get_col_major_tma_aligned_tensor,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
)

use_deep_gemm = True
except ImportError:
use_deep_gemm = False

from torch.nn import Module

from sglang.srt.custom_op import CustomOp
Expand All @@ -22,6 +26,7 @@
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
silu_and_mul_masked_post_quant_fwd,
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
Expand Down Expand Up @@ -809,6 +814,7 @@ def __init__(
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
deepep_mode: str = "auto",
):
super().__init__(
num_experts,
Expand All @@ -827,21 +833,41 @@ def __init__(
custom_routing_function,
activation,
)
self.deepep_mode = deepep_mode
if self.deepep_mode in ["low_latency", "auto"]:
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)

def forward(
self,
hidden_states: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
forward_mode: ForwardMode,
):
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
if True: # not forward_mode.is_decode():
if self.deepep_mode == "normal" or (
self.deepep_mode == "auto" and not forward_mode.is_decode()
):
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
elif self.deepep_mode == "low_latency" or (
self.deepep_mode == "auto" and forward_mode.is_decode()
):
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
else:
return self.forward_deepgemm_masked(
hidden_states, reorder_topk_ids, seg_indptr
)
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")

def forward_normal(
self,
Expand Down Expand Up @@ -958,89 +984,66 @@ def forward_normal(

def forward_deepgemm_masked(
self,
hidden_states: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
masked_m: torch.Tensor,
expected_m: int,
):
assert self.quant_method is not None
assert self.activation == "silu"

if self.activation_scheme == "dynamic" and not self.use_block_quant:
max_value = (
torch.max(hidden_states)
.repeat(self.num_experts_per_partition)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
assert (
hidden_states_fp8[0].size(0) % 4 == 0
), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"

# GroupGemm-0
num_groups, m, k = hidden_states_fp8[0].size()
n = self.w13_weight.size(1)
expected_m = min(expected_m, m)
gateup_output = torch.empty(
hidden_states.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
)
if hidden_states.shape[0] > 0:
# Transpose earlier so that the testing will not trigger transposing kernels
hidden_states = (
hidden_states[0],
get_col_major_tma_aligned_tensor(hidden_states[1]),
)
"""
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
hidden_states, self.w13_weight, out, masked_m, expected_m
)
"""

# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=gateup_output.device,
dtype=self.fp8_dtype,
)
if self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.num_experts_per_partition,
dtype=torch.float32,
device=hidden_states.device,
)

if self.activation == "silu":
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
0,
self.num_experts_per_partition - 1,
BLOCK_SIZE=512,
)
else:
raise ValueError(f"Unsupported activation: {self.activation=}")
gateup_output.shape[2] // 2 // scale_block_size,
),
device=gateup_output.device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
)

# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
get_col_major_tma_aligned_tensor(down_input_scale),
)
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
)
if down_input.shape[0] > 0:
# Transpose earlier so that the testing will not trigger transposing kernels
down_input = (
down_input[0],
get_col_major_tma_aligned_tensor(down_input[1]),
)
"""
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input, self.w2_weight, out, masked_m, expected_m
)
"""

return down_output
Loading
Loading