diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 954c9ffb..ab305952 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1588,7 +1588,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16)); auto packed_recv_src_info = - torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); @@ -1618,7 +1618,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, internode_ll::dispatch( packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, - packed_recv_src_info.data_ptr(), + packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), packed_recv_count.data_ptr(), mask_buffer_ptr, @@ -1677,6 +1677,12 @@ std::tuple, std::optional& packed_recv_count, + const std::optional& comp_signal, + int block_m, + int threshold, + int num_sms, const std::optional& combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, @@ -1687,6 +1693,7 @@ std::tuple, std::optional& out) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); + EP_HOST_ASSERT((!overlap || return_recv_hook) and "Overlap mode requires return_recv_hook=True"); // Tensor checks EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); @@ -1700,11 +1707,17 @@ std::tuple, std::optionaldim() == 1 and comp_signal->is_contiguous()); + EP_HOST_ASSERT(comp_signal->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(comp_signal->size(0) == num_experts / num_ranks * ceil_div(num_ranks * num_max_dispatch_tokens_per_rank, 64)); + } + if (combine_wait_recv_cost_stats.has_value()) { EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64); EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous()); @@ -1750,8 +1763,13 @@ std::tuple, std::optional(), topk_weights.data_ptr(), - src_info.data_ptr(), + src_info.data_ptr(), layout_range.data_ptr(), + overlap, + packed_recv_count.has_value() ? packed_recv_count->data_ptr() : nullptr, + comp_signal.has_value() ? comp_signal->data_ptr() : nullptr, + block_m, + threshold, mask_buffer_ptr, combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr() : nullptr, next_clean_meta.first, @@ -1766,6 +1784,7 @@ std::tuple, std::optional& packed_recv_count, + const std::optional& comp_signal, + int block_m, + int threshold, + int num_sms, const std::optional& combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index 9bbe096a..95639e8e 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -284,7 +284,7 @@ void clean_low_latency_buffer(int* clean_0, void dispatch(void* packed_recv_x, void* packed_recv_x_scales, - int* packed_recv_src_info, + int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* mask_buffer, @@ -319,8 +319,13 @@ void combine(void* combined_x, const void* x, const topk_idx_t* topk_idx, const float* topk_weights, - const int* src_info, + const int64_t* src_info, const int64_t* layout_range, + bool overlap, + int* packed_recv_count, + int* comp_signal, + int block_m, + int threshold, int* mask_buffer, int64_t* combine_wait_recv_cost_stats, int* next_clean, @@ -335,6 +340,7 @@ void combine(void* combined_x, bool use_logfmt, void* workspace, int num_device_sms, + int num_sms, cudaStream_t stream, int phases, bool zero_copy); diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index e9fd473b..9215b1cc 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -129,7 +129,7 @@ void clean_low_latency_buffer(int* clean_0, template __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, void* packed_recv_x_scales, - int* packed_recv_src_info, + int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* mask_buffer_ptr, @@ -427,7 +427,7 @@ LOW_LATENCY_DISPATCH_RECV: // Copy source info const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); if (lane_id == 0) - recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); + recv_src_info[recv_token_begin_idx + i] = pack2(ld_nc_global(src_src_idx), src_rank); __syncwarp(); // Copy data @@ -464,7 +464,7 @@ LOW_LATENCY_DISPATCH_RECV: void dispatch(void* packed_recv_x, void* packed_recv_x_scales, - int* packed_recv_src_info, + int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* mask_buffer_ptr, @@ -720,13 +720,19 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, const void* x, const topk_idx_t* topk_idx, const float* topk_weights, - const int* src_info, + const int64_t* src_info, const int64_t* layout_range, + bool overlap, + int* packed_recv_count, + int* comp_signal, + int block_m, + int threshold, int* mask_buffer_ptr, int64_t* combine_wait_recv_cost_stats, int* next_clean, int num_next_clean_int, int* atomic_clean_flag, + int* atomic_finish_counter_per_expert, int num_combined_tokens, int hidden, int num_topk, @@ -736,6 +742,7 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, int num_ranks, int num_warp_groups, int num_warps_per_group, + int smem_send_size, int phases, bool zero_copy) { const auto sm_id = __shfl_sync(0xffffffff, static_cast(blockIdx.x), 0); @@ -770,6 +777,11 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16) + kNumMetaBytes; EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); + // Parameters for IBGDA sends outer loop, declared upfront to bypass goto initialization restrictions. + int initial_idx, loop_bound, step_size; + // Shared between warps in sms for overlap mode, where each sm only has one warp group + auto shared_vaild_signal_prefix_sum = reinterpret_cast(smem_buffer + smem_send_size); + // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV; @@ -786,10 +798,40 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, atomic_add_release_global(atomic_clean_flag, num_experts); } - // Issue IBGDA sends - if (responsible_expert_idx < num_experts) { - const auto dst_rank = responsible_expert_idx / num_local_experts; - const auto local_expert_idx = responsible_expert_idx % num_local_experts; + __shared__ int shared_vaild_signal_sum, shared_local_expert_idx; + + // Compute prefix sums of valid signal counts per local expert + if (overlap) { + if (sub_warp_id == 0 and lane_id == 0) { + shared_vaild_signal_prefix_sum[0] = (packed_recv_count[0] == 0 ? 1 : ceil_div(packed_recv_count[0], block_m)); + shared_local_expert_idx = 0; + #pragma unroll + for (int i = 1; i < num_local_experts; i++) { + shared_vaild_signal_prefix_sum[i] = shared_vaild_signal_prefix_sum[i-1] + + (packed_recv_count[i] == 0 ? 1 : ceil_div(packed_recv_count[i], block_m)); + } + shared_vaild_signal_sum = shared_vaild_signal_prefix_sum[num_local_experts-1]; + } + __syncthreads(); + } + + // Issue IBGDA sends, non-overlap mode only loops once + initial_idx = overlap ? sm_id : responsible_expert_idx; + loop_bound = overlap ? shared_vaild_signal_sum : num_experts; + step_size = overlap ? num_sms : num_experts; + for (int vaild_signal_idx = initial_idx; vaild_signal_idx < loop_bound; vaild_signal_idx += step_size) { + + // Find the owning local_expert_idx by scanning the prefix-sum array + if (overlap) { + if (sub_warp_id == 0 and lane_id == 0) { + while (vaild_signal_idx >= shared_vaild_signal_prefix_sum[shared_local_expert_idx]) + shared_local_expert_idx++; + } + __syncthreads(); + } + + auto dst_rank = responsible_expert_idx / num_local_experts; + const auto local_expert_idx = overlap ? shared_local_expert_idx : responsible_expert_idx % num_local_experts; const auto global_expert_idx = rank * num_local_experts + local_expert_idx; const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank); const auto local_x = @@ -802,6 +844,22 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, int offset, num_tokens_to_send; unpack2(layout, num_tokens_to_send, offset); + // Wait the corresponding comp_signal to reach the threshold + int num_tokens_per_expert, num_signal_per_expert, local_expert_signal_idx; + const int* gemm_comp_signal; + if (overlap) { + num_tokens_per_expert = packed_recv_count[local_expert_idx]; + num_signal_per_expert = ceil_div(num_ranks * num_max_dispatch_tokens_per_rank, block_m); + local_expert_signal_idx = (local_expert_idx == 0) ? vaild_signal_idx : + vaild_signal_idx - shared_vaild_signal_prefix_sum[local_expert_idx-1]; + gemm_comp_signal = comp_signal + num_signal_per_expert * local_expert_idx + local_expert_signal_idx; + + if (sub_warp_id == 0 and lane_id == 0 and num_tokens_per_expert != 0) { + while (ld_acquire_global(gemm_comp_signal) != threshold); + } + __syncthreads(); + } + // TMA stuffs constexpr int kNumTMABufferBytes = sizeof(int4) * 32 * kNumSendUnrolls; constexpr int kNumStages = 3; @@ -833,14 +891,19 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, }; // Issue IBGDA send - if (not is_rank_masked(mask_buffer_ptr, dst_rank)) { - for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) { + if (overlap or (not is_rank_masked(mask_buffer_ptr, dst_rank))) { + auto token_start_idx = overlap ? local_expert_signal_idx * block_m : offset; + auto token_end_idx = overlap ? min((local_expert_signal_idx + 1) * block_m, num_tokens_per_expert) : (offset + num_tokens_to_send); + for (int token_idx = sub_warp_id + token_start_idx; token_idx < token_end_idx; token_idx += num_warps_per_group) { const auto x_int4 = local_x + token_idx * hidden_bf16_int4; const auto rdma_send_type_row = reinterpret_cast(rdma_send_x_vec + token_idx * num_bytes_per_slot); const auto rdma_send_x_vec_row = reinterpret_cast(rdma_send_type_row); // Copy directly to local rank, or copy to buffer and issue RDMA - const auto src_idx = __shfl_sync(0xffffffff, __ldg(local_src_info + token_idx), 0); + overlap ? (dst_rank = __shfl_sync(0xffffffff, static_cast(__ldg(local_src_info + token_idx) >> 32), 0)) : 0; + if (overlap and is_rank_masked(mask_buffer_ptr, dst_rank)) + continue; + const auto src_idx = __shfl_sync(0xffffffff, static_cast(__ldg(local_src_info + token_idx) & 0xffffffff), 0); const auto buf_ptr = reinterpret_cast(rdma_send_x_vec_row); const auto dst_ptr = reinterpret_cast(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot; @@ -909,14 +972,13 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, // Issue RDMA // NOTES: for zero-copy mode, we assume the data is already in the send buffer if (dst_p2p_ptr == 0) - nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, num_send_bytes, dst_rank, local_expert_idx, lane_id, token_idx - offset); + nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, num_send_bytes, dst_rank, local_expert_idx, lane_id, token_idx); } } - // Put the finishing flag - EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16); asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 1), "r"(num_warps_per_group * 32)); - if (sub_warp_id == 1 and lane_id == 0) { + + auto send_finish_flag = [&](int dst_rank) { while (ld_acquire_global(atomic_clean_flag) == 0) ; auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx); @@ -929,8 +991,38 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, } } atomic_add_release_global(atomic_clean_flag, -1); + }; + + if (overlap) { + // Put the finishing flag for overlap mode + bool put_finish_flag = false; + if (sub_warp_id == 0) { + if (lane_id == 0) { + const auto finish_counter = (num_tokens_per_expert == 0 ? 1 : ceil_div(num_tokens_per_expert, block_m)); + if ((atomicAdd(atomic_finish_counter_per_expert + local_expert_idx, 1) + 1) == finish_counter) + put_finish_flag = true; + } + put_finish_flag = __shfl_sync(0xffffffff, put_finish_flag, 0); + } + __syncthreads(); + + if (sub_warp_id == 0 and put_finish_flag) { + for (int dst_rank = lane_id; dst_rank < num_ranks; dst_rank += 32) { + send_finish_flag(dst_rank); + } + if (lane_id == 0) + atomic_finish_counter_per_expert[local_expert_idx] = 0; + } + __syncthreads(); + } + else { + // Put the finishing flag for non-overlap mode + EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16); + if (sub_warp_id == 1 and lane_id == 0) { + send_finish_flag(dst_rank); + } + __syncwarp(); } - __syncwarp(); // Destroy m-barriers if (lane_id < kNumStages) { @@ -1145,8 +1237,13 @@ void combine(void* combined_x, const void* x, const topk_idx_t* topk_idx, const float* topk_weights, - const int* src_info, + const int64_t* src_info, const int64_t* layout_range, + bool overlap, + int* packed_recv_count, + int* comp_signal, + int block_m, + int threshold, int* mask_buffer_ptr, int64_t* combine_wait_recv_cost_stats, int* next_clean, @@ -1161,22 +1258,37 @@ void combine(void* combined_x, bool use_logfmt, void* workspace, int num_device_sms, + int num_sms, cudaStream_t stream, int phases, bool zero_copy) { constexpr int kNumMaxTopk = 11; - const int num_warp_groups = ceil_div(num_experts, num_device_sms); - const int num_warps_per_group = 32 / num_warp_groups; - const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms); - EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0); + int num_warp_groups, num_warps_per_group, num_recv_per_sm, num_warps; - const auto num_warps = num_warp_groups * num_warps_per_group; - const auto num_sms = - max(ceil_div(num_experts, num_warp_groups), num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm)); + if (overlap == true and phases == LOW_LATENCY_SEND_PHASE) { + num_warp_groups = 1; + num_warps_per_group = 32; + num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms); + EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0 and block_m > 0 and threshold > 0); + + num_warps = num_warp_groups * num_warps_per_group; + } + else { + num_warp_groups = ceil_div(num_experts, num_device_sms); + num_warps_per_group = 32 / num_warp_groups; + num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms); + EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0); + + num_warps = num_warp_groups * num_warps_per_group; + num_sms = + max(ceil_div(num_experts, num_warp_groups), num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm)); + } // Check workspace + // 1 int: clean flag + `num_experts` ints: per-expert atomic finish counter for overlap mode auto atomic_clean_flag = static_cast(workspace); - EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES); + auto atomic_finish_counter_per_expert = atomic_clean_flag + 1; + EP_HOST_ASSERT((1 + num_experts) * sizeof(int) <= NUM_WORKSPACE_BYTES); EP_HOST_ASSERT(num_topk <= kNumMaxTopk); // Online cast cannot use zero-copy @@ -1191,12 +1303,16 @@ void combine(void* combined_x, const int num_send_tma_bytes = 32 * sizeof(int4) * kNumMaxUnrolls + 16; const int smem_send_size = num_warps * (kNumStages * num_send_tma_bytes + num_meta_bytes); + // prefix_sum size, used for shared_vaild_signal_prefix_sum + const int num_local_experts = num_experts / num_ranks; + const int smem_prefix_sum_size = num_local_experts * sizeof(int); + // Receive buffer size const int num_recv_tma_bytes = 16 + hidden * 2; const int smem_recv_size = kMaxNumGroups * (kNumStages * num_recv_tma_bytes + hidden * 2 + kNumStages * num_meta_bytes * 3); // Total requirement - const int smem_size = max(smem_send_size, smem_recv_size); + const int smem_size = max(smem_send_size + smem_prefix_sum_size, smem_recv_size); #define COMBINE_LAUNCH_CASE(hidden) \ { \ @@ -1214,11 +1330,17 @@ void combine(void* combined_x, topk_weights, \ src_info, \ layout_range, \ + overlap, \ + packed_recv_count, \ + comp_signal, \ + block_m, \ + threshold, \ mask_buffer_ptr, \ combine_wait_recv_cost_stats, \ next_clean, \ num_next_clean_int, \ atomic_clean_flag, \ + atomic_finish_counter_per_expert, \ num_combined_tokens, \ hidden, \ num_topk, \ @@ -1228,6 +1350,7 @@ void combine(void* combined_x, num_ranks, \ num_warp_groups, \ num_warps_per_group, \ + smem_send_size, \ phases, \ zero_copy); \ } \ diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 37512ee9..da17b806 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -614,8 +614,10 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, EventOverlap(event, tensors_to_record if async_finish else None), hook # noinspection PyTypeChecker - def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False, + def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: tuple, + overlap: bool = False, packed_recv_count: torch.Tensor = None, comp_signal: torch.Tensor = None, + block_m: int = 64, threshold: int = 0, num_sms: int = 3, + use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False, return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \ Tuple[torch.Tensor, EventOverlap, Callable]: @@ -635,6 +637,17 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched tokens. The received tokens will be reduced with the weights in this tensor. handle: the communication handle given by the `dispatch` function. + overlap: whether to overlap the down gemm with the combine send phase. + packed_recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each + expert receive. + comp_signal: `[num_local_experts * ceil_div(num_tokens * num_max_dispatch_tokens_per_rank, block_m)]` with `torch.int32`, + each element indicates the processing progress of `block_m` tokens in DeepGEMM. + Note that, the fixed-length tensor is used to support cuda graph, + only the first `ceil_div(num_tokens * num_ranks, block_m)` elements of each local_expert are valid. + block_m: set by DeepGEMM. + threshold: set by DeepGEMM. When a valid element in comp_signal reaches this threshold, it means that all the tokens + corresponding to this element have been computed by DeepGEMM and can be sent. + num_sms: the number of sms used by low_latency_combine send, only needs to be set when overlap is `True`. use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits). zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative with `get_next_low_latency_combine_buffer`. @@ -655,6 +668,7 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2 combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, + overlap, packed_recv_count, comp_signal, block_m, threshold, num_sms, combine_wait_recv_cost_stats, num_max_dispatch_tokens_per_rank, num_experts, use_logfmt, zero_copy, async_finish, return_recv_hook, out) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index 456dcf27..19634a94 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -6,7 +6,7 @@ from typing import Literal, Set import deep_ep -from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back +from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back, ceil_div def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]): @@ -127,12 +127,12 @@ def test_main(num_tokens: int, if current_x is x: recv_x = recv_x[:num_valid_tokens] recv_x_amin = recv_x[:, :-128].amin(dim=-1) - recv_src_info = recv_src_info[:num_valid_tokens] + src_token_idx = recv_src_info[:num_valid_tokens] & int_mask assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) if round_scale: - assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 + assert calc_diff(recv_x[:, -1], src_token_idx.view(-1)) < 0.007 else: - assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 + assert (recv_x[:, -128:] - src_token_idx.view(-1, 1) % num_tokens).sum().item() == 0 for j in range(num_ranks): if shrink_test and mask_status[j]: continue @@ -150,34 +150,58 @@ def test_main(num_tokens: int, if shrink_test and simulate_failure_and_skip(rank, "combine", expected_masked_ranks): break for zero_copy in (False, ) if use_logfmt else (False, True): - if zero_copy: - buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x - out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, - topk_idx, - topk_weights, - handle, - use_logfmt=use_logfmt, - async_finish=not return_recv_hook, - zero_copy=zero_copy, - return_recv_hook=return_recv_hook, - out=out) - hook() if return_recv_hook else event.current_stream_wait() - if shrink_test: - query_mask_buffer_and_check("combine", buffer, mask_status, expected_masked_ranks) - if do_check: + for overlap in (False, True) if return_recv_hook else (False, ): + if zero_copy: + buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x + out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + if overlap: + block_m, threshold, num_sms = 64, 10, 3 + total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) + comp_signal = torch.zeros(num_local_experts * total_num_per_expert, dtype=torch.int32, device='cuda') + for i in range(num_local_experts): + vaild_num = ceil_div(packed_recv_count[i], block_m) + comp_signal[i * total_num_per_expert : i * total_num_per_expert + vaild_num] = threshold + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, + topk_idx, + topk_weights, + handle, + overlap=True, + packed_recv_count=packed_recv_count, + comp_signal=comp_signal, + block_m=block_m, + threshold=threshold, + num_sms=num_sms, + use_logfmt=use_logfmt, + async_finish=not return_recv_hook, + zero_copy=zero_copy, + return_recv_hook=return_recv_hook, + out=out) + else: + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, + topk_idx, + topk_weights, + handle, + use_logfmt=use_logfmt, + async_finish=not return_recv_hook, + zero_copy=zero_copy, + return_recv_hook=return_recv_hook, + out=out) + hook() if return_recv_hook else event.current_stream_wait() if shrink_test: - owner_by_expert = (torch.arange(num_experts, device='cuda') // num_local_experts) - fail_owner_mask = (mask_status == 1).index_select(0, owner_by_expert) - valid_topk_idx = topk_idx >= 0 - failed_topk_idx = torch.zeros_like(topk_idx, device='cuda', dtype=torch.bool) - failed_topk_idx[valid_topk_idx] = fail_owner_mask.index_select(0, topk_idx[valid_topk_idx]) - topk_idx[failed_topk_idx] = -1 - diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) - assert torch.isnan(combined_x).sum().item() == 0 - if not round_scale: - assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}' - hash_value ^= hash_tensor(combined_x) + query_mask_buffer_and_check("combine", buffer, mask_status, expected_masked_ranks) + if do_check: + if shrink_test: + owner_by_expert = (torch.arange(num_experts, device='cuda') // num_local_experts) + fail_owner_mask = (mask_status == 1).index_select(0, owner_by_expert) + valid_topk_idx = topk_idx >= 0 + failed_topk_idx = torch.zeros_like(topk_idx, device='cuda', dtype=torch.bool) + failed_topk_idx[valid_topk_idx] = fail_owner_mask.index_select(0, topk_idx[valid_topk_idx]) + topk_idx[failed_topk_idx] = -1 + diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) + assert torch.isnan(combined_x).sum().item() == 0 + if not round_scale: + assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}' + hash_value ^= hash_tensor(combined_x) # Clean buffer API if shrink_test: diff --git a/tests/utils.py b/tests/utils.py index 1390b2b9..cfdd6083 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -240,3 +240,7 @@ def bench_kineto(fn, def hash_tensor(t: torch.Tensor): return t.view(torch.int).sum().item() + + +def ceil_div(a, b): + return (a + b - 1) // b