Skip to content

Commit ae968a0

Browse files
jinyangyuan-nvidialancelly
authored andcommitted
[fix] Fix wide EP when using DeepEP with online EPLB (NVIDIA#6429)
Signed-off-by: Jinyang Yuan <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent 9520c6d commit ae968a0

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,10 @@ def forward_chunk(
470470
self.expert_size_per_partition * self.mapping.moe_ep_rank)
471471
padded, x, _, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
472472
x, None, recv_topk_idx, token_final_scales)
473+
if is_last_call and self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
474+
):
475+
gathered_loadbalancer_local_statistic_info = allgather(
476+
loadbalancer_local_statistic_info, self.mapping, dim=0)
473477
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
474478
if not use_postquant_alltoall:
475479
deep_ep_topk_idx = token_selected_slots
@@ -499,6 +503,10 @@ def forward_chunk(
499503
x.shape[0], 1)
500504
token_final_scales = torch.ones_like(
501505
token_selected_slots, dtype=token_final_scales.dtype)
506+
if is_last_call and self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
507+
):
508+
gathered_loadbalancer_local_statistic_info = allgather(
509+
loadbalancer_local_statistic_info, self.mapping, dim=0)
502510

503511
x_sf = None
504512
x_row = x.shape[0]

0 commit comments

Comments
 (0)