|
11 | 11 | from ...distributed import allgather, reducescatter |
12 | 12 | from ...expert_statistic import ExpertStatistic |
13 | 13 | from ...model_config import ModelConfig |
14 | | -from ...utils import EventType, Fp4QuantizedTensor, swizzle_sf |
| 14 | +from ...utils import EventType, Fp4QuantizedTensor |
15 | 15 | from .deep_ep_utils import buffer_pool, deep_ep_installed |
16 | 16 | from .interface import MoE |
17 | 17 | from .moe_load_balancer import get_moe_load_balancer |
@@ -562,9 +562,6 @@ def forward_chunk( |
562 | 562 | dim=0, |
563 | 563 | sizes=None if use_dp_padding else all_rank_num_tokens) |
564 | 564 | x_row = x.shape[0] |
565 | | - # Fp4 gemm has extra scaling factor |
566 | | - if x_sf is not None: |
567 | | - x_sf = swizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size) |
568 | 565 |
|
569 | 566 | if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( |
570 | 567 | ): |
@@ -638,8 +635,6 @@ def forward_chunk( |
638 | 635 | x = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) |
639 | 636 | x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1], |
640 | 637 | x_sf.shape[2]) |
641 | | - x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2, |
642 | | - self.scaling_vector_size) |
643 | 638 | token_selected_slots = token_selected_slots.view(x.shape[0], 1) |
644 | 639 | token_final_scales = torch.ones_like( |
645 | 640 | token_selected_slots, dtype=token_final_scales.dtype) |
@@ -937,10 +932,6 @@ def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor, |
937 | 932 | self.alltoall_workspace, |
938 | 933 | self.ep_rank, self.ep_size) |
939 | 934 |
|
940 | | - if self.has_nvfp4: |
941 | | - x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2, |
942 | | - self.scaling_vector_size) |
943 | | - |
944 | 935 | return x, x_sf |
945 | 936 |
|
946 | 937 | def alltoall_combine(self, final_hidden_states: torch.Tensor, |
|
0 commit comments