Skip to content

Commit 02bf032

Browse files
committed
fix refactor issues
Signed-off-by: Dongxu Yang <[email protected]>
1 parent f007144 commit 02bf032

File tree

2 files changed

+30
-78
lines changed

2 files changed

+30
-78
lines changed

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -179,58 +179,6 @@ def _(
179179
return (input.new_empty(output_shape, dtype=torch.uint8),
180180
global_scale.new_empty(scale_shape, dtype=torch.uint8))
181181

182-
@torch.library.register_fake("trtllm::moe_comm_prepare_indices")
183-
def _(
184-
gathered_target_rank_ids: torch.Tensor,
185-
real_rank_token_count_cum_sum: Optional[torch.Tensor],
186-
max_token_count_per_rank: int,
187-
expert_count: int,
188-
top_k: int,
189-
ep_rank: int,
190-
ep_size: int,
191-
):
192-
max_send_ranks_per_token = max(ep_size, top_k)
193-
local_gather_indices_shape = (max_token_count_per_rank * ep_size, )
194-
rank_count_cum_sum_shape = (ep_size, )
195-
send_rank_local_indices_shape = (max_token_count_per_rank *
196-
max_send_ranks_per_token, )
197-
recv_rank_local_indices_shape = (max_token_count_per_rank * ep_size, )
198-
backward_recv_rank_local_indices_shape = (max_token_count_per_rank *
199-
max_send_ranks_per_token, )
200-
201-
local_gather_indices = gathered_target_rank_ids.new_empty(
202-
local_gather_indices_shape, dtype=torch.int32)
203-
send_rank_count_cum_sum = gathered_target_rank_ids.new_empty(
204-
rank_count_cum_sum_shape, dtype=torch.int32)
205-
send_rank_local_indices = gathered_target_rank_ids.new_empty(
206-
send_rank_local_indices_shape, dtype=torch.int32)
207-
recv_rank_count_cum_sum = gathered_target_rank_ids.new_empty(
208-
rank_count_cum_sum_shape, dtype=torch.int32)
209-
recv_rank_local_indices = gathered_target_rank_ids.new_empty(
210-
recv_rank_local_indices_shape, dtype=torch.int32)
211-
backward_recv_rank_local_indices = gathered_target_rank_ids.new_empty(
212-
backward_recv_rank_local_indices_shape, dtype=torch.int32)
213-
214-
return (local_gather_indices, send_rank_count_cum_sum,
215-
send_rank_local_indices, recv_rank_count_cum_sum,
216-
recv_rank_local_indices, backward_recv_rank_local_indices)
217-
218-
@torch.library.register_fake("trtllm::moe_local_gather")
219-
def _(
220-
recv_rank_cum_sum: torch.Tensor,
221-
local_gather_indices: torch.Tensor,
222-
gathered_expert_ids: torch.Tensor,
223-
gathered_scales: Optional[torch.Tensor],
224-
local_expert_ids: torch.Tensor,
225-
local_scales: Optional[torch.Tensor],
226-
max_token_count_per_rank: int,
227-
expert_count: int,
228-
top_k: int,
229-
ep_rank: int,
230-
ep_size: int,
231-
):
232-
pass
233-
234182
@torch.library.register_fake("trtllm::moe_comm")
235183
def _(
236184
input: torch.Tensor,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def __init__(
192192
self.use_low_precision_combine = (os.environ.get(
193193
"TRTLLM_MOE_USE_LOW_PRECISION_COMBINE", "0")
194194
== "1") and qm.has_nvfp4()
195-
195+
196196
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
197197
MnnvlMemory.initialize()
198198
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
@@ -296,6 +296,9 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
296296
1) // self.moe_max_num_tokens
297297

298298
def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens):
299+
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
300+
return True
301+
299302
# Disable alltoall when chunking is used
300303
if self.calculate_num_chunks(all_rank_num_tokens) > 1:
301304
return False
@@ -453,12 +456,12 @@ def forward_chunk(
453456
else:
454457
tuner_num_tokens = None
455458
tuner_top_k = None
459+
alltoall_info = None
456460
if use_all_to_all:
457461
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
458462
if self.enable_dummy_allreduce:
459463
self.dummy_allreduce()
460464
token_count = x.shape[0]
461-
alltoall_info = None
462465
if is_last_call and self.layer_load_balancer is not None and not self.layer_load_balancer.is_static_routing(
463466
):
464467
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
@@ -469,7 +472,7 @@ def forward_chunk(
469472
self.alltoall_prepare(all_rank_max_num_tokens,
470473
token_selected_slots,
471474
loadbalancer_local_statistic_info)
472-
475+
473476
if gathered_loadbalancer_local_statistic_info is not None:
474477
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view(
475478
(self.mapping.moe_ep_size, self.num_experts))
@@ -577,10 +580,13 @@ def forward_chunk(
577580
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
578581
top_k = self.routing_method.experts_per_token
579582
x, x_sf, token_selected_slots, token_final_scales = self.alltoall_dispatch(
580-
x, x_sf, token_selected_slots, token_final_scales, all_rank_max_num_tokens, top_k, alltoall_info)
583+
x, x_sf, token_selected_slots, token_final_scales,
584+
all_rank_max_num_tokens, top_k, alltoall_info)
581585

582586
if use_postquant_alltoall:
583-
if self.alltoall_method_type == AlltoallMethodType.DeepEP:
587+
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
588+
pass
589+
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
584590
if x_sf is not None:
585591
# Adapter between `x_sf` and DeepEP
586592
# TODO: remove the adapter by adding dtype support to DeepEP
@@ -858,34 +864,32 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
858864
self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1
859865
return outputs
860866

861-
def alltoall_prepare(
862-
self, all_rank_max_num_tokens: int,
863-
token_selected_slots: torch.Tensor,
864-
local_statistic_tensor: Optional[torch.Tensor]):
867+
def alltoall_prepare(self, all_rank_max_num_tokens: int,
868+
token_selected_slots: torch.Tensor,
869+
local_statistic_tensor: Optional[torch.Tensor]):
865870
top_k = self.routing_method.experts_per_token
866871

867872
alltoall_info, gathered_local_statistic_tensor = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
868-
token_selected_slots,
869-
local_statistic_tensor, self.alltoall_prepare_workspace,
870-
all_rank_max_num_tokens, self.ep_rank, self.ep_size,
871-
self.num_experts, self.num_slots, top_k)
873+
token_selected_slots, local_statistic_tensor,
874+
self.alltoall_prepare_workspace, all_rank_max_num_tokens,
875+
self.ep_rank, self.ep_size, self.num_experts, self.num_slots, top_k)
872876

873877
return token_selected_slots, gathered_local_statistic_tensor, alltoall_info
874878

875879
def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
876-
token_selected_slots: torch.Tensor,
877-
token_final_scales: Optional[torch.Tensor],
878-
all_rank_max_num_tokens: int,
879-
top_k: int,
880-
alltoall_info: MoEAlltoallInfo):
881-
882-
x, x_sf, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv([x, x_sf, token_selected_slots, token_final_scales], alltoall_info,
883-
self.alltoall_workspace, self.ep_rank,
884-
self.ep_size)
885-
886-
torch.ops.trtllm.memset_expert_ids(
887-
token_selected_slots, alltoall_info.recv_rank_count_cumsum,
888-
all_rank_max_num_tokens, top_k, self.num_slots, self.ep_size)
880+
token_selected_slots: torch.Tensor,
881+
token_final_scales: Optional[torch.Tensor],
882+
all_rank_max_num_tokens: int, top_k: int,
883+
alltoall_info: MoEAlltoallInfo):
884+
885+
x, x_sf, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv(
886+
[x, x_sf, token_selected_slots, token_final_scales], alltoall_info,
887+
self.alltoall_workspace, self.ep_rank, self.ep_size)
888+
889+
torch.ops.trtllm.memset_expert_ids(token_selected_slots,
890+
alltoall_info.recv_rank_count_cumsum,
891+
all_rank_max_num_tokens, top_k,
892+
self.num_slots, self.ep_size)
889893

890894
return x, x_sf, token_selected_slots, token_final_scales
891895

0 commit comments

Comments
 (0)