diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 30fb0b6b7ea..df228e5281f 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -46,6 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.utils import ( DeepEPMode, + ceil_div, dispose_tensor, get_bool_env_var, is_hip, @@ -1370,10 +1371,19 @@ def forward_deepgemm_contiguous( device=hidden_states_fp8.device, dtype=hidden_states_fp8.dtype, ), - torch.empty( - (all_tokens, K // 128), - device=hidden_states_fp8.device, - dtype=torch.float32, + ( + # TODO check whether need `zeros` + torch.zeros( + (ceil_div(K // 128, 4), all_tokens), + device=hidden_states_fp8.device, + dtype=torch.int, + ).transpose(0, 1) + if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + else torch.empty( + (all_tokens, K // 128), + device=hidden_states_fp8.device, + dtype=torch.float32, + ) ), ] m_indices = torch.empty( @@ -1399,6 +1409,7 @@ def forward_deepgemm_contiguous( input_tensor[1], m_indices, output_index, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, ) dispose_tensor(hidden_states_fp8) @@ -1407,7 +1418,8 @@ def forward_deepgemm_contiguous( device=hidden_states_fp8_device, dtype=torch.bfloat16, ) - input_tensor[1] = tma_align_input_scale(input_tensor[1]) + if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: + input_tensor[1] = tma_align_input_scale(input_tensor[1]) deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig( input_tensor, self.w13_weight_fp8, gateup_output, m_indices ) @@ -1428,10 +1440,15 @@ def forward_deepgemm_contiguous( dtype=torch.bfloat16, ) down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8( - down_input, scale_block_size + down_input, + scale_block_size, + column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, ) del down_input - down_input_scale = tma_align_input_scale(down_input_scale) + if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: + down_input_scale = tma_align_input_scale(down_input_scale) deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig( (down_input_fp8, down_input_scale), self.w2_weight_fp8, diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index ac9217da87a..2c2c4d1f542 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -246,7 +246,13 @@ def dispatch_a( topk_idx = topk_idx.to(torch.int64) if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: # TODO hard code 128 block quant,use fp8 communication - hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128) + hidden_states = sglang_per_token_group_quant_fp8( + hidden_states, + 128, + column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + ) previous_event = Buffer.capture() if self.async_finish else None return hidden_states, topk_idx, topk_weights, previous_event