Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
92d647c
feat(ep_moe): integrate deepgemm into origin ep moe
TianQiLin666666 Apr 22, 2025
e057acb
fix(ep_moe): group_gemm_mask bug
TianQiLin666666 Apr 22, 2025
19ec50e
fix bugs
TianQiLin666666 Apr 22, 2025
3ce1a91
fix bugs
TianQiLin666666 Apr 23, 2025
3d51a71
fix(em_moe): offset bugs
TianQiLin666666 Apr 24, 2025
c80fc3c
fix(deepgemm): bugfix
TianQiLin666666 Apr 24, 2025
af94a8b
fix: remove redundant code
TianQiLin666666 Apr 28, 2025
2022070
fix: clang-format
TianQiLin666666 Apr 28, 2025
55ea483
fix: remove print
TianQiLin666666 Apr 28, 2025
988a522
fix(ep_moe): replace EPMOE_USE_DEEPGEMM with _ENABLE_JIT_DEEPGEMM
TianQiLin666666 Apr 28, 2025
1f81f01
merge main
xutizhou Jun 2, 2025
b4ae984
Refactor moe_ep_deepgemm_preprocess to remove CUDA-specific handling …
xutizhou Jun 2, 2025
c6d51d2
Fix condition for expert fusion by updating the check for 'enable_dee…
xutizhou Jun 2, 2025
0397c25
Fix typo in function name from 'moe_ep_deepgemm_preproess' to 'moe_ep…
xutizhou Jun 2, 2025
aeb437d
Update moe_ep_deepgemm_preprocess to adjust m_max calculation for mas…
xutizhou Jun 3, 2025
d2c19bf
Refactor compute_masked_m_triton_kernel to remove num_experts paramet…
xutizhou Jun 3, 2025
0d3e793
Refactor moe_ep_deepgemm_preprocess to improve assertions and streaml…
xutizhou Jun 3, 2025
ee1a2b1
Refactor kernel functions in kernels.py to improve variable naming an…
xutizhou Jun 3, 2025
d530ee7
Enhance EPMoE layer by capturing hidden states' shape, dtype, and dev…
xutizhou Jun 3, 2025
cb6ea36
Refactor EPMoE layer to improve memory management by explicitly delet…
xutizhou Jun 3, 2025
2769865
Update EPMoE layer to use hidden_states_device for tensor creation, e…
xutizhou Jun 3, 2025
a7a6235
Refactor deepgemm_post_reorder_triton_kernel to improve variable hand…
xutizhou Jun 3, 2025
a1493e5
Optimize EPMoE layer by explicitly deleting intermediate tensor varia…
xutizhou Jun 3, 2025
b1edec5
fix(ep_moe_deepgemm): use dispose_tensor to really free tensor mem
TianQiLin666666 Jun 3, 2025
773f151
Merge pull request #1 from TianQiLin666666/feat/ep_moe_deepgemm_zxt
xutizhou Jun 3, 2025
4975723
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 4, 2025
6584c21
Enhance EPMoE layer to conditionally use deep GEMM based on FP8 setti…
xutizhou Jun 5, 2025
c54a141
Optimize memory management in EPMoE layer by removing unnecessary ten…
xutizhou Jun 5, 2025
86f06b6
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 5, 2025
15c9514
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 5, 2025
afd55c2
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 5, 2025
b596fa2
fix: remove 'del gateup_input_scale' to avoid H20*8 OOM, and move ann…
TianQiLin666666 Jun 6, 2025
fe97d1f
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 7, 2025
431f47a
Merge pull request #2 from TianQiLin666666/feat/ep_moe_deepgemm_zxt
xutizhou Jun 9, 2025
1422da6
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 10, 2025
bbbd98a
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 10, 2025
35cc918
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 11, 2025
1a26425
dispose hidden_states
ch-wan Jun 11, 2025
a902338
Merge branch 'main' into feat/ep_moe_deepgemm
TianQiLin666666 Jun 16, 2025
d6b3afd
fix: call deep_gemm_wrapper APIs in epmoe forward_deepgemm
TianQiLin666666 Jun 16, 2025
b51f324
assert act=="silu" in epmoe forward_deepgemm
TianQiLin666666 Jun 16, 2025
ad9abb2
fix(epmoe): remove _ENABLE_JIT_DEEPGEMM
TianQiLin666666 Jun 16, 2025
1869b18
fix(fill_gateup_input_triton_kernel): pre-define a tl.arange() outsid…
TianQiLin666666 Jun 16, 2025
77123c6
replace deepgemm_post_reorder_triton_kernel with post_reorder_triton_…
TianQiLin666666 Jun 16, 2025
ddc0c33
fix args of post_reorder_triton_kernel in all tests and benchmarks
TianQiLin666666 Jun 16, 2025
8e71771
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 17, 2025
a6d61f6
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 18, 2025
e30b3ab
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 18, 2025
aafbe4e
add num_fused_shared_experts
TianQiLin666666 Jun 19, 2025
0ea5bd9
fix(moe_deepgemm): convert per-tensor weight quant to per-block quant…
TianQiLin666666 Jun 19, 2025
87d68e9
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 19, 2025
26040f4
Merge branch 'main' into feat/ep_moe_deepgemm
xutizhou Jun 20, 2025
bbb127d
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,3 +1085,207 @@ def tma_align_input_scale(input_scale: torch.Tensor):
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
return output.t()[:m]


@triton.jit
def compute_masked_m_triton_kernel(seg_indptr, masked_m):
expert_id = tl.program_id(0)
start = tl.load(seg_indptr + expert_id)
end = tl.load(seg_indptr + expert_id + 1)
tl.store(masked_m + expert_id, (end - start))


@triton.jit
def deepgemm_compute_src2dst_triton_kernel(
topk_ids,
reorder_ids,
seg_indptr,
src2dst,
m_max,
num_toks,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = dst_id < num_toks
src_id = tl.load(reorder_ids + dst_id, mask=mask)
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
expert_dst_start = tl.load(seg_indptr + expert_id)
expert_dst_offset = dst_id - expert_dst_start
dst_id = expert_id * m_max + expert_dst_offset
tl.store(src2dst + src_id, dst_id, mask=mask)


@triton.jit
def fill_gateup_input_triton_kernel(
input_ptr,
scale_ptr,
gateup_input_ptr,
gateup_input_scale_ptr,
src2dst_ptr,
topk_ids_ptr,
start_expert_id,
end_expert_id,
topk,
m_max,
hidden_size,
scale_size,
BLOCK_SIZE: tl.constexpr,
):

src_idx_int32 = tl.program_id(0)
src_idx = src_idx_int32.to(tl.int64)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size
scale_src_ptr = scale_ptr + src_idx * scale_size

for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_idx = dst_idx - start_expert_id * m_max
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pre-define a tl.arange() outside this for-loop and reuse it, can this kernel be faster?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask)
tl.store(dst_ptr + offset, in_data, mask=mask)
scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < scale_size
in_scale = tl.load(scale_src_ptr + offset, mask=mask)
tl.store(scale_dst_ptr + offset, in_scale, mask=mask)


def moe_ep_deepgemm_preprocess(
topk_ids: torch.Tensor,
num_experts: int,
hidden_states: torch.Tensor,
top_k: int,
start_expert_id,
end_expert_id,
block_shape,
output_dtype: torch.dtype = torch.float8_e4m3fn,
):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)

compute_seg_indptr_triton_kernel[(num_experts,)](
reorder_topk_ids, seg_indptr, topk_ids.numel()
)

grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)

# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
m_max = (hidden_states.size(0) + 255) // 256 * 256
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
gateup_input = torch.empty(
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
device=hidden_states.device,
dtype=output_dtype,
)

deepgemm_compute_src2dst_triton_kernel[grid](
topk_ids,
reorder_ids,
seg_indptr,
src2dst,
m_max,
topk_ids.numel(),
BLOCK_SIZE=256,
)

assert block_shape is not None, "block_shape is not None"
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)

gateup_input_scale = torch.empty(
(gateup_input.size(0), gateup_input.size(1), scale.size(1)),
device=hidden_states.device,
dtype=scale.dtype,
)

fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
scale,
gateup_input,
gateup_input_scale,
src2dst,
topk_ids,
start_expert_id,
end_expert_id,
top_k,
m_max,
hidden_states.size(1),
scale.size(1),
BLOCK_SIZE=1024,
)

return (
m_max,
masked_m[start_expert_id : (end_expert_id + 1)],
expected_m,
src2dst,
gateup_input,
gateup_input_scale,
)


@triton.jit
def deepgemm_post_reorder_triton_kernel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this kernel is very similar with post_reorder_triton_kernel. Can we simply add an additional argument to post_reorder_triton_kernel to control the workflow? This can avoid repetitive implementation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

down_output_ptr,
output_ptr,
src2dst_ptr,
topk_ids_ptr,
topk_weights_ptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
max_m,
BLOCK_SIZE: tl.constexpr,
):
InDtype = down_output_ptr.dtype.element_ty

src_idx_int32 = tl.program_id(0)
src_idx = src_idx_int32.to(tl.int64)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
topk_weights_ptr = topk_weights_ptr + src_idx * topk

computed = False
store_ptr = output_ptr + src_idx * hidden_size

vec = tl.arange(0, BLOCK_SIZE)
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + vec
mask = offset < hidden_size

sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
computed = True
dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_idx = dst_idx - start_expert_id * max_m
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
sum_vec += in_data * weigh_scale
tl.store(store_ptr + offset, sum_vec, mask=mask)

if not computed:
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + vec
mask = offset < hidden_size
tl.store(
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
)
147 changes: 145 additions & 2 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
deepgemm_post_reorder_triton_kernel,
ep_gather,
ep_scatter,
gelu_and_mul_triton_kernel,
grouped_gemm_triton,
moe_ep_deepgemm_preprocess,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
Expand All @@ -49,13 +51,21 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import (
scaled_fp8_quant,
sglang_per_token_quant_fp8,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
from sglang.srt.utils import (
DeepEPMode,
dispose_tensor,
get_bool_env_var,
is_cuda,
is_hip,
set_weight_attrs,
)

_is_hip = is_hip()

Expand Down Expand Up @@ -228,13 +238,146 @@ def __init__(

self.grouped_gemm_runner = None

self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)

def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if use_deep_gemm and _ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm(hidden_states, router_logits)
else:
return self.forward_normal(hidden_states, router_logits)

def forward_deepgemm(
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
):
assert self.quant_method is not None
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
)

assert self.quant_method is not None
# PreReorder
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
moe_ep_deepgemm_preprocess(
topk_ids,
self.num_experts,
hidden_states,
self.top_k,
self.start_expert_id,
self.end_expert_id,
self.block_shape,
)
)
gateup_input_fp8 = (
gateup_input,
get_col_major_tma_aligned_tensor(gateup_input_scale),
)
del gateup_input, gateup_input_scale

# GroupGemm-0
num_groups, m, k = gateup_input_fp8[0].size()
n = self.w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)

m_grouped_gemm_fp8_fp8_bf16_nt_masked(
gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
)
dispose_tensor(gateup_input_fp8[0])
dispose_tensor(gateup_input_fp8[1])
del gateup_input_fp8

# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
)
dispose_tensor(gateup_output)

# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
get_col_major_tma_aligned_tensor(down_input_scale),
)
del down_input, down_input_scale
down_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
)
dispose_tensor(down_input_fp8[0])
dispose_tensor(down_input_fp8[1])
del down_input_fp8
# PostReorder
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
deepgemm_post_reorder_triton_kernel[(hidden_states_shape[0],)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states_shape[1],
m_max,
BLOCK_SIZE=512,
)
return output

def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device,
Expand Down
Loading
Loading