diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 175914560156..b145cf93e329 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -232,7 +232,7 @@ def forward_deepgemm( ( _cast_to_e8m0_with_rounding_up(gateup_input_scale) if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 - else deep_gemm_wrapper.get_col_major_tma_aligned_tensor( + else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor( gateup_input_scale ) ), @@ -289,9 +289,7 @@ def forward_deepgemm( ( down_input_scale if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 - else deep_gemm_wrapper.get_col_major_tma_aligned_tensor( - down_input_scale - ) + else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale) ), ) down_output = torch.empty(