Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 178 additions & 1 deletion python/sglang/srt/layers/moe/cutlass_moe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Cutlass MoE kernel."""
"""CUTLASS based Fused MoE kernels."""

import functools
import json
Expand All @@ -14,8 +14,10 @@
if _is_cuda:
import sgl_kernel
from sgl_kernel import (
cutlass_fp4_group_mm,
fp8_blockwise_scaled_grouped_mm,
prepare_moe_input,
scaled_fp4_experts_quant,
silu_and_mul,
)

Expand Down Expand Up @@ -205,3 +207,178 @@ def cutlass_fused_experts(
return (
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
).sum(dim=1)


FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = 448.0


def cutlass_moe_fp4(
a: torch.Tensor,
a1_gscale: torch.Tensor,
w1_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alphas: torch.Tensor,
a2_gscale: torch.Tensor,
w2_fp4: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alphas: torch.Tensor,
ab_strides_13: torch.Tensor,
ab_strides_2: torch.Tensor,
c_strides_13: torch.Tensor,
c_strides_2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
):
"""
MoE implementation for FP4 Inputs

# Gemm 1
a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32)
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
(Note: `n` is the up projection output dim, `k` is the input dim in
full precision)
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4)

# Gemm 2
a2_gscale: Activation scale per expert: [e]
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3

Strides for activations, weights and output in logical number of elements.
The activations & output stride is the number of elements to the next row.
The weights stride is the number of elements to the next row per expert.
For example, if the weight is [e, n, k], then the b_stride is a tensor of
shape [e] with each element being k. Similarly for activations, if the
shape is [m, k], then the a_stride has shape [e] with each value k.
Similarly for output, if the output is [m, n], then the c_stride is a
tensor of shape [e] with each element being k.

Note: cutlass_fp4_group_mm is designed to accept the strides of
activations and weights to be the same, so it is passed in as a single
tensor.
ab_strides_13: [e] dtype: int64 [Gemm 1: Activation / Weight strides]
ab_strides_2: [e] dtype: int64 [Gemm 2: Activation / Weight strides]
c_strides_13: [e] dtype: int64 [Gemm 1: Output Strides]
c_strides_2: [e] dtype: int64 [Gemm 1: Output Strides]

topk_weights: [m, topk] dtype: float8
topk_ids: [m, topk] dtype: float8

m, n, k: Unquantized weight shapes, dtype: int
e: number of experts for the current rank, dtype: int
assumes that topk < k < n to satisfy - up/down projection expectations.
"""
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
assert (
w1_fp4.ndim == 3
and w2_fp4.ndim == 3
and w1_blockscale.ndim == 3
and w2_blockscale.ndim == 3
), "All Weights must be of rank 3 for cutlass_moe_fp4"
m_a, k_a = a.shape
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape

assert e_w1 == e_w2 and e_w1 == e, (
"Number of experts must match",
" between weights.",
)
assert (
k_a // 2 == half_k_w1 and k == k_w2
), "Hidden size mismatch between a, w1 and w2"
assert nx2_w1 == n * 2 and half_n_w2 == n // 2, "mismatch in " "expected `n`"
assert m == m_a, "input shape mismatch"
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert (
topk_weights.shape[0] == m and topk_ids.shape[0] == m
), "topk must be provided for each row of a"

out_dtype = a.dtype
num_topk = topk_ids.shape[1]

expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,2n,k))
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,n,k))
problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)

a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)

# problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements.
blockscale_offsets = torch.empty(e + 1, dtype=torch.int32, device=device)
prepare_moe_input(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
e,
n,
k,
blockscale_offsets,
)

rep_a_fp4, rep_a_blockscale = scaled_fp4_experts_quant(
a, a1_gscale, expert_offsets, blockscale_offsets, num_topk, expert_map=a_map
)

c1 = cutlass_fp4_group_mm(
rep_a_fp4,
w1_fp4,
rep_a_blockscale,
w1_blockscale,
w1_alphas,
ab_strides_13,
c_strides_13,
problem_sizes1,
expert_offsets[:-1],
blockscale_offsets[:-1],
out_dtype,
device,
)
del rep_a_fp4, rep_a_blockscale
# hidden size dimension is split to one halfpytho sized tensor.
intermediate = torch.empty(
(m * num_topk, w1_fp4.shape[1] // 2), device=device, dtype=out_dtype
)

silu_and_mul(c1, intermediate)

int_fp4, int_blockscale = scaled_fp4_experts_quant(
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk
)
c2 = cutlass_fp4_group_mm(
int_fp4,
w2_fp4,
int_blockscale,
w2_blockscale,
w2_alphas,
ab_strides_2,
c_strides_2,
problem_sizes2,
expert_offsets[:-1],
blockscale_offsets[:-1],
out_dtype,
device,
)
del int_fp4, int_blockscale
out = (
c2[c_map].view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half()
).sum(dim=1)
return out.to(dtype=out_dtype)
Loading
Loading