Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -877,3 +877,74 @@ def moe_sum_reduce_triton(
num_warps=num_warps,
)
return


@triton.jit
def _fused_append_shared_experts_kernel(
topk_ids_ptr,
topk_weights_ptr,
out_ids_ptr,
out_weights_ptr,
N_BASE, # runtime scalar
scale_factor, # runtime scalar
K: tl.constexpr,
S: tl.constexpr,
):
"""
for m in range(M):
for n in range(K):
fused_ids[m, n] = topk_ids[m, n]
fused_weights[m, n] = topk_weights[m, n]
for s in range(S):
fused_ids[m, K + s] = N + s
fused_weights[m, K + s] = scale_factor
"""
pid = tl.program_id(0)

ids_row_ptr = pid * K
w_row_ptr = pid * K
out_ids_row_ptr = pid * (K + S)
out_w_row_ptr = pid * (K + S)

offs_k = tl.arange(0, K)
ids = tl.load(topk_ids_ptr + ids_row_ptr + offs_k)
ws = tl.load(topk_weights_ptr + w_row_ptr + offs_k)

tl.store(out_ids_ptr + out_ids_row_ptr + offs_k, ids)
tl.store(out_weights_ptr + out_w_row_ptr + offs_k, ws)

offs_s = tl.arange(0, S)

shared_ids = tl.cast(N_BASE + offs_s, ids.dtype)
shared_ws = tl.full([S], scale_factor, dtype=ws.dtype)

tl.store(out_ids_ptr + out_ids_row_ptr + K + offs_s, shared_ids)
tl.store(out_weights_ptr + out_w_row_ptr + K + offs_s, shared_ws)


def fused_append_shared_experts(
topk_ids, topk_weights, num_fused_shared_experts, scale_factor, N=None
):
assert N is not None, "N (shared expert base id) must be provided"
m, k = topk_ids.shape
s = int(num_fused_shared_experts)
if s <= 0:
return topk_ids, topk_weights

out_ids = torch.empty((m, k + s), dtype=topk_ids.dtype, device=topk_ids.device)
out_weights = torch.empty(
(m, k + s), dtype=topk_weights.dtype, device=topk_weights.device
)

_fused_append_shared_experts_kernel[(m,)](
topk_ids,
topk_weights,
out_ids,
out_weights,
N_BASE=N,
scale_factor=scale_factor,
K=k,
S=s,
num_warps=1,
)
return out_ids, out_weights
36 changes: 11 additions & 25 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,8 +804,8 @@ def biased_grouped_topk_gpu(
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
aiter_biased_grouped_topk(
gating_output.to(dtype=torch.float32),
correction_bias,
gating_output,
correction_bias.to(dtype=gating_output.dtype),
topk_weights,
topk_ids,
num_expert_group,
Expand Down Expand Up @@ -991,7 +991,6 @@ def select_experts(
renormalize=renormalize,
)

# TODO: fused ops of shared experts in topk function itself when num_fused_shared_experts > 0.
if num_fused_shared_experts > 0 and _use_aiter:
M, N = router_logits.shape
scale_factor = (
Expand All @@ -1000,30 +999,17 @@ def select_experts(
else fused_shared_experts_scaling_factor
)

topk_ids = torch.cat(
[
topk_ids,
torch.arange(
N,
N + num_fused_shared_experts,
dtype=topk_ids.dtype,
device=topk_ids.device,
).expand(M, -1),
],
dim=1,
# Lazy import to avoid circular-import issues
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_kernels import (
fused_append_shared_experts,
)

topk_weights = torch.cat(
[
topk_weights,
torch.full(
(topk_weights.size(0), num_fused_shared_experts),
scale_factor,
dtype=topk_weights.dtype,
device=topk_weights.device,
),
],
dim=1,
topk_ids, topk_weights = fused_append_shared_experts(
topk_ids,
topk_weights,
num_fused_shared_experts,
scale_factor,
N, # base id for shared experts
)

get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
Expand Down
11 changes: 10 additions & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,14 @@
_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported

if _use_aiter_gfx95:

from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import (
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant,
)
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
from aiter.ops.triton.fused_fp8_quant import (
fused_flatten_fp8_group_quant,
fused_rms_fp8_group_quant,
)

from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
Expand Down Expand Up @@ -2001,6 +2005,11 @@ def forward_absorb_core(
if self.o_proj.weight.dtype == torch.uint8:
attn_bmm_output = attn_bmm_output.transpose(0, 1)
attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
elif self.o_proj.weight.dtype == torch.float8_e4m3fn:
attn_bmm_output = attn_bmm_output.transpose(0, 1)
attn_bmm_output = fused_flatten_fp8_group_quant(
attn_bmm_output, group_size=128, dtype_quant=torch.float8_e4m3fn
)
else:
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)

Expand Down
Loading