diff --git a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt index 603f26796e6..a404013aad3 100644 --- a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt +++ b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt @@ -1,4 +1,4 @@ -set(DEEP_EP_COMMIT c381dadf43a85062f6a8947592017ee513abc70b) +set(DEEP_EP_COMMIT eb3f072664251c05074c3ecc3c3f5dad179c29a9) set(NVSHMEM_URL_HASH SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a) diff --git a/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py b/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py index 62146d9295f..bf808c93c1d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py +++ b/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py @@ -59,7 +59,7 @@ def reserve(self, hidden_size: int, hidden_dtype: torch.dtype): def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], topk_idx: torch.Tensor, topk_weights: torch.Tensor, - num_experts: int) -> \ + num_experts: int, global_expert_id_offset: int) -> \ Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple]: # NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency # of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please @@ -76,7 +76,8 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \ self.buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights, num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert) + is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert, + global_expert_id_offset=global_expert_id_offset) assert event.event is None # For event management, please refer to the docs of the `EventOverlap` class diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 1d46d0712ff..2bf7a45c7fc 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -455,12 +455,13 @@ def forward_chunk( elif self.alltoall_method_type == AlltoallMethodType.DeepEP: if not use_postquant_alltoall: x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \ - self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots) - padded, x, _, recv_topk_idx, token_final_scales = self.pad_empty_recv_tensors( + self.deep_ep_buffer.dispatch(x, token_selected_slots, token_final_scales, self.num_slots, + self.expert_size_per_partition * self.mapping.moe_ep_rank) + padded, x, _, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors( x, None, recv_topk_idx, token_final_scales) elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: if not use_postquant_alltoall: - deep_ep_topk_idx = token_selected_slots.to(torch.int64) + deep_ep_topk_idx = token_selected_slots deep_ep_topk_weights = token_final_scales x, recv_expert_count, deep_ep_handle = \ self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, self.deep_ep_max_num_tokens, self.num_slots) @@ -588,8 +589,9 @@ def forward_chunk( x_sf_dtype = x_sf.dtype x_sf = x_sf.view(torch.float32) (x, x_sf), recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \ - self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots.to(torch.int64), token_final_scales, self.num_slots) - padded, x, x_sf, recv_topk_idx, token_final_scales = self.pad_empty_recv_tensors( + self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots, token_final_scales, self.num_slots, + self.expert_size_per_partition * self.mapping.moe_ep_rank) + padded, x, x_sf, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors( x, x_sf, recv_topk_idx, token_final_scales) if x_sf is not None: x_sf = x_sf.view(x_sf_dtype) @@ -619,7 +621,7 @@ def forward_chunk( fp4_packed_tensor[:, x.shape[1]:x.shape[1] + x_sf.shape[1]] = x_sf - deep_ep_topk_idx = token_selected_slots.to(torch.int64) + deep_ep_topk_idx = token_selected_slots deep_ep_topk_weights = token_final_scales # Each LL combine/dispatch kernel call requires that the `dispatch_rdma_recv_count_buffer` be properly cleaned. # However, the offset of this buffer within the entire RDMA buffer changes according to the hidden size. @@ -668,15 +670,6 @@ def forward_chunk( f"Not available alltoall method type: {self.alltoall_method_type!r}" ) - if use_all_to_all: - # Adapter between `torch.ops.trtllm.fused_moe` and DeepEP - # TODO: remove the adapter by changing APIs - if self.alltoall_method_type == AlltoallMethodType.DeepEP: - token_selected_slots = recv_topk_idx.to(torch.int32) - mask = token_selected_slots == -1 - token_selected_slots += self.expert_size_per_partition * self.mapping.moe_ep_rank - token_selected_slots[mask] = self.num_slots - final_hidden_states = torch.ops.trtllm.fused_moe( x, token_selected_slots,