Skip to content

Commit 66a98b4

Browse files
authored
ep support logprob (#4089) (#4151)
1 parent a685e5a commit 66a98b4

File tree

3 files changed

+11
-17
lines changed

3 files changed

+11
-17
lines changed

custom_ops/gpu_ops/get_output_msg_with_topk.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ void GetOutputTopK(const paddle::Tensor& x,
3939
int k,
4040
int64_t rank_id,
4141
bool wait_flag) {
42-
if (rank_id > 0) {
43-
return;
44-
}
4542

4643
static struct msgdata msg_rcv;
4744
int msg_queue_id = 1;

fastdeploy/engine/args_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,8 +401,6 @@ def __post_init__(self):
401401
if self.enable_logprob:
402402
if self.speculative_config is not None:
403403
raise NotImplementedError("Logprob does not support speculation_config.")
404-
if self.enable_expert_parallel:
405-
raise NotImplementedError("Logprob does not support enable_expert_parallel.")
406404
if not current_platform.is_cuda():
407405
raise NotImplementedError("Only CUDA platform supports logprob.")
408406
if self.speculative_config is not None:

fastdeploy/output/token_processor.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -303,24 +303,23 @@ def process_sampling_results(self):
303303
continue
304304

305305
else:
306-
if (
306+
if self.use_logprobs:
307+
get_output_topk(
308+
self.output_tokens,
309+
self.output_scores,
310+
self.output_ranks,
311+
K,
312+
rank_id,
313+
is_blocking,
314+
)
315+
elif (
307316
self.cfg.parallel_config.enable_expert_parallel
308317
and self.cfg.parallel_config.data_parallel_size > 1
309318
):
310319
get_output_ep(self.output_tokens, rank_id, is_blocking)
311320

312321
else:
313-
if self.use_logprobs:
314-
get_output_topk(
315-
self.output_tokens,
316-
self.output_scores,
317-
self.output_ranks,
318-
K,
319-
rank_id,
320-
is_blocking,
321-
)
322-
else:
323-
get_output(self.output_tokens, rank_id, is_blocking)
322+
get_output(self.output_tokens, rank_id, is_blocking)
324323

325324
if self.output_tokens[0, 0] == -2:
326325
continue

0 commit comments

Comments
 (0)