Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 24 additions & 27 deletions python/sglang/srt/layers/moe/cutlass_moe.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
"""CUTLASS based Fused MoE kernels."""

import functools
import json
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch

from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
Expand Down Expand Up @@ -46,6 +40,8 @@ def cutlass_fused_experts_fp8(
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
a_sf_layout: torch.Tensor,
w_sf_layout: torch.Tensor,
use_fp8_blockscale: bool = True,
) -> torch.Tensor:
"""Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations.
Expand All @@ -65,14 +61,12 @@ def cutlass_fused_experts_fp8(
number of tokens and `k` is the hidden size. Expected dtype: `torch.half`
or `torch.bfloat16`.
w1_q (torch.Tensor): Pre-quantized FP8 weight tensor for the first GEMM
(up-projection part of SwiGLU). Expected shape: `(E, k, n*2)`, where
(up-projection part of SwiGLU). Expected shape: `(E, n*2, k)`, where
`E` is the number of experts, `k` is the hidden size, and `n*2` is the
intermediate size (`I`). Expected dtype: `torch.float8_e4m3fn`.
Note: This shape implies weights are stored as (num_experts, hidden_size, intermediate_size).
w2_q (torch.Tensor): Pre-quantized FP8 weight tensor for the second GEMM
(down-projection). Expected shape: `(E, n, k)`, where `n` is half the
(down-projection). Expected shape: `(E, k, n)`, where `n` is half the
intermediate size (`I // 2`). Expected dtype: `torch.float8_e4m3fn`.
Note: This shape implies weights are stored as (num_experts, intermediate_size // 2, hidden_size).
w1_scale (torch.Tensor): Scales corresponding to `w1_q` (per-block scales).
Shape: `(E, num_blocks_n, num_blocks_k)`. Dtype: `torch.float32`.
w2_scale (torch.Tensor): Scales corresponding to `w2_q` (per-block scales).
Expand All @@ -99,6 +93,8 @@ def cutlass_fused_experts_fp8(
out_ptrs (torch.Tensor): Pointers container for calculating offsets of the output activations for each expert.
a_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
a_sf_layout (torch.Tensor): Layout tensor for activation scales.
w_sf_layout (torch.Tensor): Layout tensor for weight scales.
use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with
block scaling. Currently, only `True` is supported. Defaults to `True`.

Expand All @@ -113,10 +109,9 @@ def cutlass_fused_experts_fp8(
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_q.dtype == torch.float8_e4m3fn
assert w2_q.dtype == torch.float8_e4m3fn
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
assert a.shape[1] == w1_q.shape[2], "Hidden size mismatch w1"
assert w1_q.shape[1] == w2_q.shape[2] * 2, "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
Expand All @@ -128,17 +123,19 @@ def cutlass_fused_experts_fp8(

out_dtype = a.dtype
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(1)
n = w2_q.size(1)
m, k = a.shape
n = w2_q.size(2)

topk = topk_ids.size(1)

a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
device = a_q.device

a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
numel = topk_ids.numel()
maps = torch.empty(numel * 2, dtype=torch.int32, device=device)

a_map = maps[:numel]
c_map = maps[numel:]

prepare_moe_input(
topk_ids,
Expand All @@ -155,11 +152,11 @@ def cutlass_fused_experts_fp8(
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))

c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)

a_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)
w_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)
rows = m * topk
c_buffer = torch.empty((rows * (n * 3 + k)), device=device, dtype=out_dtype)
c1 = c_buffer[: rows * n * 2].view(rows, -1)
c2 = c_buffer[rows * n * 2 : rows * (2 * n + k)].view(rows, -1)
intermediate = c_buffer[rows * (2 * n + k) :].view(rows, -1)

fp8_blockwise_scaled_grouped_mm(
c1,
Expand All @@ -182,10 +179,9 @@ def cutlass_fused_experts_fp8(
workspace,
)

intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
silu_and_mul(c1, intermediate)

intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
intermediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)

fp8_blockwise_scaled_grouped_mm(
c2,
Expand All @@ -194,7 +190,7 @@ def cutlass_fused_experts_fp8(
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
intemediate_q,
intermediate_q,
w2_q,
a2_scale,
w2_scale,
Expand All @@ -209,7 +205,8 @@ def cutlass_fused_experts_fp8(
)

result = torch.empty((m, k), device=device, dtype=out_dtype)
return apply_shuffle_mul_sum(c2, result, c_map, topk_weights)
apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype))
return result


FLOAT4_E2M1_MAX = 6.0
Expand Down
16 changes: 12 additions & 4 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,12 @@ def create_weights(
self.problem_sizes2 = torch.empty(
num_experts, 3, device=w13_weight.device, dtype=torch.int32
)
self.a_sf_layout = torch.empty(
num_experts, 5, device=w13_weight.device, dtype=torch.int32
)
self.w_sf_layout = torch.empty(
num_experts, 5, device=w13_weight.device, dtype=torch.int32
)

else:
# Allocate 2 scales for w1 and w3 respectively.
Expand Down Expand Up @@ -1058,10 +1064,10 @@ def apply(

return cutlass_fused_experts_fp8(
x,
layer.w13_weight.transpose(1, 2),
layer.w2_weight.transpose(1, 2),
layer.w13_weight_scale_inv.transpose(1, 2),
layer.w2_weight_scale_inv.transpose(1, 2),
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
topk_weights,
topk_ids,
self.ab_strides1,
Expand All @@ -1077,6 +1083,8 @@ def apply(
self.expert_offsets,
self.problem_sizes1,
self.problem_sizes2,
self.a_sf_layout,
self.w_sf_layout,
use_fp8_blockscale=True,
)
# Expert fusion with FP8 quantization
Expand Down
Loading