diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 0f0b0180ff7..015c380eda3 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -478,11 +478,13 @@ def post_reorder_triton_kernel( end_expert_id, topk, hidden_size, + dst_start, BLOCK_SIZE: tl.constexpr, ): InDtype = down_output_ptr.dtype.element_ty - src_idx = tl.program_id(0) + src_idx_int32 = tl.program_id(0) + src_idx = src_idx_int32.to(tl.int64) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk topk_weights_ptr = topk_weights_ptr + src_idx * topk @@ -501,7 +503,9 @@ def post_reorder_triton_kernel( expert_id = tl.load(topk_ids_ptr + idx) if expert_id >= start_expert_id and expert_id <= end_expert_id: computed = True - dst_idx = tl.load(src2dst_ptr + idx) + dst_idx_int32 = tl.load(src2dst_ptr + idx) + dst_idx = dst_idx_int32.to(tl.int64) + dst_idx = dst_idx - dst_start weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) load_ptr = down_output_ptr + dst_idx * hidden_size in_data = tl.load(load_ptr + offset, mask=mask) @@ -1086,3 +1090,156 @@ def tma_align_input_scale(input_scale: torch.Tensor): BLOCK_SIZE_K=BLOCK_SIZE_K, ) return output.t()[:m] + + +@triton.jit +def compute_masked_m_triton_kernel(seg_indptr, masked_m): + expert_id = tl.program_id(0) + start = tl.load(seg_indptr + expert_id) + end = tl.load(seg_indptr + expert_id + 1) + tl.store(masked_m + expert_id, (end - start)) + + +@triton.jit +def deepgemm_compute_src2dst_triton_kernel( + topk_ids, + reorder_ids, + seg_indptr, + src2dst, + m_max, + num_toks, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = dst_id < num_toks + src_id = tl.load(reorder_ids + dst_id, mask=mask) + expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks)) + expert_dst_start = tl.load(seg_indptr + expert_id) + expert_dst_offset = dst_id - expert_dst_start + dst_id = expert_id * m_max + expert_dst_offset + tl.store(src2dst + src_id, dst_id, mask=mask) + + +@triton.jit +def fill_gateup_input_triton_kernel( + input_ptr, + scale_ptr, + gateup_input_ptr, + gateup_input_scale_ptr, + src2dst_ptr, + topk_ids_ptr, + start_expert_id, + end_expert_id, + topk, + m_max, + hidden_size, + scale_size, + BLOCK_SIZE: tl.constexpr, +): + + src_idx_int32 = tl.program_id(0) + src_idx = src_idx_int32.to(tl.int64) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + src_ptr = input_ptr + src_idx * hidden_size + scale_src_ptr = scale_ptr + src_idx * scale_size + + vec = tl.arange(0, BLOCK_SIZE) + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + dst_idx_int32 = tl.load(src2dst_ptr + idx) + dst_idx = dst_idx_int32.to(tl.int64) + dst_idx = dst_idx - start_expert_id * m_max + dst_ptr = gateup_input_ptr + dst_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + vec + mask = offset < hidden_size + in_data = tl.load(src_ptr + offset, mask=mask) + tl.store(dst_ptr + offset, in_data, mask=mask) + scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size + for start_offset in tl.range(0, scale_size, BLOCK_SIZE): + offset = start_offset + vec + mask = offset < scale_size + in_scale = tl.load(scale_src_ptr + offset, mask=mask) + tl.store(scale_dst_ptr + offset, in_scale, mask=mask) + + +def moe_ep_deepgemm_preprocess( + topk_ids: torch.Tensor, + num_experts: int, + hidden_states: torch.Tensor, + top_k: int, + start_expert_id, + end_expert_id, + block_shape, + output_dtype: torch.dtype = torch.float8_e4m3fn, +): + reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) + seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) + src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) + masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32) + + compute_seg_indptr_triton_kernel[(num_experts,)]( + reorder_topk_ids, seg_indptr, topk_ids.numel() + ) + + grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),) + compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m) + + # For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165 + m_max = (hidden_states.size(0) + 255) // 256 * 256 + expected_m = (topk_ids.numel() + num_experts - 1) // num_experts + gateup_input = torch.empty( + (int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)), + device=hidden_states.device, + dtype=output_dtype, + ) + + deepgemm_compute_src2dst_triton_kernel[grid]( + topk_ids, + reorder_ids, + seg_indptr, + src2dst, + m_max, + topk_ids.numel(), + BLOCK_SIZE=256, + ) + + if block_shape is None: + block_shape = [128, 128] + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k) + + gateup_input_scale = torch.empty( + (gateup_input.size(0), gateup_input.size(1), scale.size(1)), + device=hidden_states.device, + dtype=scale.dtype, + ) + + fill_gateup_input_triton_kernel[(hidden_states.shape[0],)]( + hidden_states, + scale, + gateup_input, + gateup_input_scale, + src2dst, + topk_ids, + start_expert_id, + end_expert_id, + top_k, + m_max, + hidden_states.size(1), + scale.size(1), + BLOCK_SIZE=1024, + ) + + return ( + m_max, + masked_m[start_expert_id : (end_expert_id + 1)], + expected_m, + src2dst, + gateup_input, + gateup_input_scale, + ) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 38f123247cb..5b654b2d8c4 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -16,6 +16,7 @@ ep_scatter, gelu_and_mul_triton_kernel, grouped_gemm_triton, + moe_ep_deepgemm_preprocess, post_reorder_triton_kernel, pre_reorder_triton_kernel, run_moe_ep_preproess, @@ -178,6 +179,7 @@ def __init__( assert ( num_fused_shared_experts == 0 ), "num_fused_shared_experts is not supported in EP" + self.num_fused_shared_experts = num_fused_shared_experts self.num_experts_per_partition = self.num_experts // self.tp_size self.start_expert_id = self.tp_rank * self.num_experts_per_partition self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 @@ -227,13 +229,182 @@ def __init__( self.grouped_gemm_runner = None + self.w13_weight_fp8 = ( + self.w13_weight, + ( + self.w13_weight_scale_inv + if self.use_block_quant + else self.w13_weight_scale + ), + ) + self.w2_weight_fp8 = ( + self.w2_weight, + self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale, + ) + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: + return self.forward_deepgemm(hidden_states, router_logits) + else: + return self.forward_normal(hidden_states, router_logits) + + def forward_deepgemm( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor + ): + assert self.quant_method is not None + assert self.activation == "silu" hidden_states_shape = hidden_states.shape hidden_states_dtype = hidden_states.dtype hidden_states_device = hidden_states.device + topk_weights, topk_ids = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=self.use_grouped_topk, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + num_fused_shared_experts=self.num_fused_shared_experts, + correction_bias=self.correction_bias, + custom_routing_function=self.custom_routing_function, + routed_scaling_factor=self.routed_scaling_factor, + ) - assert self.quant_method is not None + if not self.use_block_quant: + # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm + scale_block_size = 128 + w13_weight_scale_n = 2 * ( + (self.intermediate_size + scale_block_size - 1) // scale_block_size + ) + w13_weight_scale_k = ( + hidden_states_shape[-1] + scale_block_size - 1 + ) // scale_block_size + w13_weight_scale = ( + self.w13_weight_scale.unsqueeze(1) + .repeat_interleave(w13_weight_scale_n, dim=1) + .unsqueeze(2) + .repeat_interleave(w13_weight_scale_k, dim=2) + ) + self.w13_weight_fp8 = ( + self.w13_weight, + w13_weight_scale, + ) + w2_weight_scale_n = ( + hidden_states_shape[-1] + scale_block_size - 1 + ) // scale_block_size + w2_weight_scale_k = ( + self.intermediate_size + scale_block_size - 1 + ) // scale_block_size + w2_weight_scale = ( + self.w2_weight_scale.unsqueeze(1) + .repeat_interleave(w2_weight_scale_n, dim=1) + .unsqueeze(2) + .repeat_interleave(w2_weight_scale_k, dim=2) + ) + self.w2_weight_fp8 = ( + self.w2_weight, + w2_weight_scale, + ) + # PreReorder + m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = ( + moe_ep_deepgemm_preprocess( + topk_ids, + self.num_experts, + hidden_states, + self.top_k, + self.start_expert_id, + self.end_expert_id, + self.block_shape, + ) + ) + + dispose_tensor(hidden_states) + + # GroupGemm-0 + gateup_input_fp8 = ( + gateup_input, + deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale), + ) + num_groups, m, k = gateup_input_fp8[0].size() + n = self.w13_weight.size(1) + gateup_output = torch.empty( + (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16 + ) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m + ) + del gateup_input + del gateup_input_fp8 + + # Act + down_input = torch.empty( + ( + gateup_output.shape[0], + gateup_output.shape[1], + gateup_output.shape[2] // 2, + ), + device=hidden_states_device, + dtype=self.fp8_dtype, + ) + scale_block_size = 128 + down_input_scale = torch.empty( + ( + gateup_output.shape[0], + gateup_output.shape[1], + gateup_output.shape[2] // 2 // scale_block_size, + ), + device=hidden_states_device, + dtype=torch.float32, + ) + silu_and_mul_masked_post_quant_fwd( + gateup_output, + down_input, + down_input_scale, + scale_block_size, + masked_m, + ) + del gateup_output + + # GroupGemm-1 + n = self.w2_weight.size(1) + down_input_fp8 = ( + down_input, + deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale), + ) + down_output = torch.empty( + (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16 + ) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m + ) + del down_input + del down_input_fp8 + + # PostReorder + output = torch.empty( + hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device + ) + post_reorder_triton_kernel[(hidden_states_shape[0],)]( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + self.start_expert_id, + self.end_expert_id, + self.top_k, + hidden_states_shape[1], + m_max * self.start_expert_id, + BLOCK_SIZE=512, + ) + return output + + def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + assert self.quant_method is not None + hidden_states_shape = hidden_states.shape + hidden_states_dtype = hidden_states.dtype + hidden_states_device = hidden_states.device if self.grouped_gemm_runner is None: self.grouped_gemm_runner = GroupedGemmRunner( hidden_states.device, @@ -249,6 +420,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): renormalize=self.renormalize, topk_group=self.topk_group, num_expert_group=self.num_expert_group, + num_fused_shared_experts=self.num_fused_shared_experts, correction_bias=self.correction_bias, custom_routing_function=self.custom_routing_function, routed_scaling_factor=self.routed_scaling_factor, @@ -440,6 +612,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): self.end_expert_id, self.top_k, hidden_states_shape[1], + 0, BLOCK_SIZE=512, ) return output diff --git a/python/sglang/test/test_block_fp8_ep.py b/python/sglang/test/test_block_fp8_ep.py index a89ee1fe384..bd735edbdc5 100644 --- a/python/sglang/test/test_block_fp8_ep.py +++ b/python/sglang/test/test_block_fp8_ep.py @@ -182,6 +182,7 @@ def ep_moe( end_expert_id, top_k, hidden_states.size(1), + 0, BLOCK_SIZE=512, ) return output diff --git a/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py b/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py index 701fb8c5b8a..078e2c13185 100644 --- a/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py +++ b/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py @@ -77,6 +77,7 @@ def run_triton(): end_expert_id, topk, hidden_size, + 0, block_size, ) diff --git a/sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py b/sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py index d20e9c9a6d9..1891735591c 100644 --- a/sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py +++ b/sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py @@ -85,6 +85,7 @@ def run_triton_kernel( end_expert_id, topk, hidden_size, + 0, block_size, ) return output