Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/deep_ep/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
set(DEEP_EP_COMMIT be2582ffe69b5e7d61c3bc9bf7a5316bc48261f9)
set(DEEP_EP_COMMIT 5be51b228a7c82dbdb213ea58e77bffd12b38af8)
set(NVSHMEM_URL_HASH
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)

Expand Down
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def reserve(self, hidden_size: int, hidden_dtype: torch.dtype):

def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int, global_expert_id_offset: int) -> \
num_experts: int, global_expert_id_offset: int, all_rank_max_num_tokens: int, ep_size: int, use_cuda_graph: bool) -> \
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple]:
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
Expand All @@ -71,13 +71,16 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
assert event.event is None

# Do MoE dispatch
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
# For more advanced usages, please refer to the docs of the `dispatch` function
if use_cuda_graph:
num_worst_tokens = all_rank_max_num_tokens * ep_size
else:
num_worst_tokens = 0
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
self.buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert,
global_expert_id_offset=global_expert_id_offset)
global_expert_id_offset=global_expert_id_offset, num_worst_tokens=num_worst_tokens)
assert event.event is None

# For event management, please refer to the docs of the `EventOverlap` class
Expand Down
40 changes: 31 additions & 9 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm._utils import get_sm_version, local_mpi_size
from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
Expand All @@ -32,7 +32,7 @@ class AlltoallMethodType(IntEnum):
NotEnabled = 0
# MNNVL
MNNVL = 1
# DeepEP intranode or internode: no CUDA Graphs support, IBGDA is required by internode
# DeepEP intranode or internode: CUDA Graphs are supported, IBGDA is required by internode
DeepEP = 2
# DeepEP low latency: CUDA Graphs are supported, IBGDA is required
DeepEPLowLatency = 3
Expand Down Expand Up @@ -101,6 +101,8 @@ def __init__(
self.repeat_idx = 0
self.repeat_count = 1

self.use_cuda_graph = model_config.use_cuda_graph

moe_load_balancer_config = model_config.moe_load_balancer
init_expert_size_per_partition = moe_load_balancer_config.num_local_slots if moe_load_balancer_config else self.num_experts // self.ep_size
self.initial_global_assignments = [
Expand Down Expand Up @@ -212,6 +214,9 @@ def __init__(
str(
min(model_config.max_num_tokens,
self.moe_max_num_tokens))))
# Set nvshmem queue pair depth larger than the number of on-flight WRs (ref: https://github.com/deepseek-ai/DeepEP/issues/427)
os.environ['NVSHMEM_QP_DEPTH'] = str(
2 * (self.deep_ep_max_num_tokens + 1))
self.deep_ep_buffer = buffer_pool.get_low_latency_buffer(
model_config.mapping)
self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens,
Expand Down Expand Up @@ -253,6 +258,25 @@ def _check_configs(self):
def select_alltoall_method_type(mapping: Mapping, top_k: int,
dtype: torch.dtype,
use_cuda_graph: bool) -> AlltoallMethodType:

# Check if DeepEP is feasible for the given number of ranks
# DeepEP supports two modes:
# 1. Intranode: Single node with 2, 4, or 8 ranks
# 2. Internode: 2, 4, 8, or 16 nodes with 8 ranks per node
def is_deepep_feasible(num_ranks: int) -> bool:
NUM_INTRANODE_SUPPORTED_RANKS = {2, 4, 8}
REQUIRED_LOCAL_MPI_SIZE = 8
NUM_INTERNODE_SUPPORTED_RDMA_RANKS = {2, 4, 8, 16}
mpi_size = local_mpi_size()
# Intranode cases
if num_ranks == mpi_size and num_ranks in NUM_INTRANODE_SUPPORTED_RANKS:
return True
# Internode cases
if mpi_size != REQUIRED_LOCAL_MPI_SIZE:
return False
num_rdma_nodes = num_ranks // mpi_size
return num_rdma_nodes in NUM_INTERNODE_SUPPORTED_RDMA_RANKS

all2all_method_type = os.environ.get("TRTLLM_FORCE_ALLTOALL_METHOD")
if all2all_method_type is not None:
return AlltoallMethodType[all2all_method_type]
Expand All @@ -274,12 +298,10 @@ def select_alltoall_method_type(mapping: Mapping, top_k: int,

if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1":
if deep_ep_installed and dtype == torch.bfloat16:
if use_cuda_graph:
# Here we can only choose DeepEPLowLatency since only this method supports CUDA Graphs.
return AlltoallMethodType.DeepEPLowLatency
else:
# Here we can choose DeepEP or DeepEPLowLatency if both are available. Now DeepEP is faster.
# Choose DeepEP if feasible
if is_deepep_feasible(mapping.moe_ep_size):
return AlltoallMethodType.DeepEP
return AlltoallMethodType.DeepEPLowLatency

return AlltoallMethodType.NotEnabled

Expand Down Expand Up @@ -546,7 +568,7 @@ def forward_chunk(
if not use_postquant_alltoall:
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
self.deep_ep_buffer.dispatch(x, token_selected_slots, token_final_scales, self.num_slots,
self.expert_size_per_partition * self.mapping.moe_ep_rank)
self.expert_size_per_partition * self.mapping.moe_ep_rank, all_rank_max_num_tokens, self.ep_size, self.use_cuda_graph)
padded, x, _, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
x, None, recv_topk_idx, token_final_scales)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
Expand Down Expand Up @@ -634,7 +656,7 @@ def forward_chunk(
x_sf = x_sf.view(torch.float32)
(x, x_sf), recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots, token_final_scales, self.num_slots,
self.expert_size_per_partition * self.mapping.moe_ep_rank)
self.expert_size_per_partition * self.mapping.moe_ep_rank, all_rank_max_num_tokens, self.ep_size, self.use_cuda_graph)
padded, x, x_sf, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
x, x_sf, recv_topk_idx, token_final_scales)
if x_sf is not None:
Expand Down