Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def grouped_gemm_triton_kernel(
b_ptr += BLOCK_SIZE_K

if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
scale_a_value = tl.load(scale_a + expert_id)
scale_a_value = tl.load(scale_a + m_range_start + offs_am[:, None])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change modifies scale_a_value to be loaded using per-token indexing (m_range_start + offs_am[:, None]). This implies that the scale_a tensor (which originates from self.w13_input_scale or self.w2_input_scale in layer.py) is expected to be a 1D tensor containing per-token scale factors.

However, there are a few concerns:

  1. Potential Out-of-Bounds Access: If scale_a is actually a smaller, per-expert scale tensor (as suggested by its calculation in EPMoE.forward at layer.py#L262-L268, where it's derived from torch.max(hidden_states) and has shape (num_experts_per_partition,)), then indexing it with global token indices (m_range_start + offs_am) could lead to out-of-bounds memory access, which is a critical issue.
  2. Consistency with Activation Quantization: For this per-token dequantization to be correct, the activation tensor a (input to this kernel) must have been quantized using corresponding per-token scales.
    • If a is pre-quantized to FP8 by pre_reorder_triton_kernel (as suggested by gateup_input dtype in layer.py#L252-L259), then pre_reorder_triton_kernel must use per-token scales. However, its current implementation (kernels.py#L163-L173) appears to use per-expert scales (tl.load(scale_a + expert_id_cur_rank)).
    • If a is not pre-quantized (e.g., it's bf16/fp16), then this kernel performs the quantization of a in its main loop (around kernels.py#L593-L603 in the full file). The same scale_a_value (now per-token) would be used for quantizing a and later for dequantizing the accumulator. This part would be consistent if scale_a is indeed per-token.

Could you clarify the structure of the scale_a tensor in this specific scenario (use_fp8_w8a8 and not (group_k > 0 and group_n > 0)) and ensure that it is compatible with per-token indexing? If scale_a is indeed intended to be per-token, the upstream code responsible for computing w13_input_scale/w2_input_scale and its usage in pre_reorder_triton_kernel might need adjustments to ensure consistency.

scale_b_value = tl.load(scale_b + expert_id)
accumulator *= scale_a_value * scale_b_value

Expand Down
Loading