diff --git a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt index f690ab5a905..db3d59e9a28 100644 --- a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt +++ b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt @@ -1,4 +1,4 @@ -set(DEEP_EP_COMMIT be2582ffe69b5e7d61c3bc9bf7a5316bc48261f9) +set(DEEP_EP_COMMIT 5be51b228a7c82dbdb213ea58e77bffd12b38af8) set(NVSHMEM_URL_HASH SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a) diff --git a/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py b/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py index 84673163b82..0dcd115eee6 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py +++ b/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py @@ -59,7 +59,7 @@ def reserve(self, hidden_size: int, hidden_dtype: torch.dtype): def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], topk_idx: torch.Tensor, topk_weights: torch.Tensor, - num_experts: int, global_expert_id_offset: int) -> \ + num_experts: int, global_expert_id_offset: int, all_rank_max_num_tokens: int, ep_size: int, use_cuda_graph: bool) -> \ Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple]: # NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency # of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please @@ -71,13 +71,16 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], assert event.event is None # Do MoE dispatch - # NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph # For more advanced usages, please refer to the docs of the `dispatch` function + if use_cuda_graph: + num_worst_tokens = all_rank_max_num_tokens * ep_size + else: + num_worst_tokens = 0 recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \ self.buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights, num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert, - global_expert_id_offset=global_expert_id_offset) + global_expert_id_offset=global_expert_id_offset, num_worst_tokens=num_worst_tokens) assert event.event is None # For event management, please refer to the docs of the `EventOverlap` class diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 2f8cd3346ff..75f234474d7 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -5,7 +5,7 @@ import torch from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo -from tensorrt_llm._utils import get_sm_version +from tensorrt_llm._utils import get_sm_version, local_mpi_size from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping @@ -32,7 +32,7 @@ class AlltoallMethodType(IntEnum): NotEnabled = 0 # MNNVL MNNVL = 1 - # DeepEP intranode or internode: no CUDA Graphs support, IBGDA is required by internode + # DeepEP intranode or internode: CUDA Graphs are supported, IBGDA is required by internode DeepEP = 2 # DeepEP low latency: CUDA Graphs are supported, IBGDA is required DeepEPLowLatency = 3 @@ -101,6 +101,8 @@ def __init__( self.repeat_idx = 0 self.repeat_count = 1 + self.use_cuda_graph = model_config.use_cuda_graph + moe_load_balancer_config = model_config.moe_load_balancer init_expert_size_per_partition = moe_load_balancer_config.num_local_slots if moe_load_balancer_config else self.num_experts // self.ep_size self.initial_global_assignments = [ @@ -212,6 +214,9 @@ def __init__( str( min(model_config.max_num_tokens, self.moe_max_num_tokens)))) + # Set nvshmem queue pair depth larger than the number of on-flight WRs (ref: https://github.com/deepseek-ai/DeepEP/issues/427) + os.environ['NVSHMEM_QP_DEPTH'] = str( + 2 * (self.deep_ep_max_num_tokens + 1)) self.deep_ep_buffer = buffer_pool.get_low_latency_buffer( model_config.mapping) self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens, @@ -253,6 +258,25 @@ def _check_configs(self): def select_alltoall_method_type(mapping: Mapping, top_k: int, dtype: torch.dtype, use_cuda_graph: bool) -> AlltoallMethodType: + + # Check if DeepEP is feasible for the given number of ranks + # DeepEP supports two modes: + # 1. Intranode: Single node with 2, 4, or 8 ranks + # 2. Internode: 2, 4, 8, or 16 nodes with 8 ranks per node + def is_deepep_feasible(num_ranks: int) -> bool: + NUM_INTRANODE_SUPPORTED_RANKS = {2, 4, 8} + REQUIRED_LOCAL_MPI_SIZE = 8 + NUM_INTERNODE_SUPPORTED_RDMA_RANKS = {2, 4, 8, 16} + mpi_size = local_mpi_size() + # Intranode cases + if num_ranks == mpi_size and num_ranks in NUM_INTRANODE_SUPPORTED_RANKS: + return True + # Internode cases + if mpi_size != REQUIRED_LOCAL_MPI_SIZE: + return False + num_rdma_nodes = num_ranks // mpi_size + return num_rdma_nodes in NUM_INTERNODE_SUPPORTED_RDMA_RANKS + all2all_method_type = os.environ.get("TRTLLM_FORCE_ALLTOALL_METHOD") if all2all_method_type is not None: return AlltoallMethodType[all2all_method_type] @@ -274,12 +298,10 @@ def select_alltoall_method_type(mapping: Mapping, top_k: int, if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1": if deep_ep_installed and dtype == torch.bfloat16: - if use_cuda_graph: - # Here we can only choose DeepEPLowLatency since only this method supports CUDA Graphs. - return AlltoallMethodType.DeepEPLowLatency - else: - # Here we can choose DeepEP or DeepEPLowLatency if both are available. Now DeepEP is faster. + # Choose DeepEP if feasible + if is_deepep_feasible(mapping.moe_ep_size): return AlltoallMethodType.DeepEP + return AlltoallMethodType.DeepEPLowLatency return AlltoallMethodType.NotEnabled @@ -546,7 +568,7 @@ def forward_chunk( if not use_postquant_alltoall: x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \ self.deep_ep_buffer.dispatch(x, token_selected_slots, token_final_scales, self.num_slots, - self.expert_size_per_partition * self.mapping.moe_ep_rank) + self.expert_size_per_partition * self.mapping.moe_ep_rank, all_rank_max_num_tokens, self.ep_size, self.use_cuda_graph) padded, x, _, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors( x, None, recv_topk_idx, token_final_scales) elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: @@ -634,7 +656,7 @@ def forward_chunk( x_sf = x_sf.view(torch.float32) (x, x_sf), recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \ self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots, token_final_scales, self.num_slots, - self.expert_size_per_partition * self.mapping.moe_ep_rank) + self.expert_size_per_partition * self.mapping.moe_ep_rank, all_rank_max_num_tokens, self.ep_size, self.use_cuda_graph) padded, x, x_sf, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors( x, x_sf, recv_topk_idx, token_final_scales) if x_sf is not None: