Skip to content

Commit d2608f4

Browse files
committed
Expand to other AlltoallMethodType
Signed-off-by: Jiang Shao <[email protected]>
1 parent 9844fbb commit d2608f4

File tree

1 file changed

+1
-8
lines changed

1 file changed

+1
-8
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ...distributed import allgather, reducescatter
1212
from ...expert_statistic import ExpertStatistic
1313
from ...model_config import ModelConfig
14-
from ...utils import EventType, Fp4QuantizedTensor, swizzle_sf
14+
from ...utils import EventType, Fp4QuantizedTensor
1515
from .deep_ep_utils import buffer_pool, deep_ep_installed
1616
from .interface import MoE
1717
from .moe_load_balancer import get_moe_load_balancer
@@ -552,7 +552,6 @@ def forward_chunk(
552552
# Fp4 gemm has extra scaling factor
553553
if x_sf is not None:
554554
assert not x_is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before allgather"
555-
x_sf = swizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size)
556555

557556
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
558557
):
@@ -657,8 +656,6 @@ def forward_chunk(
657656
x = x.reshape(x.shape[0] * x.shape[1], x.shape[2]).view(x_dtype)
658657
x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1],
659658
x_sf.shape[2]).view(x_sf_dtype)
660-
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
661-
self.scaling_vector_size)
662659
token_selected_slots = token_selected_slots.view(x.shape[0], 1)
663660
token_final_scales = torch.ones_like(
664661
token_selected_slots, dtype=token_final_scales.dtype)
@@ -967,10 +964,6 @@ def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor,
967964
self.alltoall_workspace,
968965
self.ep_rank, self.ep_size)
969966

970-
if self.has_nvfp4:
971-
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
972-
self.scaling_vector_size)
973-
974967
return x, x_sf
975968

976969
def alltoall_combine(self, final_hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)