diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index 00b7adf778e..3426ba7fcee 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -1,11 +1,5 @@ """CUTLASS based Fused MoE kernels.""" -import functools -import json -import logging -import os -from typing import Any, Callable, Dict, List, Optional, Tuple - import torch from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams @@ -46,6 +40,8 @@ def cutlass_fused_experts_fp8( expert_offsets: torch.Tensor, problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, + a_sf_layout: torch.Tensor, + w_sf_layout: torch.Tensor, use_fp8_blockscale: bool = True, ) -> torch.Tensor: """Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations. @@ -65,14 +61,12 @@ def cutlass_fused_experts_fp8( number of tokens and `k` is the hidden size. Expected dtype: `torch.half` or `torch.bfloat16`. w1_q (torch.Tensor): Pre-quantized FP8 weight tensor for the first GEMM - (up-projection part of SwiGLU). Expected shape: `(E, k, n*2)`, where + (up-projection part of SwiGLU). Expected shape: `(E, n*2, k)`, where `E` is the number of experts, `k` is the hidden size, and `n*2` is the intermediate size (`I`). Expected dtype: `torch.float8_e4m3fn`. - Note: This shape implies weights are stored as (num_experts, hidden_size, intermediate_size). w2_q (torch.Tensor): Pre-quantized FP8 weight tensor for the second GEMM - (down-projection). Expected shape: `(E, n, k)`, where `n` is half the + (down-projection). Expected shape: `(E, k, n)`, where `n` is half the intermediate size (`I // 2`). Expected dtype: `torch.float8_e4m3fn`. - Note: This shape implies weights are stored as (num_experts, intermediate_size // 2, hidden_size). w1_scale (torch.Tensor): Scales corresponding to `w1_q` (per-block scales). Shape: `(E, num_blocks_n, num_blocks_k)`. Dtype: `torch.float32`. w2_scale (torch.Tensor): Scales corresponding to `w2_q` (per-block scales). @@ -99,6 +93,8 @@ def cutlass_fused_experts_fp8( out_ptrs (torch.Tensor): Pointers container for calculating offsets of the output activations for each expert. a_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert. b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert. + a_sf_layout (torch.Tensor): Layout tensor for activation scales. + w_sf_layout (torch.Tensor): Layout tensor for weight scales. use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with block scaling. Currently, only `True` is supported. Defaults to `True`. @@ -113,10 +109,9 @@ def cutlass_fused_experts_fp8( assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert w1_q.dtype == torch.float8_e4m3fn assert w2_q.dtype == torch.float8_e4m3fn - assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" - assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" + assert a.shape[1] == w1_q.shape[2], "Hidden size mismatch w1" + assert w1_q.shape[1] == w2_q.shape[2] * 2, "Hidden size mismatch w2" assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" - assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype" @@ -128,17 +123,19 @@ def cutlass_fused_experts_fp8( out_dtype = a.dtype num_experts = w1_q.size(0) - m = a.size(0) - k = w1_q.size(1) - n = w2_q.size(1) + m, k = a.shape + n = w2_q.size(2) topk = topk_ids.size(1) a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128) device = a_q.device - a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + numel = topk_ids.numel() + maps = torch.empty(numel * 2, dtype=torch.int32, device=device) + + a_map = maps[:numel] + c_map = maps[numel:] prepare_moe_input( topk_ids, @@ -155,11 +152,11 @@ def cutlass_fused_experts_fp8( rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k)) rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128))) - c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) - c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) - - a_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int) - w_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int) + rows = m * topk + c_buffer = torch.empty((rows * (n * 3 + k)), device=device, dtype=out_dtype) + c1 = c_buffer[: rows * n * 2].view(rows, -1) + c2 = c_buffer[rows * n * 2 : rows * (2 * n + k)].view(rows, -1) + intermediate = c_buffer[rows * (2 * n + k) :].view(rows, -1) fp8_blockwise_scaled_grouped_mm( c1, @@ -182,10 +179,9 @@ def cutlass_fused_experts_fp8( workspace, ) - intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) silu_and_mul(c1, intermediate) - intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128) + intermediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128) fp8_blockwise_scaled_grouped_mm( c2, @@ -194,7 +190,7 @@ def cutlass_fused_experts_fp8( out_ptrs, a_scales_ptrs, b_scales_ptrs, - intemediate_q, + intermediate_q, w2_q, a2_scale, w2_scale, @@ -209,7 +205,8 @@ def cutlass_fused_experts_fp8( ) result = torch.empty((m, k), device=device, dtype=out_dtype) - return apply_shuffle_mul_sum(c2, result, c_map, topk_weights) + apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype)) + return result FLOAT4_E2M1_MAX = 6.0 diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 4d886de9181..b2e5c13e93c 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -675,6 +675,12 @@ def create_weights( self.problem_sizes2 = torch.empty( num_experts, 3, device=w13_weight.device, dtype=torch.int32 ) + self.a_sf_layout = torch.empty( + num_experts, 5, device=w13_weight.device, dtype=torch.int32 + ) + self.w_sf_layout = torch.empty( + num_experts, 5, device=w13_weight.device, dtype=torch.int32 + ) else: # Allocate 2 scales for w1 and w3 respectively. @@ -1058,10 +1064,10 @@ def apply( return cutlass_fused_experts_fp8( x, - layer.w13_weight.transpose(1, 2), - layer.w2_weight.transpose(1, 2), - layer.w13_weight_scale_inv.transpose(1, 2), - layer.w2_weight_scale_inv.transpose(1, 2), + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale_inv, + layer.w2_weight_scale_inv, topk_weights, topk_ids, self.ab_strides1, @@ -1077,6 +1083,8 @@ def apply( self.expert_offsets, self.problem_sizes1, self.problem_sizes2, + self.a_sf_layout, + self.w_sf_layout, use_fp8_blockscale=True, ) # Expert fusion with FP8 quantization