diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index 262f1ae3937..d5ed74b2fbf 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -226,6 +226,58 @@ def cutlass_fused_experts_fp8( FLOAT8_E4M3_MAX = 448.0 +def cutlass_moe_fp8( + a: torch.Tensor, + a_scale: torch.Tensor, + w: torch.Tensor, + w_scale: torch.Tensor, + c: torch.Tensor, + m_indices: torch.Tensor, +) -> None: + '''Performs EP MoE computation using CUTLASS-like kernels with per-block-fp8-quant weights and per-token-group-fp8-quant activations. + ''' + device = a.device + num_experts, k_g, n_g = w.shape + layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32) + layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32) + a_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) + b_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) + out_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) + a_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) + b_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) + workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8) + a_strides = torch.full((num_experts,), a.stride(0), device=device, dtype=torch.int64) + c_strides = torch.full((num_experts,), c.stride(0), device=device, dtype=torch.int64) + m_tensor = m_indices[1:] - m_indices[:-1] + n_tensor = torch.full_like(m_tensor, fill_value=n_g) + k_tensor = torch.full_like(m_tensor, fill_value=k_g) + problem_sizes = torch.stack([m_tensor, n_tensor, k_tensor], dim=1) + # (E, K, N):(K*N, N, 1) -> (E, N, K):(N*K, 1, N) -> (E, N, K):(N*K, K, 1) + # w_scale = w_scale.transpose(1, 2).contiguous() + # TODO: a_scale + + fp8_blockwise_scaled_grouped_mm( + c, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a, + w, + a_scale, + w_scale, + a_strides, + a_strides, + c_strides, + layout_sfa, + layout_sfb, + problem_sizes, + m_indices[:-1], + workspace, + ) + + def cutlass_moe_fp4( a: torch.Tensor, a1_gscale: torch.Tensor, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 8e99d212d87..bc06eb862de 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -43,6 +43,7 @@ _is_hip = is_hip() _is_npu = is_npu() _is_fp8_fnuz = is_fp8_fnuz() +_is_cuda = is_cuda() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if not (_is_npu or _is_hip): @@ -53,6 +54,9 @@ from aiter.fused_moe import fused_moe from aiter.ops.shuffle import shuffle_weight +if _is_cuda: + from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp8 + logger = logging.getLogger(__name__) @@ -423,6 +427,7 @@ def __init__( else self.w2_weight_scale ), ) + self.cutlass_moe_fp8_supported = cutlass_fp8_supported() and (torch.cuda.get_device_capability(torch.cuda.current_device())[0] == 9) and (torch.version.cuda >= "12.3") def forward( self, @@ -521,6 +526,117 @@ def forward_aiter( ), expert_mask=self.expert_mask, ) + + def forward_cutlass_moe(self, + hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], + topk_idx, + topk_weights, + num_recv_tokens_per_expert: List[int] + ): + hidden_states_fp8, hidden_states_scale = hidden_states_fp8 + assert self.quant_method is not None + assert self.activation == "silu" + if num_recv_tokens_per_expert is None: + return hidden_states_fp8.bfloat16() + all_tokens = sum(num_recv_tokens_per_expert) + if all_tokens <= 0: + return hidden_states_fp8.bfloat16() + M, K = hidden_states_fp8.size() + N = self.w13_weight.size(1) + scale_block_size = 128 + + hidden_states_fp8_shape = hidden_states_fp8.shape + hidden_states_fp8_device = hidden_states_fp8.device + hidden_states_fp8_dtype = hidden_states_fp8.dtype + + gateup_input_fp8 = torch.empty( + (all_tokens, K), + device=hidden_states_fp8_device, + dtype=hidden_states_fp8_dtype) + gateup_input_scale = torch.empty( + (all_tokens, K // 128), + device=hidden_states_fp8_device, + dtype=torch.float32) + m_indices = torch.empty( + all_tokens, device=hidden_states_fp8_device, dtype=torch.int32 + ) + output_index = torch.empty_like(topk_idx) + + num_recv_tokens_per_expert_gpu = torch.tensor( + num_recv_tokens_per_expert, + dtype=torch.int32, + pin_memory=True, + device="cpu", + ).cuda(non_blocking=True) + expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu) + + ep_scatter( + hidden_states_fp8, + hidden_states_scale, + topk_idx, + num_recv_tokens_per_expert_gpu, + expert_start_loc, + gateup_input_fp8, + gateup_input_scale, + m_indices, + output_index, + scale_ue8m0=False, + ) + dispose_tensor(hidden_states_fp8) + + gateup_output = torch.empty( + (all_tokens, N), + device=hidden_states_fp8_device, + dtype=torch.bfloat16, + ) + gateup_input_scale = tma_align_input_scale(gateup_input_scale) + + cutlass_moe_fp8(a=gateup_input_fp8, + a_scale=gateup_input_scale, + w=self.w13_weight_fp8[0], + w_scale=self.w13_weight_fp8[1], + c=gateup_output, + m_indices=m_indices) + del gateup_input_fp8, gateup_input_scale + + down_input = torch.empty( + ( + all_tokens, + N // 2, + ), + device=hidden_states_fp8_device, + dtype=torch.bfloat16, + ) + silu_and_mul(gateup_output.view(-1, N), down_input) + del gateup_output + down_output = torch.empty( + (all_tokens, K), + device=hidden_states_fp8_device, + dtype=torch.bfloat16, + ) + down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8( + down_input, + scale_block_size, + ) + del down_input + down_input_scale = tma_align_input_scale(down_input_scale) + cutlass_moe_fp8(a=down_input_fp8, + a_scale=down_input_scale, + w=self.w2_weight_fp8[0], + w_scale=self.w2_weight_fp8[1], + c=down_output, + m_indices=m_indices) + del down_input_fp8, down_input_scale + + gather_out = torch.empty( + hidden_states_fp8_shape, + device=hidden_states_fp8_device, + dtype=torch.bfloat16, + ) + ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out) + + return gather_out + def forward_deepgemm_contiguous( self,