diff --git a/README.md b/README.md index 435ccea8..c9f95522 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,10 @@ We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each c |:---------:|:------------:|:--------------------:|:-----------:|:--------------------:| | Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) | | Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) | -| Internode | 32 | 44 GB/s (RDMA) | 32 | 47 GB/s (RDMA) | -| Internode | 64 | 46 GB/s (RDMA) | 64 | 45 GB/s (RDMA) | +| Internode | 32 | 58 GB/s (RDMA) | 32 | 57 GB/s (RDMA) | +| Internode | 64 | 51 GB/s (RDMA) | 64 | 50 GB/s (RDMA) | + +**News (2025.04.22)**: with optimizations from Tencent Network Platform Department, performance was enhanced by up to 30%, see [#130](https://github.com/deepseek-ai/DeepEP/pull/130) for more details. Thanks for the contribution! ### Low-latency kernels with pure RDMA diff --git a/csrc/kernels/ibgda_device.cuh b/csrc/kernels/ibgda_device.cuh index 1e8d8715..9f8c37c0 100644 --- a/csrc/kernels/ibgda_device.cuh +++ b/csrc/kernels/ibgda_device.cuh @@ -325,6 +325,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) { st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } +template __device__ static __forceinline__ void nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) { // Get lkey and rkey, store them into lanes @@ -365,7 +366,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, // Submit if (lane_id == 0) - ibgda_submit_requests(qp, base_wqe_idx, num_wqes, message_idx); + ibgda_submit_requests(qp, base_wqe_idx, num_wqes, message_idx); __syncwarp(); } @@ -410,20 +411,25 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe( st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } -__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id) { - nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id); +__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) { + if (is_local_copy) { + // Fallback to NVSHMEM legacy API + nvshmemx_signal_op(static_cast(rptr), value, NVSHMEM_SIGNAL_ADD, pe); + } else { + nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id); - __be32 rkey; - uint64_t raddr; - ibgda_get_rkey(reinterpret_cast(rptr), pe, &raddr, &rkey); + __be32 rkey; + uint64_t raddr; + ibgda_get_rkey(reinterpret_cast(rptr), pe, &raddr, &rkey); - uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); - void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx); + uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); + void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx); - ibgda_write_amo_add_wqe(qp, value, reinterpret_cast(qp->ibuf.buf), - qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs); + ibgda_write_amo_add_wqe(qp, value, reinterpret_cast(qp->ibuf.buf), + qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs); - ibgda_submit_requests(qp, my_wqe_idx, 1); + ibgda_submit_requests(qp, my_wqe_idx, 1); + } } } // namespace deep_ep diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index d6ad5837..2e774606 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -3,6 +3,7 @@ #include "exception.cuh" #include "launch.cuh" #include "utils.cuh" +#include "ibgda_device.cuh" namespace deep_ep { @@ -479,6 +480,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv const bool is_forwarder = sm_id % 2 == 0; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_channels); + const auto role_meta = [=]() -> std::pair { if (is_forwarder) { if (warp_id < NUM_MAX_NVL_PEERS) { @@ -555,19 +558,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Send number of tokens in this channel by `-value - 1` EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { + auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank); if (lane_id < NUM_MAX_NVL_PEERS) { - rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1; + dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1; } else if (lane_id < NUM_MAX_NVL_PEERS * 2) { - rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1; + dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1; } else if (lane_id == NUM_MAX_NVL_PEERS * 2) { - rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; + dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { - rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; + dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; + } + __syncwarp(); + + // Issue RDMA for non-local ranks + if (dst_rdma_rank != rdma_rank) { + nvshmemi_ibgda_put_nbi_warp(reinterpret_cast(rdma_channel_meta.recv_buffer(rdma_rank)), + reinterpret_cast(rdma_channel_meta.send_buffer(dst_rdma_rank)), + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2), + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), + channel_id, lane_id, 0); } - nvshmemx_int_put_nbi_warp(rdma_channel_meta.recv_buffer(rdma_rank), rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); } - nvshmem_fence(); sync_rdma_sender_smem(); // Iterate over tokens and copy into buffer @@ -710,11 +721,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (dst_rdma_rank != rdma_rank) { auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); - nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, - rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, - num_bytes_per_rdma_token * num_tokens_to_issue, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); - nvshmem_fence(); + const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue; + const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); + const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0); } else { // Lighter fence for local RDMA rank memory_fence(); @@ -725,8 +736,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (lane_id == dst_rdma_rank) { last_issued_tail += num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue; - nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); } } } @@ -926,8 +937,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Update remote head if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(lane_id, nvl_rank)); + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head, + translate_dst_rdma_rank(lane_id, nvl_rank), channel_id, lane_id == rdma_rank); last_head = min_head; } @@ -1558,20 +1569,21 @@ combine(int4* combined_x, float* combined_topk_weights, if (sub_warp_id == kNumWarpsPerForwarder - 1) { if (dst_rdma_rank != rdma_rank) { auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; - nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, - rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, - num_chunked_tokens * num_bytes_per_rdma_token, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); - nvshmem_fence(); + const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token; + const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); + const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0); } else { memory_fence(); } // Write new RDMA tail __syncwarp(); - if (lane_id == 0) - nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + if (lane_id == 0) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); + } } } @@ -1656,8 +1668,8 @@ combine(int4* combined_x, float* combined_topk_weights, for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); last_rdma_head = min_head; } } else { diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index c33e0621..8e0d9e4b 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -167,7 +167,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, EP_DEVICE_ASSERT(num_sms > 1); if (sm_id == 0) { // The first SM is also responsible for checking QPs - EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_local_experts); + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts); // The first SM is also responsible for cleaning the next buffer #pragma unroll diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index c9f58797..8e536cac 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -58,14 +58,12 @@ int init(const std::vector &root_unique_id_val, int rank, int num_ranks EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID); } - // Normal operations use IBRC, while low-latency operations use IBGDA - if (low_latency_mode) { - nvshmemi_device_host_state_t* dev_state_ptr = nullptr; - CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&dev_state_ptr), nvshmemi_device_state_d)); + // TODO: we still use `nvshmem_barrier` under IBRC mode, which should be switch to IBGDA mode later + nvshmemi_device_host_state_t* dev_state_ptr = nullptr; + CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&dev_state_ptr), nvshmemi_device_state_d)); - bool ibgda_is_initialized = false; - CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice)); - } + bool ibgda_is_initialized = false; + CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice)); nvshmem_barrier_all(); return nvshmem_my_pe(); } diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 831a2e60..feeb3866 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -31,7 +31,7 @@ class Buffer: def __init__(self, group: dist.ProcessGroup, num_nvl_bytes: int = 0, num_rdma_bytes: int = 0, - low_latency_mode: bool = False, num_qps_per_rank: int = 1) -> None: + low_latency_mode: bool = False, num_qps_per_rank: int = 12) -> None: """ Initialize the communication buffer. @@ -66,17 +66,16 @@ def __init__(self, group: dist.ProcessGroup, # Synchronize NVSHMEM unique IDs root_unique_id = None if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: - # Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA" - if low_latency_mode: - assert num_qps_per_rank > 0 - os.environ['NVSHMEM_DISABLE_P2P'] = '1' - os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' - os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu' - os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}' - # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check - os.environ['NVSHMEM_QP_DEPTH'] = '1024' - # NOTES: NVSHMEM initialization requires at least 256 MiB - os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}' + # Enable IBGDA + assert num_qps_per_rank > 0 + os.environ['NVSHMEM_DISABLE_P2P'] = '1' + os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' + os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu' + os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}' + # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check + os.environ['NVSHMEM_QP_DEPTH'] = '1024' + # NOTES: NVSHMEM initialization requires at least 256 MiB + os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}' # Synchronize using the root ID nvshmem_unique_ids = [None, ] * self.group_size diff --git a/tests/test_internode.py b/tests/test_internode.py index 5884a16a..e9b3d57b 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -219,16 +219,19 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): def test_loop(local_rank: int, num_local_ranks: int): num_nodes = int(os.getenv('WORLD_SIZE', 1)) rank, num_ranks, group = init_dist(local_rank, num_local_ranks) - test_ll_compatibility = False + test_ll_compatibility = True if test_ll_compatibility: ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 + num_sms = 24 + num_qps_per_rank = max(num_sms // 2, ll_num_experts // num_ranks if test_ll_compatibility else 0) + buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility, - num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) + num_qps_per_rank=num_qps_per_rank) assert num_local_ranks == 8 and num_ranks > 8 torch.manual_seed(rank) - for i in (24, ): + for i in (num_sms, ): test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group) if local_rank == 0: print('', flush=True)