Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
silu_and_mul,
)

from sglang.srt.distributed import get_moe_expert_parallel_world_size
from sglang.srt.layers.moe.ep_moe.kernels import (
cutlass_w4_run_moe_ep_preproess,
deepep_ll_get_cutlass_w4a8_moe_mm_data,
deepep_permute_triton_kernel,
deepep_post_reorder_triton_kernel,
deepep_run_moe_deep_preprocess,
post_reorder_triton_kernel_for_cutlass_moe,
pre_reorder_triton_kernel_for_cutlass_moe,
run_moe_ep_preproess,
post_reorder_for_cutlass_moe,
pre_reorder_for_cutlass_moe,
silu_and_mul_masked_post_per_tensor_quant_fwd,
silu_mul_static_tensorwise_quant_for_cutlass_moe,
)


Expand All @@ -44,6 +46,7 @@ def cutlass_w4a8_moe(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
"""
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
Expand Down Expand Up @@ -108,11 +111,11 @@ def cutlass_w4a8_moe(
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"

device = a.device
topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)
if get_moe_expert_parallel_world_size() > 1:
topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)

_, src2dst, _ = run_moe_ep_preproess(
src2dst = cutlass_w4_run_moe_ep_preproess(
topk_ids,
num_local_experts,
)

gateup_input = torch.empty(
Expand All @@ -121,16 +124,16 @@ def cutlass_w4a8_moe(
dtype=torch.float8_e4m3fn,
)

pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
pre_reorder_for_cutlass_moe(
a,
gateup_input,
src2dst,
topk_ids,
a1_scale,
num_local_experts,
topk,
m,
k,
BLOCK_SIZE=512,
)

# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
Expand All @@ -151,7 +154,7 @@ def cutlass_w4a8_moe(
)

c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
c2 = torch.empty((m * topk, k), device=device, dtype=torch.bfloat16)

cutlass_w4a8_moe_mm(
c1,
Expand All @@ -169,13 +172,12 @@ def cutlass_w4a8_moe(
topk,
)

intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
silu_and_mul(c1, intermediate)

intermediate_q = torch.empty(
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
(m * topk, n), dtype=torch.float8_e4m3fn, device=device
)
silu_mul_static_tensorwise_quant_for_cutlass_moe(
c1, intermediate_q, a2_scale.float(), expert_offsets[-1:], m * topk, n
)
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)

cutlass_w4a8_moe_mm(
c2,
Expand All @@ -194,16 +196,18 @@ def cutlass_w4a8_moe(
)

output = torch.empty_like(a)
post_reorder_triton_kernel_for_cutlass_moe[(m,)](

post_reorder_for_cutlass_moe(
c2,
output,
src2dst,
topk_ids,
topk_weights,
topk,
num_local_experts,
topk,
m,
k,
BLOCK_SIZE=512,
routed_scaling_factor,
)
return output

Expand Down
Loading
Loading