From 3d317d514c066c550a5485253c51b73549832bf8 Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Fri, 12 Sep 2025 21:11:16 +0800 Subject: [PATCH] ep support logprob (#4089) --- .../gpu_ops/get_output_msg_with_topk.cc | 3 --- fastdeploy/engine/args_utils.py | 2 -- fastdeploy/output/token_processor.py | 23 +++++++++---------- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/custom_ops/gpu_ops/get_output_msg_with_topk.cc b/custom_ops/gpu_ops/get_output_msg_with_topk.cc index 5da88dc1d65..4d6b5f56bba 100644 --- a/custom_ops/gpu_ops/get_output_msg_with_topk.cc +++ b/custom_ops/gpu_ops/get_output_msg_with_topk.cc @@ -39,9 +39,6 @@ void GetOutputTopK(const paddle::Tensor& x, int k, int64_t rank_id, bool wait_flag) { - if (rank_id > 0) { - return; - } static struct msgdata msg_rcv; int msg_queue_id = 1; diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 4040a6f92a1..e43c2557465 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -401,8 +401,6 @@ def __post_init__(self): if self.enable_logprob: if self.speculative_config is not None: raise NotImplementedError("Logprob does not support speculation_config.") - if self.enable_expert_parallel: - raise NotImplementedError("Logprob does not support enable_expert_parallel.") if not current_platform.is_cuda(): raise NotImplementedError("Only CUDA platform supports logprob.") if self.speculative_config is not None: diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 5173e339828..dede9d21c98 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -304,24 +304,23 @@ def process_sampling_results(self): continue else: - if ( + if self.use_logprobs: + get_output_topk( + self.output_tokens, + self.output_scores, + self.output_ranks, + K, + rank_id, + is_blocking, + ) + elif ( self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1 ): get_output_ep(self.output_tokens, rank_id, is_blocking) else: - if self.use_logprobs: - get_output_topk( - self.output_tokens, - self.output_scores, - self.output_ranks, - K, - rank_id, - is_blocking, - ) - else: - get_output(self.output_tokens, rank_id, is_blocking) + get_output(self.output_tokens, rank_id, is_blocking) if self.output_tokens[0, 0] == -2: continue