diff --git a/csrc/deepep/deep_ep.cpp b/csrc/deepep/deep_ep.cpp index d6ef3541d..b6b99db17 100644 --- a/csrc/deepep/deep_ep.cpp +++ b/csrc/deepep/deep_ep.cpp @@ -942,14 +942,14 @@ Buffer::low_latency_dispatch(const at::Tensor &x, const at::Tensor &topk_idx, auto num_tokens = static_cast(new_x.size(0)), hidden = static_cast(new_x.size(1)); auto num_scales = hidden / 128, num_topk = static_cast(new_topk_idx.size(1)); - auto num_local_experts = num_experts / (num_ranks - shared_expert_rank_num); + int32_t num_local_experts = num_experts / (num_ranks - shared_expert_rank_num); int64_t global_bs = num_max_dispatch_tokens_per_rank * num_ranks; auto num_max_tokens = 0; if (rank < shared_expert_rank_num) { num_max_tokens = global_bs / shared_expert_rank_num; num_local_experts = 1; } else { // moe expert - num_max_tokens = global_bs * num_local_experts; + num_max_tokens = global_bs * std::min(num_topk, num_local_experts); } auto max_size = std::max(num_tokens * num_topk, num_max_tokens * 128);