Skip to content

Commit 8fea5d4

Browse files
authored
make deep_ep's ll_internode using nvlink when intranode (#72883)
1 parent cf56cbd commit 8fea5d4

File tree

2 files changed

+89
-22
lines changed
  • paddle/fluid/distributed/collective/deep_ep/kernels
  • python/paddle/distributed/communication/deep_ep

2 files changed

+89
-22
lines changed

paddle/fluid/distributed/collective/deep_ep/kernels/internode_ll.cu

Lines changed: 87 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -279,13 +279,33 @@ __global__ __launch_bounds__(
279279
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
280280
slot_idx * num_bytes_per_msg;
281281
if (dst_rank != rank) {
282-
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
283-
src_ptr,
284-
num_bytes_per_msg,
285-
dst_rank,
286-
dst_expert_local_idx,
287-
lane_id,
288-
slot_idx);
282+
void* peer_base_addr = reinterpret_cast<void*>(
283+
__ldg(reinterpret_cast<const uint64_t*>(
284+
nvshmemi_device_state_d.peer_heap_base_p2p) +
285+
dst_rank));
286+
if (peer_base_addr) {
287+
char* req_rptr_actual =
288+
reinterpret_cast<char*>(peer_base_addr) +
289+
(reinterpret_cast<char*>(dst_ptr) -
290+
reinterpret_cast<char*>(nvshmemi_device_state_d.heap_base));
291+
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
292+
const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
293+
UNROLLED_WARP_COPY(8,
294+
lane_id,
295+
num_int4_per_msg,
296+
dst_int4_ptr,
297+
src_int4_ptr,
298+
ld_nc_global,
299+
st_na_global);
300+
} else {
301+
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
302+
src_ptr,
303+
num_bytes_per_msg,
304+
dst_rank,
305+
dst_expert_local_idx,
306+
lane_id,
307+
slot_idx);
308+
}
289309
} else {
290310
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
291311
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
@@ -367,11 +387,24 @@ __global__ __launch_bounds__(
367387
responsible_expert_idx) != FINISHED_SUM_TAG * 2) {
368388
}
369389
if (dst_rank != rank) {
370-
nvshmemi_ibgda_amo_nonfetch_add(
371-
rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
372-
-num_tokens_sent - 1,
373-
dst_rank,
374-
dst_expert_local_idx);
390+
void* peer_base_addr = reinterpret_cast<void*>(
391+
__ldg(reinterpret_cast<const uint64_t*>(
392+
nvshmemi_device_state_d.peer_heap_base_p2p) +
393+
dst_rank));
394+
if (peer_base_addr) { // P2P enabled
395+
int* rptr_actual = reinterpret_cast<int*>(
396+
reinterpret_cast<char*>(peer_base_addr) +
397+
(reinterpret_cast<char*>(rdma_recv_count +
398+
dst_expert_local_idx * num_ranks + rank) -
399+
reinterpret_cast<char*>(nvshmemi_device_state_d.heap_base)));
400+
st_na_release(rptr_actual, -num_tokens_sent - 1);
401+
} else {
402+
nvshmemi_ibgda_amo_nonfetch_add(
403+
rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
404+
-num_tokens_sent - 1,
405+
dst_rank,
406+
dst_expert_local_idx);
407+
}
375408
} else {
376409
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
377410
-num_tokens_sent - 1);
@@ -691,13 +724,32 @@ __global__ __launch_bounds__(
691724
x_int4,
692725
ld_nc_global,
693726
st_na_global);
694-
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
695-
buf_ptr,
696-
hidden * sizeof(nv_bfloat16),
697-
dst_rank,
698-
local_expert_idx,
699-
lane_id,
700-
token_idx - offset);
727+
void* peer_base_addr = reinterpret_cast<void*>(
728+
__ldg(reinterpret_cast<const uint64_t*>(
729+
nvshmemi_device_state_d.peer_heap_base_p2p) +
730+
dst_rank));
731+
if (peer_base_addr) {
732+
char* req_rptr_actual =
733+
reinterpret_cast<char*>(peer_base_addr) +
734+
(reinterpret_cast<char*>(dst_ptr) -
735+
reinterpret_cast<char*>(nvshmemi_device_state_d.heap_base));
736+
const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
737+
UNROLLED_WARP_COPY(7,
738+
lane_id,
739+
hidden_bf16_int4,
740+
dst_int4_ptr,
741+
x_int4,
742+
ld_nc_global,
743+
st_na_global);
744+
} else {
745+
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
746+
buf_ptr,
747+
hidden * sizeof(nv_bfloat16),
748+
dst_rank,
749+
local_expert_idx,
750+
lane_id,
751+
token_idx - offset);
752+
}
701753
}
702754
}
703755

@@ -710,8 +762,22 @@ __global__ __launch_bounds__(
710762
while (ld_acquire_global(atomic_clean_flag) == 0) {
711763
}
712764
if (dst_rank != rank) {
713-
nvshmemi_ibgda_amo_nonfetch_add(
714-
rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
765+
void* peer_base_addr = reinterpret_cast<void*>(
766+
__ldg(reinterpret_cast<const uint64_t*>(
767+
nvshmemi_device_state_d.peer_heap_base_p2p) +
768+
dst_rank));
769+
if (peer_base_addr) {
770+
int* req_rptr_actual = reinterpret_cast<int*>(
771+
reinterpret_cast<char*>(peer_base_addr) +
772+
(reinterpret_cast<char*>(rdma_recv_flag + global_expert_idx) -
773+
reinterpret_cast<char*>(nvshmemi_device_state_d.heap_base)));
774+
st_na_release(req_rptr_actual, 1);
775+
} else {
776+
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx,
777+
1,
778+
dst_rank,
779+
local_expert_idx);
780+
}
715781
} else {
716782
st_na_release(rdma_recv_flag + global_expert_idx, 1);
717783
}

python/paddle/distributed/communication/deep_ep/buffer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def __init__(
108108
# Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA"
109109
if low_latency_mode:
110110
assert num_qps_per_rank > 0
111-
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
111+
if not os.getenv("NVSHMEM_DISABLE_P2P"):
112+
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
112113
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
113114
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
114115
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = (

0 commit comments

Comments
 (0)