Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x,
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16));
auto packed_recv_src_info =
torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));

Expand Down Expand Up @@ -1495,7 +1495,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x,
internode_ll::dispatch(
packed_recv_x.data_ptr(),
packed_recv_x_scales_ptr,
packed_recv_src_info.data_ptr<int>(),
packed_recv_src_info.data_ptr<int64_t>(),
packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
mask_buffer_ptr,
Expand Down Expand Up @@ -1554,6 +1554,12 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
const torch::Tensor& topk_weights,
const torch::Tensor& src_info,
const torch::Tensor& layout_range,
bool overlap,
const std::optional<torch::Tensor>& packed_recv_count,
const std::optional<torch::Tensor>& comp_signal,
int block_m,
int threshold,
int num_sms,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank,
int num_experts,
Expand All @@ -1564,6 +1570,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
const std::optional<torch::Tensor>& out) {
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT(low_latency_mode);
EP_HOST_ASSERT((!overlap || return_recv_hook) and "Overlap mode requires return_recv_hook=True");

// Tensor checks
EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
Expand All @@ -1577,11 +1584,17 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous());
EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0));
EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt64 and x.size(0) == src_info.size(0));
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);

if (comp_signal.has_value()) {
EP_HOST_ASSERT(comp_signal->dim() == 1 and comp_signal->is_contiguous());
EP_HOST_ASSERT(comp_signal->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(comp_signal->size(0) == num_experts / num_ranks * ceil_div(num_ranks * num_max_dispatch_tokens_per_rank, 64));
}

if (combine_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous());
Expand Down Expand Up @@ -1627,8 +1640,13 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
x.data_ptr(),
topk_idx.data_ptr<topk_idx_t>(),
topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(),
src_info.data_ptr<int64_t>(),
layout_range.data_ptr<int64_t>(),
overlap,
packed_recv_count.has_value() ? packed_recv_count->data_ptr<int>() : nullptr,
comp_signal.has_value() ? comp_signal->data_ptr<int>() : nullptr,
block_m,
threshold,
mask_buffer_ptr,
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
next_clean_meta.first,
Expand All @@ -1643,6 +1661,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
use_logfmt,
workspace,
num_device_sms,
num_sms,
launch_stream,
phases,
zero_copy);
Expand Down
6 changes: 6 additions & 0 deletions csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ struct Buffer {
const torch::Tensor& topk_weights,
const torch::Tensor& src_info,
const torch::Tensor& layout_range,
bool overlap,
const std::optional<torch::Tensor>& packed_recv_count,
const std::optional<torch::Tensor>& comp_signal,
int block_m,
int threshold,
int num_sms,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank,
int num_experts,
Expand Down
10 changes: 8 additions & 2 deletions csrc/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ void clean_low_latency_buffer(int* clean_0,

void dispatch(void* packed_recv_x,
void* packed_recv_x_scales,
int* packed_recv_src_info,
int64_t* packed_recv_src_info,
int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* mask_buffer,
Expand Down Expand Up @@ -317,8 +317,13 @@ void combine(void* combined_x,
const void* x,
const topk_idx_t* topk_idx,
const float* topk_weights,
const int* src_info,
const int64_t* src_info,
const int64_t* layout_range,
bool overlap,
int* packed_recv_count,
int* comp_signal,
int block_m,
int threshold,
int* mask_buffer,
int64_t* combine_wait_recv_cost_stats,
int* next_clean,
Expand All @@ -333,6 +338,7 @@ void combine(void* combined_x,
bool use_logfmt,
void* workspace,
int num_device_sms,
int num_sms,
cudaStream_t stream,
int phases,
bool zero_copy);
Expand Down
Loading