Skip to content

Commit fc92e82

Browse files
yifeizhang-cdominicshanshan
authored andcommitted
[TRTLLM-6589][feat] Support CUDA graph for DeepEP (NVIDIA#7514)
Signed-off-by: Yifei Zhang <[email protected]>
1 parent bd2fd9c commit fc92e82

File tree

3 files changed

+37
-12
lines changed

3 files changed

+37
-12
lines changed

cpp/tensorrt_llm/deep_ep/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
set(DEEP_EP_COMMIT be2582ffe69b5e7d61c3bc9bf7a5316bc48261f9)
1+
set(DEEP_EP_COMMIT 5be51b228a7c82dbdb213ea58e77bffd12b38af8)
22
set(NVSHMEM_URL_HASH
33
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)
44

tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def reserve(self, hidden_size: int, hidden_dtype: torch.dtype):
5959

6060
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
6161
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
62-
num_experts: int, global_expert_id_offset: int) -> \
62+
num_experts: int, global_expert_id_offset: int, all_rank_max_num_tokens: int, ep_size: int, use_cuda_graph: bool) -> \
6363
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple]:
6464
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
6565
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
@@ -71,13 +71,16 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
7171
assert event.event is None
7272

7373
# Do MoE dispatch
74-
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
7574
# For more advanced usages, please refer to the docs of the `dispatch` function
75+
if use_cuda_graph:
76+
num_worst_tokens = all_rank_max_num_tokens * ep_size
77+
else:
78+
num_worst_tokens = 0
7679
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
7780
self.buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
7881
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
7982
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert,
80-
global_expert_id_offset=global_expert_id_offset)
83+
global_expert_id_offset=global_expert_id_offset, num_worst_tokens=num_worst_tokens)
8184
assert event.event is None
8285

8386
# For event management, please refer to the docs of the `EventOverlap` class

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class AlltoallMethodType(IntEnum):
3232
NotEnabled = 0
3333
# MNNVL
3434
MNNVL = 1
35-
# DeepEP intranode or internode: no CUDA Graphs support, IBGDA is required by internode
35+
# DeepEP intranode or internode: CUDA Graphs are supported, IBGDA is required by internode
3636
DeepEP = 2
3737
# DeepEP low latency: CUDA Graphs are supported, IBGDA is required
3838
DeepEPLowLatency = 3
@@ -101,6 +101,8 @@ def __init__(
101101
self.repeat_idx = 0
102102
self.repeat_count = 1
103103

104+
self.use_cuda_graph = model_config.use_cuda_graph
105+
104106
moe_load_balancer_config = model_config.moe_load_balancer
105107
init_expert_size_per_partition = moe_load_balancer_config.num_local_slots if moe_load_balancer_config else self.num_experts // self.ep_size
106108
self.initial_global_assignments = [
@@ -212,6 +214,9 @@ def __init__(
212214
str(
213215
min(model_config.max_num_tokens,
214216
self.moe_max_num_tokens))))
217+
# Set nvshmem queue pair depth larger than the number of on-flight WRs (ref: https://github.com/deepseek-ai/DeepEP/issues/427)
218+
os.environ['NVSHMEM_QP_DEPTH'] = str(
219+
2 * (self.deep_ep_max_num_tokens + 1))
215220
self.deep_ep_buffer = buffer_pool.get_low_latency_buffer(
216221
model_config.mapping)
217222
self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens,
@@ -255,6 +260,25 @@ def _check_configs(self):
255260
def select_alltoall_method_type(mapping: Mapping, top_k: int,
256261
dtype: torch.dtype,
257262
use_cuda_graph: bool) -> AlltoallMethodType:
263+
264+
# Check if DeepEP is feasible for the given number of ranks
265+
# DeepEP supports two modes:
266+
# 1. Intranode: Single node with 2, 4, or 8 ranks
267+
# 2. Internode: 2, 4, 8, or 16 nodes with 8 ranks per node
268+
def is_deepep_feasible(num_ranks: int) -> bool:
269+
NUM_INTRANODE_SUPPORTED_RANKS = {2, 4, 8}
270+
REQUIRED_LOCAL_MPI_SIZE = 8
271+
NUM_INTERNODE_SUPPORTED_RDMA_RANKS = {2, 4, 8, 16}
272+
mpi_size = local_mpi_size()
273+
# Intranode cases
274+
if num_ranks == mpi_size and num_ranks in NUM_INTRANODE_SUPPORTED_RANKS:
275+
return True
276+
# Internode cases
277+
if mpi_size != REQUIRED_LOCAL_MPI_SIZE:
278+
return False
279+
num_rdma_nodes = num_ranks // mpi_size
280+
return num_rdma_nodes in NUM_INTERNODE_SUPPORTED_RDMA_RANKS
281+
258282
all2all_method_type = os.environ.get("TRTLLM_FORCE_ALLTOALL_METHOD")
259283
if all2all_method_type is not None:
260284
return AlltoallMethodType[all2all_method_type]
@@ -276,12 +300,10 @@ def select_alltoall_method_type(mapping: Mapping, top_k: int,
276300

277301
if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1":
278302
if deep_ep_installed and dtype == torch.bfloat16:
279-
if use_cuda_graph:
280-
# Here we can only choose DeepEPLowLatency since only this method supports CUDA Graphs.
281-
return AlltoallMethodType.DeepEPLowLatency
282-
else:
283-
# Here we can choose DeepEP or DeepEPLowLatency if both are available. Now DeepEP is faster.
303+
# Choose DeepEP if feasible
304+
if is_deepep_feasible(mapping.moe_ep_size):
284305
return AlltoallMethodType.DeepEP
306+
return AlltoallMethodType.DeepEPLowLatency
285307

286308
return AlltoallMethodType.NotEnabled
287309

@@ -548,7 +570,7 @@ def forward_chunk(
548570
if not use_postquant_alltoall:
549571
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
550572
self.deep_ep_buffer.dispatch(x, token_selected_slots, token_final_scales, self.num_slots,
551-
self.expert_size_per_partition * self.mapping.moe_ep_rank)
573+
self.expert_size_per_partition * self.mapping.moe_ep_rank, all_rank_max_num_tokens, self.ep_size, self.use_cuda_graph)
552574
padded, x, _, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
553575
x, None, recv_topk_idx, token_final_scales)
554576
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
@@ -636,7 +658,7 @@ def forward_chunk(
636658
x_sf = x_sf.view(torch.float32)
637659
(x, x_sf), recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
638660
self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots, token_final_scales, self.num_slots,
639-
self.expert_size_per_partition * self.mapping.moe_ep_rank)
661+
self.expert_size_per_partition * self.mapping.moe_ep_rank, all_rank_max_num_tokens, self.ep_size, self.use_cuda_graph)
640662
padded, x, x_sf, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
641663
x, x_sf, recv_topk_idx, token_final_scales)
642664
if x_sf is not None:

0 commit comments

Comments
 (0)