Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/deep_ep/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
set(DEEP_EP_COMMIT c381dadf43a85062f6a8947592017ee513abc70b)
set(DEEP_EP_COMMIT eb3f072664251c05074c3ecc3c3f5dad179c29a9)
set(NVSHMEM_URL_HASH
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)

Expand Down
5 changes: 3 additions & 2 deletions tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def reserve(self, hidden_size: int, hidden_dtype: torch.dtype):

def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int) -> \
num_experts: int, global_expert_id_offset: int) -> \
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple]:
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
Expand All @@ -76,7 +76,8 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
self.buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert)
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert,
global_expert_id_offset=global_expert_id_offset)
assert event.event is None

# For event management, please refer to the docs of the `EventOverlap` class
Expand Down
23 changes: 8 additions & 15 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,12 +455,13 @@ def forward_chunk(
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
if not use_postquant_alltoall:
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
padded, x, _, recv_topk_idx, token_final_scales = self.pad_empty_recv_tensors(
self.deep_ep_buffer.dispatch(x, token_selected_slots, token_final_scales, self.num_slots,
self.expert_size_per_partition * self.mapping.moe_ep_rank)
padded, x, _, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
x, None, recv_topk_idx, token_final_scales)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
if not use_postquant_alltoall:
deep_ep_topk_idx = token_selected_slots.to(torch.int64)
deep_ep_topk_idx = token_selected_slots
deep_ep_topk_weights = token_final_scales
x, recv_expert_count, deep_ep_handle = \
self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, self.deep_ep_max_num_tokens, self.num_slots)
Expand Down Expand Up @@ -588,8 +589,9 @@ def forward_chunk(
x_sf_dtype = x_sf.dtype
x_sf = x_sf.view(torch.float32)
(x, x_sf), recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
padded, x, x_sf, recv_topk_idx, token_final_scales = self.pad_empty_recv_tensors(
self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots, token_final_scales, self.num_slots,
self.expert_size_per_partition * self.mapping.moe_ep_rank)
padded, x, x_sf, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
x, x_sf, recv_topk_idx, token_final_scales)
if x_sf is not None:
x_sf = x_sf.view(x_sf_dtype)
Expand Down Expand Up @@ -619,7 +621,7 @@ def forward_chunk(
fp4_packed_tensor[:,
x.shape[1]:x.shape[1] + x_sf.shape[1]] = x_sf

deep_ep_topk_idx = token_selected_slots.to(torch.int64)
deep_ep_topk_idx = token_selected_slots
deep_ep_topk_weights = token_final_scales
# Each LL combine/dispatch kernel call requires that the `dispatch_rdma_recv_count_buffer` be properly cleaned.
# However, the offset of this buffer within the entire RDMA buffer changes according to the hidden size.
Expand Down Expand Up @@ -668,15 +670,6 @@ def forward_chunk(
f"Not available alltoall method type: {self.alltoall_method_type!r}"
)

if use_all_to_all:
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
# TODO: remove the adapter by changing APIs
if self.alltoall_method_type == AlltoallMethodType.DeepEP:
token_selected_slots = recv_topk_idx.to(torch.int32)
mask = token_selected_slots == -1
token_selected_slots += self.expert_size_per_partition * self.mapping.moe_ep_rank
token_selected_slots[mask] = self.num_slots

final_hidden_states = torch.ops.trtllm.fused_moe(
x,
token_selected_slots,
Expand Down