Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook) {
Expand All @@ -1110,6 +1111,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks);
}

if (dispatch_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous());
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks);
}

auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_topk = static_cast<int>(topk_idx.size(1));
auto num_local_experts = num_experts / num_ranks;
Expand Down Expand Up @@ -1162,6 +1169,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
Expand Down Expand Up @@ -1200,6 +1208,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) {
Expand All @@ -1222,6 +1231,13 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);

if (combine_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous());
EP_HOST_ASSERT(combine_wait_recv_cost_stats->size(0) == num_ranks);
}

auto hidden = static_cast<int>(x.size(2));
auto num_topk = static_cast<int>(topk_weights.size(1));
auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
Expand Down Expand Up @@ -1259,6 +1275,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
buffer.combine_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
Expand Down
2 changes: 2 additions & 0 deletions csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,15 @@ struct Buffer {
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook);

std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);
Expand Down
2 changes: 2 additions & 0 deletions csrc/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
Expand All @@ -156,6 +157,7 @@ void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
Expand Down
17 changes: 17 additions & 0 deletions csrc/kernels/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
Expand Down Expand Up @@ -272,14 +273,20 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int num_recv_tokens, recv_token_begin_idx;
EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15);
if (sub_warp_id == 1 and lane_id == 0) {
auto start_time = clock64();
while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
auto wait_recv_cost = clock64() - start_time;
num_recv_tokens = -num_recv_tokens - 1;
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
if (cumulative_local_expert_recv_stats != nullptr)
atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);

if (dispatch_wait_recv_cost_stats != nullptr)
atomicAdd(reinterpret_cast<unsigned long long*>(dispatch_wait_recv_cost_stats + src_rank),
wait_recv_cost);
}
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
Expand Down Expand Up @@ -330,6 +337,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
Expand Down Expand Up @@ -368,6 +376,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
cumulative_local_expert_recv_stats, \
dispatch_wait_recv_cost_stats, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
Expand All @@ -388,6 +397,7 @@ combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int,
int* atomic_clean_flag,
int num_combined_tokens, int hidden, int num_topk,
Expand Down Expand Up @@ -618,7 +628,12 @@ combine(void* combined_x,
if (responsible_expert_idx < num_experts) {
EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 0 and lane_id == 0) {
auto start_time = clock64();
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
auto wait_recv_cost = clock64() - start_time;
if (combine_wait_recv_cost_stats != nullptr)
atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats
+ responsible_expert_idx / num_local_experts), wait_recv_cost);
}
}
cg::this_grid().sync();
Expand Down Expand Up @@ -667,6 +682,7 @@ void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
Expand Down Expand Up @@ -701,6 +717,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
Expand Down
12 changes: 11 additions & 1 deletion deep_ep/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
Expand All @@ -535,6 +536,9 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
monitoring.
dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
Expand Down Expand Up @@ -565,6 +569,7 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
self.runtime.low_latency_dispatch(x, topk_idx,
cumulative_local_expert_recv_stats,
dispatch_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, round_scale, use_ue8m0,
async_finish, return_recv_hook)
Expand All @@ -579,7 +584,8 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]:
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
Expand All @@ -605,6 +611,9 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you do not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.

Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `torch.bfloat16`.
Expand All @@ -613,6 +622,7 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig
"""
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts,
use_logfmt, zero_copy, async_finish, return_recv_hook,
out)
Expand Down
Loading