Skip to content

Commit 9658755

Browse files
committed
Expand to other AlltoallMethodType
Signed-off-by: Jiang Shao <[email protected]>
1 parent 391d8c6 commit 9658755

File tree

1 file changed

+1
-10
lines changed

1 file changed

+1
-10
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ...distributed import allgather, reducescatter
1212
from ...expert_statistic import ExpertStatistic
1313
from ...model_config import ModelConfig
14-
from ...utils import EventType, Fp4QuantizedTensor, swizzle_sf
14+
from ...utils import EventType, Fp4QuantizedTensor
1515
from .deep_ep_utils import buffer_pool, deep_ep_installed
1616
from .interface import MoE
1717
from .moe_load_balancer import get_moe_load_balancer
@@ -562,9 +562,6 @@ def forward_chunk(
562562
dim=0,
563563
sizes=None if use_dp_padding else all_rank_num_tokens)
564564
x_row = x.shape[0]
565-
# Fp4 gemm has extra scaling factor
566-
if x_sf is not None:
567-
x_sf = swizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size)
568565

569566
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
570567
):
@@ -638,8 +635,6 @@ def forward_chunk(
638635
x = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
639636
x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1],
640637
x_sf.shape[2])
641-
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
642-
self.scaling_vector_size)
643638
token_selected_slots = token_selected_slots.view(x.shape[0], 1)
644639
token_final_scales = torch.ones_like(
645640
token_selected_slots, dtype=token_final_scales.dtype)
@@ -937,10 +932,6 @@ def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor,
937932
self.alltoall_workspace,
938933
self.ep_rank, self.ep_size)
939934

940-
if self.has_nvfp4:
941-
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
942-
self.scaling_vector_size)
943-
944935
return x, x_sf
945936

946937
def alltoall_combine(self, final_hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)