diff --git a/csrc/deepep/deep_ep.cpp b/csrc/deepep/deep_ep.cpp index 51d7840bf..dbbdcba8f 100644 --- a/csrc/deepep/deep_ep.cpp +++ b/csrc/deepep/deep_ep.cpp @@ -938,11 +938,12 @@ Buffer::low_latency_dispatch(const at::Tensor &x, const at::Tensor &topk_idx, this->new_topk_idx = torch::cat(topk_blocks, 0); } + EP_HOST_ASSERT(num_max_dispatch_tokens_per_rank >= new_x.size(0)); + auto num_tokens = static_cast(new_x.size(0)), hidden = static_cast(new_x.size(1)); auto num_scales = hidden / 128, num_topk = static_cast(new_topk_idx.size(1)); auto num_local_experts = num_experts / (num_ranks - shared_expert_rank_num); - - int64_t global_bs = std::max(new_topk_idx.size(0), num_max_dispatch_tokens_per_rank) * num_ranks; + int64_t global_bs = num_max_dispatch_tokens_per_rank * num_ranks; auto num_max_tokens = 0; if (rank < shared_expert_rank_num) { num_max_tokens = global_bs / shared_expert_rank_num; @@ -1059,6 +1060,7 @@ std::tuple, std::optional= new_idx.size(0)); // EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks); // get ep & tp name @@ -1082,7 +1084,7 @@ std::tuple, std::optional