Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import (
DeepEPMode,
ceil_div,
dispose_tensor,
get_bool_env_var,
is_hip,
Expand Down Expand Up @@ -1370,10 +1371,19 @@ def forward_deepgemm_contiguous(
device=hidden_states_fp8.device,
dtype=hidden_states_fp8.dtype,
),
torch.empty(
(all_tokens, K // 128),
device=hidden_states_fp8.device,
dtype=torch.float32,
(
# TODO check whether need `zeros`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Regarding the TODO check whether need zeros:

If input_tensor[1] (the scale tensor) is guaranteed to be fully overwritten by ep_scatter before any read operation, using torch.empty() could be slightly more performant by avoiding the zero-initialization cost.

However, if DEEPGEMM_SCALE_UE8M0 is true, this tensor has dtype=torch.int and its shape involves ceil_div(K // 128, 4). This suggests a packed format where multiple scale values might be stored within each integer. If K // 128 is not perfectly divisible by 4, ceil_div will cause padding. If these padding bits within the integers are not guaranteed to be overwritten and could affect subsequent operations (e.g., if the kernel reads them or if they are part of a checksum), then torch.zeros() is crucial for correctness to ensure these padding bits are zero.

Could you clarify if torch.zeros is strictly necessary for correctness here, or if torch.empty would suffice under the assumption that ep_scatter fully populates the required parts of the tensor? This might also relate to the TODO(FIXME) in fp8_kernel.py concerning sgl_per_token_group_quant_fp8 which also initializes a similar scale tensor with torch.zeros when scale_ue8m0 is true.

torch.zeros(
(ceil_div(K // 128, 4), all_tokens),
device=hidden_states_fp8.device,
dtype=torch.int,
).transpose(0, 1)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else torch.empty(
(all_tokens, K // 128),
device=hidden_states_fp8.device,
dtype=torch.float32,
)
),
]
m_indices = torch.empty(
Expand All @@ -1399,6 +1409,7 @@ def forward_deepgemm_contiguous(
input_tensor[1],
m_indices,
output_index,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
dispose_tensor(hidden_states_fp8)

Expand All @@ -1407,7 +1418,8 @@ def forward_deepgemm_contiguous(
device=hidden_states_fp8_device,
dtype=torch.bfloat16,
)
input_tensor[1] = tma_align_input_scale(input_tensor[1])
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
input_tensor[1] = tma_align_input_scale(input_tensor[1])
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
)
Expand All @@ -1428,10 +1440,15 @@ def forward_deepgemm_contiguous(
dtype=torch.bfloat16,
)
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input, scale_block_size
down_input,
scale_block_size,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del down_input
down_input_scale = tma_align_input_scale(down_input_scale)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = tma_align_input_scale(down_input_scale)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale),
self.w2_weight_fp8,
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,13 @@ def dispatch_a(
topk_idx = topk_idx.to(torch.int64)
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
# TODO hard code 128 block quant,use fp8 communication
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO mentions a hardcoded 128 for block quantization. This value 128 also appears in python/sglang/srt/layers/quantization/fp8_kernel.py (e.g., in sglang_per_token_group_quant_fp8 as x_q_k // 128) and in python/sglang/srt/layers/moe/ep_moe/layer.py (e.g., K // 128).

To improve maintainability and readability, consider defining this block size (128) as a named constant. This constant could reside in a shared configuration module (e.g., within deep_gemm_wrapper or a dedicated quantization constants file) and be imported where needed. This would make it easier to understand its significance and update it consistently if required in the future.

hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
hidden_states = sglang_per_token_group_quant_fp8(
hidden_states,
128,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_idx, topk_weights, previous_event

Expand Down
Loading