55import torch
66
77from tensorrt_llm ._mnnvl_utils import MnnvlMemory , MnnvlMoe , MoEAlltoallInfo
8- from tensorrt_llm ._utils import get_sm_version
8+ from tensorrt_llm ._utils import get_sm_version , local_mpi_size
99from tensorrt_llm .functional import AllReduceStrategy
1010from tensorrt_llm .logger import logger
1111from tensorrt_llm .mapping import Mapping
@@ -32,7 +32,7 @@ class AlltoallMethodType(IntEnum):
3232 NotEnabled = 0
3333 # MNNVL
3434 MNNVL = 1
35- # DeepEP intranode or internode: no CUDA Graphs support , IBGDA is required by internode
35+ # DeepEP intranode or internode: CUDA Graphs are supported , IBGDA is required by internode
3636 DeepEP = 2
3737 # DeepEP low latency: CUDA Graphs are supported, IBGDA is required
3838 DeepEPLowLatency = 3
@@ -101,6 +101,8 @@ def __init__(
101101 self .repeat_idx = 0
102102 self .repeat_count = 1
103103
104+ self .use_cuda_graph = model_config .use_cuda_graph
105+
104106 moe_load_balancer_config = model_config .moe_load_balancer
105107 init_expert_size_per_partition = moe_load_balancer_config .num_local_slots if moe_load_balancer_config else self .num_experts // self .ep_size
106108 self .initial_global_assignments = [
@@ -213,6 +215,7 @@ def __init__(
213215 str (
214216 min (model_config .max_num_tokens ,
215217 self .moe_max_num_tokens ))))
218+ os .environ ['NVSHMEM_QP_DEPTH' ] = 2 * (self .deep_ep_max_num_tokens + 1 )
216219 self .deep_ep_buffer = buffer_pool .get_low_latency_buffer (
217220 model_config .mapping )
218221 self .deep_ep_buffer .reserve (self .deep_ep_max_num_tokens ,
@@ -254,6 +257,25 @@ def _check_configs(self):
254257 def select_alltoall_method_type (mapping : Mapping , top_k : int ,
255258 dtype : torch .dtype ,
256259 use_cuda_graph : bool ) -> AlltoallMethodType :
260+
261+ # Check if DeepEP is feasible for the given number of ranks
262+ # DeepEP supports two modes:
263+ # 1. Intranode: Single node with 2, 4, or 8 ranks
264+ # 2. Internode: 2, 4, 8, or 16 nodes with 8 ranks per node
265+ def is_deepep_feasible (num_ranks : int ) -> bool :
266+ NUM_INTRANODE_SUPPORTED_RANKS = {2 , 4 , 8 }
267+ REQUIRED_LOCAL_MPI_SIZE = 8
268+ NUM_INTERNODE_SUPPORTED_RDMA_RANKS = {2 , 4 , 8 , 16 }
269+ # Intranode cases
270+ if num_ranks in NUM_INTRANODE_SUPPORTED_RANKS :
271+ return True
272+ # Internode cases
273+ mpi_size = local_mpi_size ()
274+ if mpi_size != REQUIRED_LOCAL_MPI_SIZE :
275+ return False
276+ num_rdma_nodes = num_ranks // mpi_size
277+ return num_rdma_nodes in NUM_INTERNODE_SUPPORTED_RDMA_RANKS
278+
257279 all2all_method_type = os .environ .get ("TRTLLM_FORCE_ALLTOALL_METHOD" )
258280 if all2all_method_type is not None :
259281 return AlltoallMethodType [all2all_method_type ]
@@ -275,12 +297,10 @@ def select_alltoall_method_type(mapping: Mapping, top_k: int,
275297
276298 if os .environ .get ("TRTLLM_CAN_USE_DEEP_EP" , "0" ) == "1" :
277299 if deep_ep_installed and dtype == torch .bfloat16 :
278- if use_cuda_graph :
279- # Here we can only choose DeepEPLowLatency since only this method supports CUDA Graphs.
280- return AlltoallMethodType .DeepEPLowLatency
281- else :
282- # Here we can choose DeepEP or DeepEPLowLatency if both are available. Now DeepEP is faster.
300+ # Choose DeepEP if feasible
301+ if is_deepep_feasible (mapping .moe_ep_size ):
283302 return AlltoallMethodType .DeepEP
303+ return AlltoallMethodType .DeepEPLowLatency
284304
285305 return AlltoallMethodType .NotEnabled
286306
@@ -534,7 +554,7 @@ def forward_chunk(
534554 if not use_postquant_alltoall :
535555 x , recv_topk_idx , token_final_scales , num_recv_tokens_per_expert_list , deep_ep_handle = \
536556 self .deep_ep_buffer .dispatch (x , token_selected_slots , token_final_scales , self .num_slots ,
537- self .expert_size_per_partition * self .mapping .moe_ep_rank )
557+ self .expert_size_per_partition * self .mapping .moe_ep_rank , all_rank_max_num_tokens , self . ep_size , self . use_cuda_graph )
538558 padded , x , _ , token_selected_slots , token_final_scales = self .pad_empty_recv_tensors (
539559 x , None , recv_topk_idx , token_final_scales )
540560 elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
@@ -621,7 +641,7 @@ def forward_chunk(
621641 x_sf = x_sf .view (torch .float32 )
622642 (x , x_sf ), recv_topk_idx , token_final_scales , num_recv_tokens_per_expert_list , deep_ep_handle = \
623643 self .deep_ep_buffer .dispatch ((x , x_sf ), token_selected_slots , token_final_scales , self .num_slots ,
624- self .expert_size_per_partition * self .mapping .moe_ep_rank )
644+ self .expert_size_per_partition * self .mapping .moe_ep_rank , all_rank_max_num_tokens , self . ep_size , self . use_cuda_graph )
625645 padded , x , x_sf , token_selected_slots , token_final_scales = self .pad_empty_recv_tensors (
626646 x , x_sf , recv_topk_idx , token_final_scales )
627647 if x_sf is not None :
0 commit comments