-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[WIP] feat: support cutlass_moe_fp8 for deepepmoe #8273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -216,6 +216,58 @@ def cutlass_fused_experts_fp8( | |
| FLOAT8_E4M3_MAX = 448.0 | ||
|
|
||
|
|
||
| def cutlass_moe_fp8( | ||
| a: torch.Tensor, | ||
| a_scale: torch.Tensor, | ||
| w: torch.Tensor, | ||
| w_scale: torch.Tensor, | ||
| c: torch.Tensor, | ||
| m_indices: torch.Tensor, | ||
| ) -> None: | ||
| '''Performs EP MoE computation using CUTLASS-like kernels with per-block-fp8-quant weights and per-token-group-fp8-quant activations. | ||
| ''' | ||
| device = a.device | ||
| num_experts, k_g, n_g = w.shape | ||
| layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32) | ||
| layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32) | ||
| a_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) | ||
| b_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) | ||
| out_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) | ||
| a_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) | ||
| b_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64) | ||
| workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8) | ||
| a_strides = torch.full((num_experts,), a.stride(0), device=device, dtype=torch.int64) | ||
| c_strides = torch.full((num_experts,), c.stride(0), device=device, dtype=torch.int64) | ||
| m_tensor = m_indices[1:] - m_indices[:-1] | ||
| n_tensor = torch.full_like(m_tensor, fill_value=n_g) | ||
| k_tensor = torch.full_like(m_tensor, fill_value=k_g) | ||
| problem_sizes = torch.stack([m_tensor, n_tensor, k_tensor], dim=1) | ||
| # (E, K, N):(K*N, N, 1) -> (E, N, K):(N*K, 1, N) -> (E, N, K):(N*K, K, 1) | ||
| # w_scale = w_scale.transpose(1, 2).contiguous() | ||
| # TODO: a_scale | ||
|
|
||
| fp8_blockwise_scaled_grouped_mm( | ||
| c, | ||
| a_ptrs, | ||
| b_ptrs, | ||
| out_ptrs, | ||
| a_scales_ptrs, | ||
| b_scales_ptrs, | ||
| a, | ||
| w, | ||
| a_scale, | ||
| w_scale, | ||
| a_strides, | ||
| a_strides, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The You should define b_strides = torch.full((num_experts,), w.stride(1), device=device, dtype=torch.int64)This should be defined before the b_strides = torch.full((num_experts,), w.stride(1), device=device, dtype=torch.int64)
fp8_blockwise_scaled_grouped_mm(
...
a_strides,
b_strides, |
||
| c_strides, | ||
| layout_sfa, | ||
| layout_sfb, | ||
| problem_sizes, | ||
| m_indices[:-1], | ||
| workspace, | ||
| ) | ||
|
|
||
|
|
||
| def cutlass_moe_fp4( | ||
| a: torch.Tensor, | ||
| a1_gscale: torch.Tensor, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,7 @@ | |
| sglang_per_token_group_quant_fp8, | ||
| sglang_per_token_quant_fp8, | ||
| ) | ||
| from sglang.srt.layers.quantization.fp8_utils import cutlass_fp8_supported | ||
| from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod | ||
| from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod | ||
| from sglang.srt.managers.schedule_batch import global_server_args_dict | ||
|
|
@@ -51,16 +52,17 @@ | |
| get_bool_env_var, | ||
| is_hip, | ||
| is_npu, | ||
| is_cuda, | ||
| ) | ||
|
|
||
| _is_hip = is_hip() | ||
| _is_npu = is_npu() | ||
| _is_fp8_fnuz = is_fp8_fnuz() | ||
| _is_cuda = is_cuda() | ||
| _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip | ||
|
|
||
| if not _is_npu: | ||
| from sgl_kernel import silu_and_mul | ||
|
|
||
| from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe | ||
|
|
||
| if _is_hip: | ||
|
|
@@ -71,6 +73,9 @@ | |
| from aiter.fused_moe import fused_moe | ||
| from aiter.ops.shuffle import shuffle_weight | ||
|
|
||
| if _is_cuda: | ||
| from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp8 | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
|
|
@@ -991,6 +996,7 @@ def __init__( | |
| else self.w2_weight_scale | ||
| ), | ||
| ) | ||
| self.cutlass_moe_fp8_supported = cutlass_fp8_supported() and (torch.cuda.get_device_capability(torch.cuda.current_device())[0] == 9) and (torch.version.cuda >= "12.3") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This condition is quite complex and long, which affects readability. It can be simplified. The expression You could simplify this by extracting the capability and version checks and combining them into a more readable expression. For example: major, _ = torch.cuda.get_device_capability()
self.cutlass_moe_fp8_supported = (
cutlass_fp8_supported() and major == 9 and torch.version.cuda >= "12.3"
)This assumes you only want to support Hopper (SM 90) with CUDA 12.3+, which seems to be the intent. major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
self.cutlass_moe_fp8_supported = (
cutlass_fp8_supported() and major == 9 and torch.version.cuda >= "12.3"
) |
||
|
|
||
| def forward( | ||
| self, | ||
|
|
@@ -1011,7 +1017,11 @@ def forward( | |
| forward_batch.is_extend_in_batch | ||
| ) | ||
| if resolved_deepep_mode == DeepEPMode.normal: | ||
| if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: | ||
| if get_bool_env_var("SGLANG_CUTLASS_MOE") and self.cutlass_moe_fp8_supported: | ||
| return self.forward_cutlass_moe( | ||
| hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert | ||
| ) | ||
| elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: | ||
| return self.forward_deepgemm_contiguous( | ||
| hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert | ||
| ) | ||
|
|
@@ -1171,6 +1181,117 @@ def forward_aiter( | |
| ), | ||
| expert_mask=self.expert_mask, | ||
| ) | ||
|
|
||
| def forward_cutlass_moe(self, | ||
| hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The parameter To improve clarity, I suggest renaming the parameter to reflect that it's a tuple, for example, hidden_states_fp8_and_scale: Tuple[torch.Tensor, torch.Tensor], |
||
| topk_idx, | ||
| topk_weights, | ||
| num_recv_tokens_per_expert: List[int] | ||
| ): | ||
| hidden_states_fp8, hidden_states_scale = hidden_states_fp8 | ||
| assert self.quant_method is not None | ||
| assert self.activation == "silu" | ||
| if num_recv_tokens_per_expert is None: | ||
| return hidden_states_fp8.bfloat16() | ||
| all_tokens = sum(num_recv_tokens_per_expert) | ||
| if all_tokens <= 0: | ||
| return hidden_states_fp8.bfloat16() | ||
| M, K = hidden_states_fp8.size() | ||
| N = self.w13_weight.size(1) | ||
| scale_block_size = 128 | ||
|
|
||
| hidden_states_fp8_shape = hidden_states_fp8.shape | ||
| hidden_states_fp8_device = hidden_states_fp8.device | ||
| hidden_states_fp8_dtype = hidden_states_fp8.dtype | ||
|
|
||
| gateup_input_fp8 = torch.empty( | ||
| (all_tokens, K), | ||
| device=hidden_states_fp8_device, | ||
| dtype=hidden_states_fp8_dtype) | ||
| gateup_input_scale = torch.empty( | ||
| (all_tokens, K // 128), | ||
| device=hidden_states_fp8_device, | ||
| dtype=torch.float32) | ||
| m_indices = torch.empty( | ||
| all_tokens, device=hidden_states_fp8_device, dtype=torch.int32 | ||
| ) | ||
| output_index = torch.empty_like(topk_idx) | ||
|
|
||
| num_recv_tokens_per_expert_gpu = torch.tensor( | ||
| num_recv_tokens_per_expert, | ||
| dtype=torch.int32, | ||
| pin_memory=True, | ||
| device="cpu", | ||
| ).cuda(non_blocking=True) | ||
| expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu) | ||
|
|
||
| ep_scatter( | ||
| hidden_states_fp8, | ||
| hidden_states_scale, | ||
| topk_idx, | ||
| num_recv_tokens_per_expert_gpu, | ||
| expert_start_loc, | ||
| gateup_input_fp8, | ||
| gateup_input_scale, | ||
| m_indices, | ||
| output_index, | ||
| scale_ue8m0=False, | ||
| ) | ||
| dispose_tensor(hidden_states_fp8) | ||
|
|
||
| gateup_output = torch.empty( | ||
| (all_tokens, N), | ||
| device=hidden_states_fp8_device, | ||
| dtype=torch.bfloat16, | ||
| ) | ||
| gateup_input_scale = tma_align_input_scale(gateup_input_scale) | ||
|
|
||
| cutlass_moe_fp8(a=gateup_input_fp8, | ||
| a_scale=gateup_input_scale, | ||
| w=self.w13_weight_fp8[0], | ||
| w_scale=self.w13_weight_fp8[1], | ||
| c=gateup_output, | ||
| m_indices=m_indices) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The This will result in incorrect calculations within You should compute the cumulative sum of tokens from m_indices_for_cutlass = torch.nn.functional.pad(
torch.cumsum(num_recv_tokens_per_expert_gpu, dim=0, dtype=torch.int32), (1, 0)
)Then, pass m_indices = torch.empty(
all_tokens, device=hidden_states_fp8_device, dtype=torch.int32
)
output_index = torch.empty_like(topk_idx)
num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
ep_scatter(
hidden_states_fp8,
hidden_states_scale,
topk_idx,
num_recv_tokens_per_expert_gpu,
expert_start_loc,
gateup_input_fp8,
gateup_input_scale,
m_indices,
output_index,
scale_ue8m0=False,
)
m_indices_for_cutlass = torch.nn.functional.pad(
torch.cumsum(num_recv_tokens_per_expert_gpu, dim=0, dtype=torch.int32), (1, 0)
)
cutlass_moe_fp8(a=gateup_input_fp8,
a_scale=gateup_input_scale,
w=self.w13_weight_fp8[0],
w_scale=self.w13_weight_fp8[1],
c=gateup_output,
m_indices=m_indices_for_cutlass) |
||
| del gateup_input_fp8, gateup_input_scale | ||
|
|
||
| down_input = torch.empty( | ||
| ( | ||
| all_tokens, | ||
| N // 2, | ||
| ), | ||
| device=hidden_states_fp8_device, | ||
| dtype=torch.bfloat16, | ||
| ) | ||
| silu_and_mul(gateup_output.view(-1, N), down_input) | ||
| del gateup_output | ||
| down_output = torch.empty( | ||
| (all_tokens, K), | ||
| device=hidden_states_fp8_device, | ||
| dtype=torch.bfloat16, | ||
| ) | ||
| down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8( | ||
| down_input, | ||
| scale_block_size, | ||
| ) | ||
| del down_input | ||
| down_input_scale = tma_align_input_scale(down_input_scale) | ||
| cutlass_moe_fp8(a=down_input_fp8, | ||
| a_scale=down_input_scale, | ||
| w=self.w2_weight_fp8[0], | ||
| w_scale=self.w2_weight_fp8[1], | ||
| c=down_output, | ||
| m_indices=m_indices) | ||
| del down_input_fp8, down_input_scale | ||
|
|
||
| gather_out = torch.empty( | ||
| hidden_states_fp8_shape, | ||
| device=hidden_states_fp8_device, | ||
| dtype=torch.bfloat16, | ||
| ) | ||
| ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out) | ||
|
|
||
| return gather_out | ||
|
|
||
|
|
||
| def forward_deepgemm_contiguous( | ||
| self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allocating a 1GB workspace tensor with
torch.emptyis risky and can easily lead to Out-of-Memory (OOM) errors, especially in environments with limited memory. This size is hardcoded and might be excessive for the actual needs of thefp8_blockwise_scaled_grouped_mmkernel.Consider calculating the required workspace size dynamically based on the problem size or using a much smaller, more reasonable default size. For reference, the test file
test_cutlass_moe.pyallocates a workspace of about 7MB, which is significantly smaller. A smaller buffer or dynamic allocation would be safer and more memory-efficient.