diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py index 11b555a7833a..8737c26b7c7c 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py @@ -877,3 +877,74 @@ def moe_sum_reduce_triton( num_warps=num_warps, ) return + + +@triton.jit +def _fused_append_shared_experts_kernel( + topk_ids_ptr, + topk_weights_ptr, + out_ids_ptr, + out_weights_ptr, + N_BASE, # runtime scalar + scale_factor, # runtime scalar + K: tl.constexpr, + S: tl.constexpr, +): + """ + for m in range(M): + for n in range(K): + fused_ids[m, n] = topk_ids[m, n] + fused_weights[m, n] = topk_weights[m, n] + for s in range(S): + fused_ids[m, K + s] = N + s + fused_weights[m, K + s] = scale_factor + """ + pid = tl.program_id(0) + + ids_row_ptr = pid * K + w_row_ptr = pid * K + out_ids_row_ptr = pid * (K + S) + out_w_row_ptr = pid * (K + S) + + offs_k = tl.arange(0, K) + ids = tl.load(topk_ids_ptr + ids_row_ptr + offs_k) + ws = tl.load(topk_weights_ptr + w_row_ptr + offs_k) + + tl.store(out_ids_ptr + out_ids_row_ptr + offs_k, ids) + tl.store(out_weights_ptr + out_w_row_ptr + offs_k, ws) + + offs_s = tl.arange(0, S) + + shared_ids = tl.cast(N_BASE + offs_s, ids.dtype) + shared_ws = tl.full([S], scale_factor, dtype=ws.dtype) + + tl.store(out_ids_ptr + out_ids_row_ptr + K + offs_s, shared_ids) + tl.store(out_weights_ptr + out_w_row_ptr + K + offs_s, shared_ws) + + +def fused_append_shared_experts( + topk_ids, topk_weights, num_fused_shared_experts, scale_factor, N=None +): + assert N is not None, "N (shared expert base id) must be provided" + m, k = topk_ids.shape + s = int(num_fused_shared_experts) + if s <= 0: + return topk_ids, topk_weights + + out_ids = torch.empty((m, k + s), dtype=topk_ids.dtype, device=topk_ids.device) + out_weights = torch.empty( + (m, k + s), dtype=topk_weights.dtype, device=topk_weights.device + ) + + _fused_append_shared_experts_kernel[(m,)]( + topk_ids, + topk_weights, + out_ids, + out_weights, + N_BASE=N, + scale_factor=scale_factor, + K=k, + S=s, + num_warps=1, + ) + return out_ids, out_weights diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 30b7cc5da496..84396329d691 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -804,8 +804,8 @@ def biased_grouped_topk_gpu( topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) aiter_biased_grouped_topk( - gating_output.to(dtype=torch.float32), - correction_bias, + gating_output, + correction_bias.to(dtype=gating_output.dtype), topk_weights, topk_ids, num_expert_group, @@ -991,7 +991,6 @@ def select_experts( renormalize=renormalize, ) - # TODO: fused ops of shared experts in topk function itself when num_fused_shared_experts > 0. if num_fused_shared_experts > 0 and _use_aiter: M, N = router_logits.shape scale_factor = ( @@ -1000,30 +999,17 @@ def select_experts( else fused_shared_experts_scaling_factor ) - topk_ids = torch.cat( - [ - topk_ids, - torch.arange( - N, - N + num_fused_shared_experts, - dtype=topk_ids.dtype, - device=topk_ids.device, - ).expand(M, -1), - ], - dim=1, + # Lazy import to avoid circular-import issues + from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_kernels import ( + fused_append_shared_experts, ) - topk_weights = torch.cat( - [ - topk_weights, - torch.full( - (topk_weights.size(0), num_fused_shared_experts), - scale_factor, - dtype=topk_weights.dtype, - device=topk_weights.device, - ), - ], - dim=1, + topk_ids, topk_weights = fused_append_shared_experts( + topk_ids, + topk_weights, + num_fused_shared_experts, + scale_factor, + N, # base id for shared experts ) get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b7ed99d56822..9b47317a331c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -168,10 +168,14 @@ _use_aiter_gfx95 = _use_aiter and _is_gfx95_supported if _use_aiter_gfx95: + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, ) - from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + from aiter.ops.triton.fused_fp8_quant import ( + fused_flatten_fp8_group_quant, + fused_rms_fp8_group_quant, + ) from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights from sglang.srt.layers.quantization.rocm_mxfp4_utils import ( @@ -2001,6 +2005,11 @@ def forward_absorb_core( if self.o_proj.weight.dtype == torch.uint8: attn_bmm_output = attn_bmm_output.transpose(0, 1) attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output) + elif self.o_proj.weight.dtype == torch.float8_e4m3fn: + attn_bmm_output = attn_bmm_output.transpose(0, 1) + attn_bmm_output = fused_flatten_fp8_group_quant( + attn_bmm_output, group_size=128, dtype_quant=torch.float8_e4m3fn + ) else: attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)