Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each c
|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:|
| Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) |
| Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) |
| Internode | 32 | 44 GB/s (RDMA) | 32 | 47 GB/s (RDMA) |
| Internode | 64 | 46 GB/s (RDMA) | 64 | 45 GB/s (RDMA) |
| Internode | 32 | 58 GB/s (RDMA) | 32 | 57 GB/s (RDMA) |
| Internode | 64 | 51 GB/s (RDMA) | 64 | 50 GB/s (RDMA) |

**News (2025.04.22)**: with optimizations from Tencent Network Platform Department, performance was enhanced by up to 30%, see [#130](https://github.com/deepseek-ai/DeepEP/pull/130) for more details. Thanks for the contribution!

### Low-latency kernels with pure RDMA

Expand Down
28 changes: 17 additions & 11 deletions csrc/kernels/ibgda_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
}

template <bool kAlwaysDoPostSend = false>
__device__ static __forceinline__ void
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
// Get lkey and rkey, store them into lanes
Expand Down Expand Up @@ -365,7 +366,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,

// Submit
if (lane_id == 0)
ibgda_submit_requests<false>(qp, base_wqe_idx, num_wqes, message_idx);
ibgda_submit_requests<kAlwaysDoPostSend>(qp, base_wqe_idx, num_wqes, message_idx);
__syncwarp();
}

Expand Down Expand Up @@ -410,20 +411,25 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<int4*>(&data_seg));
}

__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id) {
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) {
if (is_local_copy) {
// Fallback to NVSHMEM legacy API
nvshmemx_signal_op(static_cast<uint64_t*>(rptr), value, NVSHMEM_SIGNAL_ADD, pe);
} else {
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);

__be32 rkey;
uint64_t raddr;
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);
__be32 rkey;
uint64_t raddr;
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);

uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);

ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);
ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);

ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
}
}

} // namespace deep_ep
64 changes: 38 additions & 26 deletions csrc/kernels/internode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "exception.cuh"
#include "launch.cuh"
#include "utils.cuh"
#include "ibgda_device.cuh"

namespace deep_ep {

Expand Down Expand Up @@ -479,6 +480,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
const bool is_forwarder = sm_id % 2 == 0;
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;

EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_channels);

const auto role_meta = [=]() -> std::pair<WarpRole, int> {
if (is_forwarder) {
if (warp_id < NUM_MAX_NVL_PEERS) {
Expand Down Expand Up @@ -555,19 +558,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers");
for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank);
if (lane_id < NUM_MAX_NVL_PEERS) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
} else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
}
__syncwarp();

// Issue RDMA for non-local ranks
if (dst_rdma_rank != rdma_rank) {
nvshmemi_ibgda_put_nbi_warp<true>(reinterpret_cast<uint64_t>(rdma_channel_meta.recv_buffer(rdma_rank)),
reinterpret_cast<uint64_t>(rdma_channel_meta.send_buffer(dst_rdma_rank)),
sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2),
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
channel_id, lane_id, 0);
}
nvshmemx_int_put_nbi_warp(rdma_channel_meta.recv_buffer(rdma_rank), rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
nvshmem_fence();
sync_rdma_sender_smem();

// Iterate over tokens and copy into buffer
Expand Down Expand Up @@ -710,11 +721,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
num_bytes_per_rdma_token * num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
nvshmem_fence();
const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0);
} else {
// Lighter fence for local RDMA rank
memory_fence();
Expand All @@ -725,8 +736,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (lane_id == dst_rdma_rank) {
last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue;
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
}
}
}
Expand Down Expand Up @@ -926,8 +937,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv

// Update remote head
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank), channel_id, lane_id == rdma_rank);
last_head = min_head;
}

Expand Down Expand Up @@ -1558,20 +1569,21 @@ combine(int4* combined_x, float* combined_topk_weights,
if (sub_warp_id == kNumWarpsPerForwarder - 1) {
if (dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
nvshmem_fence();
const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0);
} else {
memory_fence();
}

// Write new RDMA tail
__syncwarp();
if (lane_id == 0)
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
if (lane_id == 0) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
}
}
}

Expand Down Expand Up @@ -1656,8 +1668,8 @@ combine(int4* combined_x, float* combined_topk_weights,
for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
last_rdma_head = min_head;
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion csrc/kernels/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
EP_DEVICE_ASSERT(num_sms > 1);
if (sm_id == 0) {
// The first SM is also responsible for checking QPs
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_local_experts);
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts);

// The first SM is also responsible for cleaning the next buffer
#pragma unroll
Expand Down
12 changes: 5 additions & 7 deletions csrc/kernels/runtime.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,12 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
}

// Normal operations use IBRC, while low-latency operations use IBGDA
if (low_latency_mode) {
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
// TODO: we still use `nvshmem_barrier` under IBRC mode, which should be switch to IBGDA mode later
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));

bool ibgda_is_initialized = false;
CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice));
}
bool ibgda_is_initialized = false;
CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice));
nvshmem_barrier_all();
return nvshmem_my_pe();
}
Expand Down
23 changes: 11 additions & 12 deletions deep_ep/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Buffer:

def __init__(self, group: dist.ProcessGroup,
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
low_latency_mode: bool = False, num_qps_per_rank: int = 1) -> None:
low_latency_mode: bool = False, num_qps_per_rank: int = 12) -> None:
"""
Initialize the communication buffer.

Expand Down Expand Up @@ -66,17 +66,16 @@ def __init__(self, group: dist.ProcessGroup,
# Synchronize NVSHMEM unique IDs
root_unique_id = None
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
# Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA"
if low_latency_mode:
assert num_qps_per_rank > 0
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
os.environ['NVSHMEM_QP_DEPTH'] = '1024'
# NOTES: NVSHMEM initialization requires at least 256 MiB
os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}'
# Enable IBGDA
assert num_qps_per_rank > 0
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
os.environ['NVSHMEM_QP_DEPTH'] = '1024'
# NOTES: NVSHMEM initialization requires at least 256 MiB
os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}'

# Synchronize using the root ID
nvshmem_unique_ids = [None, ] * self.group_size
Expand Down
9 changes: 6 additions & 3 deletions tests/test_internode.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,19 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
def test_loop(local_rank: int, num_local_ranks: int):
num_nodes = int(os.getenv('WORLD_SIZE', 1))
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility = False
test_ll_compatibility = True
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9

num_sms = 24
num_qps_per_rank = max(num_sms // 2, ll_num_experts // num_ranks if test_ll_compatibility else 0)

buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
num_qps_per_rank=num_qps_per_rank)
assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)

for i in (24, ):
for i in (num_sms, ):
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
if local_rank == 0:
print('', flush=True)
Expand Down