Skip to content

Commit dee504d

Browse files
committed
Support CUDA graph for DeepEP
Signed-off-by: Yifei Zhang <[email protected]>
1 parent 2e5850c commit dee504d

File tree

3 files changed

+36
-13
lines changed

3 files changed

+36
-13
lines changed

cpp/tensorrt_llm/deep_ep/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
set(DEEP_EP_COMMIT be2582ffe69b5e7d61c3bc9bf7a5316bc48261f9)
1+
set(DEEP_EP_COMMIT f59b4cbee9486f919405e2cfffe7d5f786c2cd25)
22
set(NVSHMEM_URL_HASH
33
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)
44

tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def reserve(self, hidden_size: int, hidden_dtype: torch.dtype):
5959

6060
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
6161
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
62-
num_experts: int, global_expert_id_offset: int) -> \
62+
num_experts: int, global_expert_id_offset: int, all_rank_max_num_tokens: int, ep_size: int, use_cuda_graph: bool) -> \
6363
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple]:
6464
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
6565
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
@@ -71,13 +71,16 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
7171
assert event.event is None
7272

7373
# Do MoE dispatch
74-
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
7574
# For more advanced usages, please refer to the docs of the `dispatch` function
75+
if use_cuda_graph:
76+
num_worst_tokens = all_rank_max_num_tokens * ep_size
77+
else:
78+
num_worst_tokens = 0
7679
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
7780
self.buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
7881
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
7982
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert,
80-
global_expert_id_offset=global_expert_id_offset)
83+
global_expert_id_offset=global_expert_id_offset, num_worst_tokens=num_worst_tokens)
8184
assert event.event is None
8285

8386
# For event management, please refer to the docs of the `EventOverlap` class

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66

77
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
8-
from tensorrt_llm._utils import get_sm_version
8+
from tensorrt_llm._utils import get_sm_version, local_mpi_size
99
from tensorrt_llm.functional import AllReduceStrategy
1010
from tensorrt_llm.logger import logger
1111
from tensorrt_llm.mapping import Mapping
@@ -32,7 +32,7 @@ class AlltoallMethodType(IntEnum):
3232
NotEnabled = 0
3333
# MNNVL
3434
MNNVL = 1
35-
# DeepEP intranode or internode: no CUDA Graphs support, IBGDA is required by internode
35+
# DeepEP intranode or internode: CUDA Graphs are supported, IBGDA is required by internode
3636
DeepEP = 2
3737
# DeepEP low latency: CUDA Graphs are supported, IBGDA is required
3838
DeepEPLowLatency = 3
@@ -101,6 +101,8 @@ def __init__(
101101
self.repeat_idx = 0
102102
self.repeat_count = 1
103103

104+
self.use_cuda_graph = model_config.use_cuda_graph
105+
104106
moe_load_balancer_config = model_config.moe_load_balancer
105107
init_expert_size_per_partition = moe_load_balancer_config.num_local_slots if moe_load_balancer_config else self.num_experts // self.ep_size
106108
self.initial_global_assignments = [
@@ -213,6 +215,7 @@ def __init__(
213215
str(
214216
min(model_config.max_num_tokens,
215217
self.moe_max_num_tokens))))
218+
os.environ['NVSHMEM_QP_DEPTH'] = 2 * (self.deep_ep_max_num_tokens + 1)
216219
self.deep_ep_buffer = buffer_pool.get_low_latency_buffer(
217220
model_config.mapping)
218221
self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens,
@@ -254,6 +257,25 @@ def _check_configs(self):
254257
def select_alltoall_method_type(mapping: Mapping, top_k: int,
255258
dtype: torch.dtype,
256259
use_cuda_graph: bool) -> AlltoallMethodType:
260+
261+
# Check if DeepEP is feasible for the given number of ranks
262+
# DeepEP supports two modes:
263+
# 1. Intranode: Single node with 2, 4, or 8 ranks
264+
# 2. Internode: 2, 4, 8, or 16 nodes with 8 ranks per node
265+
def is_deepep_feasible(num_ranks: int) -> bool:
266+
NUM_INTRANODE_SUPPORTED_RANKS = {2, 4, 8}
267+
REQUIRED_LOCAL_MPI_SIZE = 8
268+
NUM_INTERNODE_SUPPORTED_RDMA_RANKS = {2, 4, 8, 16}
269+
# Intranode cases
270+
if num_ranks in NUM_INTRANODE_SUPPORTED_RANKS:
271+
return True
272+
# Internode cases
273+
mpi_size = local_mpi_size()
274+
if mpi_size != REQUIRED_LOCAL_MPI_SIZE:
275+
return False
276+
num_rdma_nodes = num_ranks // mpi_size
277+
return num_rdma_nodes in NUM_INTERNODE_SUPPORTED_RDMA_RANKS
278+
257279
all2all_method_type = os.environ.get("TRTLLM_FORCE_ALLTOALL_METHOD")
258280
if all2all_method_type is not None:
259281
return AlltoallMethodType[all2all_method_type]
@@ -275,12 +297,10 @@ def select_alltoall_method_type(mapping: Mapping, top_k: int,
275297

276298
if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1":
277299
if deep_ep_installed and dtype == torch.bfloat16:
278-
if use_cuda_graph:
279-
# Here we can only choose DeepEPLowLatency since only this method supports CUDA Graphs.
280-
return AlltoallMethodType.DeepEPLowLatency
281-
else:
282-
# Here we can choose DeepEP or DeepEPLowLatency if both are available. Now DeepEP is faster.
300+
# Choose DeepEP if feasible
301+
if is_deepep_feasible(mapping.moe_ep_size):
283302
return AlltoallMethodType.DeepEP
303+
return AlltoallMethodType.DeepEPLowLatency
284304

285305
return AlltoallMethodType.NotEnabled
286306

@@ -534,7 +554,7 @@ def forward_chunk(
534554
if not use_postquant_alltoall:
535555
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
536556
self.deep_ep_buffer.dispatch(x, token_selected_slots, token_final_scales, self.num_slots,
537-
self.expert_size_per_partition * self.mapping.moe_ep_rank)
557+
self.expert_size_per_partition * self.mapping.moe_ep_rank, all_rank_max_num_tokens, self.ep_size, self.use_cuda_graph)
538558
padded, x, _, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
539559
x, None, recv_topk_idx, token_final_scales)
540560
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
@@ -621,7 +641,7 @@ def forward_chunk(
621641
x_sf = x_sf.view(torch.float32)
622642
(x, x_sf), recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
623643
self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots, token_final_scales, self.num_slots,
624-
self.expert_size_per_partition * self.mapping.moe_ep_rank)
644+
self.expert_size_per_partition * self.mapping.moe_ep_rank, all_rank_max_num_tokens, self.ep_size, self.use_cuda_graph)
625645
padded, x, x_sf, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
626646
x, x_sf, recv_topk_idx, token_final_scales)
627647
if x_sf is not None:

0 commit comments

Comments
 (0)