From 737508752e07646fc7648b8a2f51a41aa0e66244 Mon Sep 17 00:00:00 2001 From: Mohammad Miadh Angkad Date: Sun, 31 Aug 2025 17:24:42 +0800 Subject: [PATCH 1/3] Change tensor alignment method to mn major --- python/sglang/srt/layers/moe/ep_moe/layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 175914560156..b37e06880bb0 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -232,7 +232,7 @@ def forward_deepgemm( ( _cast_to_e8m0_with_rounding_up(gateup_input_scale) if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 - else deep_gemm_wrapper.get_col_major_tma_aligned_tensor( + else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor( gateup_input_scale ) ), @@ -289,7 +289,7 @@ def forward_deepgemm( ( down_input_scale if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 - else deep_gemm_wrapper.get_col_major_tma_aligned_tensor( + else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor( down_input_scale ) ), From 52368d6c0d04a0bb172f5d5f69906ba948bb45ae Mon Sep 17 00:00:00 2001 From: Mohammad Miadh Angkad Date: Sun, 31 Aug 2025 18:31:15 +0800 Subject: [PATCH 2/3] Refactor scale handling in gateup and down inputs --- python/sglang/srt/layers/moe/ep_moe/layer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index b37e06880bb0..ed95cb1718e5 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -232,9 +232,7 @@ def forward_deepgemm( ( _cast_to_e8m0_with_rounding_up(gateup_input_scale) if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 - else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor( - gateup_input_scale - ) + else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(gateup_input_scale) ), ) num_groups, m, k = gateup_input_fp8[0].size() @@ -289,9 +287,7 @@ def forward_deepgemm( ( down_input_scale if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 - else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor( - down_input_scale - ) + else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale) ), ) down_output = torch.empty( From 3eb4e34e94625b15ebc6f06faf4f39ed4a782eef Mon Sep 17 00:00:00 2001 From: zhyncs Date: Tue, 2 Sep 2025 01:22:29 -0700 Subject: [PATCH 3/3] upd --- python/sglang/srt/layers/moe/ep_moe/layer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index ed95cb1718e5..b145cf93e329 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -232,7 +232,9 @@ def forward_deepgemm( ( _cast_to_e8m0_with_rounding_up(gateup_input_scale) if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 - else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(gateup_input_scale) + else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor( + gateup_input_scale + ) ), ) num_groups, m, k = gateup_input_fp8[0].size()