|
11 | 11 | from ...distributed import allgather, reducescatter |
12 | 12 | from ...expert_statistic import ExpertStatistic |
13 | 13 | from ...model_config import ModelConfig |
14 | | -from ...utils import EventType, Fp4QuantizedTensor, swizzle_sf |
| 14 | +from ...utils import EventType, Fp4QuantizedTensor |
15 | 15 | from .deep_ep_utils import buffer_pool, deep_ep_installed |
16 | 16 | from .interface import MoE |
17 | 17 | from .moe_load_balancer import get_moe_load_balancer |
@@ -552,7 +552,6 @@ def forward_chunk( |
552 | 552 | # Fp4 gemm has extra scaling factor |
553 | 553 | if x_sf is not None: |
554 | 554 | assert not x_is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before allgather" |
555 | | - x_sf = swizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size) |
556 | 555 |
|
557 | 556 | if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( |
558 | 557 | ): |
@@ -657,8 +656,6 @@ def forward_chunk( |
657 | 656 | x = x.reshape(x.shape[0] * x.shape[1], x.shape[2]).view(x_dtype) |
658 | 657 | x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1], |
659 | 658 | x_sf.shape[2]).view(x_sf_dtype) |
660 | | - x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2, |
661 | | - self.scaling_vector_size) |
662 | 659 | token_selected_slots = token_selected_slots.view(x.shape[0], 1) |
663 | 660 | token_final_scales = torch.ones_like( |
664 | 661 | token_selected_slots, dtype=token_final_scales.dtype) |
@@ -967,10 +964,6 @@ def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor, |
967 | 964 | self.alltoall_workspace, |
968 | 965 | self.ep_rank, self.ep_size) |
969 | 966 |
|
970 | | - if self.has_nvfp4: |
971 | | - x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2, |
972 | | - self.scaling_vector_size) |
973 | | - |
974 | 967 | return x, x_sf |
975 | 968 |
|
976 | 969 | def alltoall_combine(self, final_hidden_states: torch.Tensor, |
|
0 commit comments