Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions python/sglang/srt/layers/moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,58 @@ def cutlass_fused_experts_fp8(
FLOAT8_E4M3_MAX = 448.0


def cutlass_moe_fp8(
a: torch.Tensor,
a_scale: torch.Tensor,
w: torch.Tensor,
w_scale: torch.Tensor,
c: torch.Tensor,
m_indices: torch.Tensor,
) -> None:
'''Performs EP MoE computation using CUTLASS-like kernels with per-block-fp8-quant weights and per-token-group-fp8-quant activations.
'''
device = a.device
num_experts, k_g, n_g = w.shape
layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
a_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
b_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
out_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
a_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
b_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Allocating a 1GB workspace tensor with torch.empty is risky and can easily lead to Out-of-Memory (OOM) errors, especially in environments with limited memory. This size is hardcoded and might be excessive for the actual needs of the fp8_blockwise_scaled_grouped_mm kernel.

Consider calculating the required workspace size dynamically based on the problem size or using a much smaller, more reasonable default size. For reference, the test file test_cutlass_moe.py allocates a workspace of about 7MB, which is significantly smaller. A smaller buffer or dynamic allocation would be safer and more memory-efficient.

a_strides = torch.full((num_experts,), a.stride(0), device=device, dtype=torch.int64)
c_strides = torch.full((num_experts,), c.stride(0), device=device, dtype=torch.int64)
m_tensor = m_indices[1:] - m_indices[:-1]
n_tensor = torch.full_like(m_tensor, fill_value=n_g)
k_tensor = torch.full_like(m_tensor, fill_value=k_g)
problem_sizes = torch.stack([m_tensor, n_tensor, k_tensor], dim=1)
# (E, K, N):(K*N, N, 1) -> (E, N, K):(N*K, 1, N) -> (E, N, K):(N*K, K, 1)
# w_scale = w_scale.transpose(1, 2).contiguous()
# TODO: a_scale

fp8_blockwise_scaled_grouped_mm(
c,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
w,
a_scale,
w_scale,
a_strides,
a_strides,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The stride_b argument to fp8_blockwise_scaled_grouped_mm is incorrectly set to a_strides. The stride for the weight tensor w should be used instead, as a and w have different shapes and layouts.

You should define b_strides based on w's dimensions and pass it here. For example:

b_strides = torch.full((num_experts,), w.stride(1), device=device, dtype=torch.int64)

This should be defined before the fp8_blockwise_scaled_grouped_mm call.

        b_strides = torch.full((num_experts,), w.stride(1), device=device, dtype=torch.int64)
        fp8_blockwise_scaled_grouped_mm(
        ...
        a_strides,
        b_strides,

c_strides,
layout_sfa,
layout_sfb,
problem_sizes,
m_indices[:-1],
workspace,
)


def cutlass_moe_fp4(
a: torch.Tensor,
a1_gscale: torch.Tensor,
Expand Down
125 changes: 123 additions & 2 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
sglang_per_token_group_quant_fp8,
sglang_per_token_quant_fp8,
)
from sglang.srt.layers.quantization.fp8_utils import cutlass_fp8_supported
from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
Expand All @@ -51,16 +52,17 @@
get_bool_env_var,
is_hip,
is_npu,
is_cuda,
)

_is_hip = is_hip()
_is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz()
_is_cuda = is_cuda()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip

if not _is_npu:
from sgl_kernel import silu_and_mul

from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe

if _is_hip:
Expand All @@ -71,6 +73,9 @@
from aiter.fused_moe import fused_moe
from aiter.ops.shuffle import shuffle_weight

if _is_cuda:
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp8

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -991,6 +996,7 @@ def __init__(
else self.w2_weight_scale
),
)
self.cutlass_moe_fp8_supported = cutlass_fp8_supported() and (torch.cuda.get_device_capability(torch.cuda.current_device())[0] == 9) and (torch.version.cuda >= "12.3")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This condition is quite complex and long, which affects readability. It can be simplified.

The expression cutlass_fp8_supported() and (torch.cuda.get_device_capability(torch.cuda.current_device())[0] == 9) and (torch.version.cuda >= "12.3") appears to be redundant. The cutlass_fp8_supported() function already checks for device capability.

You could simplify this by extracting the capability and version checks and combining them into a more readable expression. For example:

major, _ = torch.cuda.get_device_capability()
self.cutlass_moe_fp8_supported = (
    cutlass_fp8_supported() and major == 9 and torch.version.cuda >= "12.3"
)

This assumes you only want to support Hopper (SM 90) with CUDA 12.3+, which seems to be the intent.

        major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
        self.cutlass_moe_fp8_supported = (
            cutlass_fp8_supported() and major == 9 and torch.version.cuda >= "12.3"
        )


def forward(
self,
Expand All @@ -1011,7 +1017,11 @@ def forward(
forward_batch.is_extend_in_batch
)
if resolved_deepep_mode == DeepEPMode.normal:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
if get_bool_env_var("SGLANG_CUTLASS_MOE") and self.cutlass_moe_fp8_supported:
return self.forward_cutlass_moe(
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
)
elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm_contiguous(
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
)
Expand Down Expand Up @@ -1171,6 +1181,117 @@ def forward_aiter(
),
expert_mask=self.expert_mask,
)

def forward_cutlass_moe(self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter hidden_states_fp8 is a tuple of two tensors, but its name suggests it's a single tensor. This is confusing because the next line unpacks it into hidden_states_fp8, hidden_states_scale, reusing the name hidden_states_fp8.

To improve clarity, I suggest renaming the parameter to reflect that it's a tuple, for example, hidden_states_fp8_and_scale.

        hidden_states_fp8_and_scale: Tuple[torch.Tensor, torch.Tensor],

topk_idx,
topk_weights,
num_recv_tokens_per_expert: List[int]
):
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
assert self.quant_method is not None
assert self.activation == "silu"
if num_recv_tokens_per_expert is None:
return hidden_states_fp8.bfloat16()
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states_fp8.bfloat16()
M, K = hidden_states_fp8.size()
N = self.w13_weight.size(1)
scale_block_size = 128

hidden_states_fp8_shape = hidden_states_fp8.shape
hidden_states_fp8_device = hidden_states_fp8.device
hidden_states_fp8_dtype = hidden_states_fp8.dtype

gateup_input_fp8 = torch.empty(
(all_tokens, K),
device=hidden_states_fp8_device,
dtype=hidden_states_fp8_dtype)
gateup_input_scale = torch.empty(
(all_tokens, K // 128),
device=hidden_states_fp8_device,
dtype=torch.float32)
m_indices = torch.empty(
all_tokens, device=hidden_states_fp8_device, dtype=torch.int32
)
output_index = torch.empty_like(topk_idx)

num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)

ep_scatter(
hidden_states_fp8,
hidden_states_scale,
topk_idx,
num_recv_tokens_per_expert_gpu,
expert_start_loc,
gateup_input_fp8,
gateup_input_scale,
m_indices,
output_index,
scale_ue8m0=False,
)
dispose_tensor(hidden_states_fp8)

gateup_output = torch.empty(
(all_tokens, N),
device=hidden_states_fp8_device,
dtype=torch.bfloat16,
)
gateup_input_scale = tma_align_input_scale(gateup_input_scale)

cutlass_moe_fp8(a=gateup_input_fp8,
a_scale=gateup_input_scale,
w=self.w13_weight_fp8[0],
w_scale=self.w13_weight_fp8[1],
c=gateup_output,
m_indices=m_indices)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The m_indices tensor passed to cutlass_moe_fp8 is incorrect. The cutlass_moe_fp8 function expects m_indices to be the cumulative sum of token counts per expert to calculate problem sizes. However, the m_indices tensor passed here is the output of ep_scatter, which contains expert IDs for each token.

This will result in incorrect calculations within cutlass_moe_fp8 and likely lead to errors or wrong results.

You should compute the cumulative sum of tokens from num_recv_tokens_per_expert_gpu and pass that to cutlass_moe_fp8. For example:

m_indices_for_cutlass = torch.nn.functional.pad(
    torch.cumsum(num_recv_tokens_per_expert_gpu, dim=0, dtype=torch.int32), (1, 0)
)

Then, pass m_indices_for_cutlass to the function. The same applies to the second call to cutlass_moe_fp8.

        m_indices = torch.empty(
            all_tokens, device=hidden_states_fp8_device, dtype=torch.int32
        )
        output_index = torch.empty_like(topk_idx)

        num_recv_tokens_per_expert_gpu = torch.tensor(
            num_recv_tokens_per_expert,
            dtype=torch.int32,
            pin_memory=True,
            device="cpu",
        ).cuda(non_blocking=True)
        expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)

        ep_scatter(
            hidden_states_fp8,
            hidden_states_scale,
            topk_idx,
            num_recv_tokens_per_expert_gpu,
            expert_start_loc,
            gateup_input_fp8,
            gateup_input_scale,
            m_indices,
            output_index,
            scale_ue8m0=False,
        )
        
        m_indices_for_cutlass = torch.nn.functional.pad(
            torch.cumsum(num_recv_tokens_per_expert_gpu, dim=0, dtype=torch.int32), (1, 0)
        )

        cutlass_moe_fp8(a=gateup_input_fp8,
                        a_scale=gateup_input_scale,
                        w=self.w13_weight_fp8[0],
                        w_scale=self.w13_weight_fp8[1],
                        c=gateup_output,
                        m_indices=m_indices_for_cutlass)

del gateup_input_fp8, gateup_input_scale

down_input = torch.empty(
(
all_tokens,
N // 2,
),
device=hidden_states_fp8_device,
dtype=torch.bfloat16,
)
silu_and_mul(gateup_output.view(-1, N), down_input)
del gateup_output
down_output = torch.empty(
(all_tokens, K),
device=hidden_states_fp8_device,
dtype=torch.bfloat16,
)
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input,
scale_block_size,
)
del down_input
down_input_scale = tma_align_input_scale(down_input_scale)
cutlass_moe_fp8(a=down_input_fp8,
a_scale=down_input_scale,
w=self.w2_weight_fp8[0],
w_scale=self.w2_weight_fp8[1],
c=down_output,
m_indices=m_indices)
del down_input_fp8, down_input_scale

gather_out = torch.empty(
hidden_states_fp8_shape,
device=hidden_states_fp8_device,
dtype=torch.bfloat16,
)
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)

return gather_out


def forward_deepgemm_contiguous(
self,
Expand Down
Loading