Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
92d647c
feat(ep_moe): integrate deepgemm into origin ep moe
TianQiLin666666 Apr 22, 2025
e057acb
fix(ep_moe): group_gemm_mask bug
TianQiLin666666 Apr 22, 2025
19ec50e
fix bugs
TianQiLin666666 Apr 22, 2025
3ce1a91
fix bugs
TianQiLin666666 Apr 23, 2025
3d51a71
fix(em_moe): offset bugs
TianQiLin666666 Apr 24, 2025
c80fc3c
fix(deepgemm): bugfix
TianQiLin666666 Apr 24, 2025
af94a8b
fix: remove redundant code
TianQiLin666666 Apr 28, 2025
2022070
fix: clang-format
TianQiLin666666 Apr 28, 2025
55ea483
fix: remove print
TianQiLin666666 Apr 28, 2025
988a522
fix(ep_moe): replace EPMOE_USE_DEEPGEMM with _ENABLE_JIT_DEEPGEMM
TianQiLin666666 Apr 28, 2025
1f81f01
merge main
xutizhou Jun 2, 2025
b4ae984
Refactor moe_ep_deepgemm_preprocess to remove CUDA-specific handling …
xutizhou Jun 2, 2025
c6d51d2
Fix condition for expert fusion by updating the check for 'enable_dee…
xutizhou Jun 2, 2025
0397c25
Fix typo in function name from 'moe_ep_deepgemm_preproess' to 'moe_ep…
xutizhou Jun 2, 2025
aeb437d
Update moe_ep_deepgemm_preprocess to adjust m_max calculation for mas…
xutizhou Jun 3, 2025
d2c19bf
Refactor compute_masked_m_triton_kernel to remove num_experts paramet…
xutizhou Jun 3, 2025
0d3e793
Refactor moe_ep_deepgemm_preprocess to improve assertions and streaml…
xutizhou Jun 3, 2025
ee1a2b1
Refactor kernel functions in kernels.py to improve variable naming an…
xutizhou Jun 3, 2025
d530ee7
Enhance EPMoE layer by capturing hidden states' shape, dtype, and dev…
xutizhou Jun 3, 2025
cb6ea36
Refactor EPMoE layer to improve memory management by explicitly delet…
xutizhou Jun 3, 2025
2769865
Update EPMoE layer to use hidden_states_device for tensor creation, e…
xutizhou Jun 3, 2025
a7a6235
Refactor deepgemm_post_reorder_triton_kernel to improve variable hand…
xutizhou Jun 3, 2025
a1493e5
Optimize EPMoE layer by explicitly deleting intermediate tensor varia…
xutizhou Jun 3, 2025
b1edec5
fix(ep_moe_deepgemm): use dispose_tensor to really free tensor mem
TianQiLin666666 Jun 3, 2025
773f151
Merge pull request #1 from TianQiLin666666/feat/ep_moe_deepgemm_zxt
xutizhou Jun 3, 2025
4975723
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 4, 2025
6584c21
Enhance EPMoE layer to conditionally use deep GEMM based on FP8 setti…
xutizhou Jun 5, 2025
c54a141
Optimize memory management in EPMoE layer by removing unnecessary ten…
xutizhou Jun 5, 2025
86f06b6
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 5, 2025
15c9514
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 5, 2025
afd55c2
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 5, 2025
b596fa2
fix: remove 'del gateup_input_scale' to avoid H20*8 OOM, and move ann…
TianQiLin666666 Jun 6, 2025
fe97d1f
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 7, 2025
431f47a
Merge pull request #2 from TianQiLin666666/feat/ep_moe_deepgemm_zxt
xutizhou Jun 9, 2025
1422da6
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 10, 2025
bbbd98a
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 10, 2025
35cc918
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 11, 2025
1a26425
dispose hidden_states
ch-wan Jun 11, 2025
a902338
Merge branch 'main' into feat/ep_moe_deepgemm
TianQiLin666666 Jun 16, 2025
d6b3afd
fix: call deep_gemm_wrapper APIs in epmoe forward_deepgemm
TianQiLin666666 Jun 16, 2025
b51f324
assert act=="silu" in epmoe forward_deepgemm
TianQiLin666666 Jun 16, 2025
ad9abb2
fix(epmoe): remove _ENABLE_JIT_DEEPGEMM
TianQiLin666666 Jun 16, 2025
1869b18
fix(fill_gateup_input_triton_kernel): pre-define a tl.arange() outsid…
TianQiLin666666 Jun 16, 2025
77123c6
replace deepgemm_post_reorder_triton_kernel with post_reorder_triton_…
TianQiLin666666 Jun 16, 2025
ddc0c33
fix args of post_reorder_triton_kernel in all tests and benchmarks
TianQiLin666666 Jun 16, 2025
8e71771
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 17, 2025
a6d61f6
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 18, 2025
e30b3ab
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 18, 2025
aafbe4e
add num_fused_shared_experts
TianQiLin666666 Jun 19, 2025
0ea5bd9
fix(moe_deepgemm): convert per-tensor weight quant to per-block quant…
TianQiLin666666 Jun 19, 2025
87d68e9
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 19, 2025
26040f4
Merge branch 'main' into feat/ep_moe_deepgemm
xutizhou Jun 20, 2025
bbb127d
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 159 additions & 2 deletions python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,13 @@ def post_reorder_triton_kernel(
end_expert_id,
topk,
hidden_size,
dst_start,
BLOCK_SIZE: tl.constexpr,
):
InDtype = down_output_ptr.dtype.element_ty

src_idx = tl.program_id(0)
src_idx_int32 = tl.program_id(0)
src_idx = src_idx_int32.to(tl.int64)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
topk_weights_ptr = topk_weights_ptr + src_idx * topk
Expand All @@ -501,7 +503,9 @@ def post_reorder_triton_kernel(
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
computed = True
dst_idx = tl.load(src2dst_ptr + idx)
dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_idx = dst_idx - dst_start
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
Expand Down Expand Up @@ -1086,3 +1090,156 @@ def tma_align_input_scale(input_scale: torch.Tensor):
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
return output.t()[:m]


@triton.jit
def compute_masked_m_triton_kernel(seg_indptr, masked_m):
expert_id = tl.program_id(0)
start = tl.load(seg_indptr + expert_id)
end = tl.load(seg_indptr + expert_id + 1)
tl.store(masked_m + expert_id, (end - start))


@triton.jit
def deepgemm_compute_src2dst_triton_kernel(
topk_ids,
reorder_ids,
seg_indptr,
src2dst,
m_max,
num_toks,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = dst_id < num_toks
src_id = tl.load(reorder_ids + dst_id, mask=mask)
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
expert_dst_start = tl.load(seg_indptr + expert_id)
expert_dst_offset = dst_id - expert_dst_start
dst_id = expert_id * m_max + expert_dst_offset
tl.store(src2dst + src_id, dst_id, mask=mask)


@triton.jit
def fill_gateup_input_triton_kernel(
input_ptr,
scale_ptr,
gateup_input_ptr,
gateup_input_scale_ptr,
src2dst_ptr,
topk_ids_ptr,
start_expert_id,
end_expert_id,
topk,
m_max,
hidden_size,
scale_size,
BLOCK_SIZE: tl.constexpr,
):

src_idx_int32 = tl.program_id(0)
src_idx = src_idx_int32.to(tl.int64)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size
scale_src_ptr = scale_ptr + src_idx * scale_size

vec = tl.arange(0, BLOCK_SIZE)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_idx = dst_idx - start_expert_id * m_max
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + vec
mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask)
tl.store(dst_ptr + offset, in_data, mask=mask)
scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
offset = start_offset + vec
mask = offset < scale_size
in_scale = tl.load(scale_src_ptr + offset, mask=mask)
tl.store(scale_dst_ptr + offset, in_scale, mask=mask)


def moe_ep_deepgemm_preprocess(
topk_ids: torch.Tensor,
num_experts: int,
hidden_states: torch.Tensor,
top_k: int,
start_expert_id,
end_expert_id,
block_shape,
output_dtype: torch.dtype = torch.float8_e4m3fn,
):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)

compute_seg_indptr_triton_kernel[(num_experts,)](
reorder_topk_ids, seg_indptr, topk_ids.numel()
)

grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)

# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
m_max = (hidden_states.size(0) + 255) // 256 * 256
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
gateup_input = torch.empty(
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
device=hidden_states.device,
dtype=output_dtype,
)

deepgemm_compute_src2dst_triton_kernel[grid](
topk_ids,
reorder_ids,
seg_indptr,
src2dst,
m_max,
topk_ids.numel(),
BLOCK_SIZE=256,
)

if block_shape is None:
block_shape = [128, 128]
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)

gateup_input_scale = torch.empty(
(gateup_input.size(0), gateup_input.size(1), scale.size(1)),
device=hidden_states.device,
dtype=scale.dtype,
)

fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
scale,
gateup_input,
gateup_input_scale,
src2dst,
topk_ids,
start_expert_id,
end_expert_id,
top_k,
m_max,
hidden_states.size(1),
scale.size(1),
BLOCK_SIZE=1024,
)

return (
m_max,
masked_m[start_expert_id : (end_expert_id + 1)],
expected_m,
src2dst,
gateup_input,
gateup_input_scale,
)
175 changes: 174 additions & 1 deletion python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ep_scatter,
gelu_and_mul_triton_kernel,
grouped_gemm_triton,
moe_ep_deepgemm_preprocess,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
Expand Down Expand Up @@ -178,6 +179,7 @@ def __init__(
assert (
num_fused_shared_experts == 0
), "num_fused_shared_experts is not supported in EP"
self.num_fused_shared_experts = num_fused_shared_experts
self.num_experts_per_partition = self.num_experts // self.tp_size
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
Expand Down Expand Up @@ -227,13 +229,182 @@ def __init__(

self.grouped_gemm_runner = None

self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)

def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm(hidden_states, router_logits)
else:
return self.forward_normal(hidden_states, router_logits)

def forward_deepgemm(
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
):
assert self.quant_method is not None
assert self.activation == "silu"
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
)

assert self.quant_method is not None
if not self.use_block_quant:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
scale_block_size = 128
w13_weight_scale_n = 2 * (
(self.intermediate_size + scale_block_size - 1) // scale_block_size
)
w13_weight_scale_k = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w13_weight_scale = (
self.w13_weight_scale.unsqueeze(1)
.repeat_interleave(w13_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w13_weight_scale_k, dim=2)
)
self.w13_weight_fp8 = (
self.w13_weight,
w13_weight_scale,
)
w2_weight_scale_n = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w2_weight_scale_k = (
self.intermediate_size + scale_block_size - 1
) // scale_block_size
w2_weight_scale = (
self.w2_weight_scale.unsqueeze(1)
.repeat_interleave(w2_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w2_weight_scale_k, dim=2)
)
self.w2_weight_fp8 = (
self.w2_weight,
w2_weight_scale,
)

# PreReorder
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
moe_ep_deepgemm_preprocess(
topk_ids,
self.num_experts,
hidden_states,
self.top_k,
self.start_expert_id,
self.end_expert_id,
self.block_shape,
)
)

dispose_tensor(hidden_states)

# GroupGemm-0
gateup_input_fp8 = (
gateup_input,
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
)
num_groups, m, k = gateup_input_fp8[0].size()
n = self.w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
)
del gateup_input
del gateup_input_fp8

# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
)
del gateup_output

# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
)
down_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
)
del down_input
del down_input_fp8

# PostReorder
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states_shape[1],
m_max * self.start_expert_id,
BLOCK_SIZE=512,
)
return output

def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device,
Expand All @@ -249,6 +420,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
Expand Down Expand Up @@ -440,6 +612,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
self.end_expert_id,
self.top_k,
hidden_states_shape[1],
0,
BLOCK_SIZE=512,
)
return output
Expand Down
Loading
Loading