[scheduler] fix: correcting extend_logprob_start_len calculation#15922
[scheduler] fix: correcting extend_logprob_start_len calculation#15922
extend_logprob_start_len calculation#15922Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
/tag-and-rerun-ci |
b3fa171 to
49734da
Compare
49734da to
6ade519
Compare
| # When return_logprob = False, only need last token per request | ||
| num_tokens_for_logprob = local_batch.batch_size() | ||
| ) | ||
| if not local_batch.return_logprob: |
There was a problem hiding this comment.
it should be removed before this PR getting merged
extend_logprob_start_len calculation
|
/rerun-failed-ci |
| f"extend_logprob_start_lens={local_batch.extend_logprob_start_lens}" | ||
| ) | ||
| print(f"extend_lens={local_batch.extend_lens}") | ||
| assert ( |
There was a problem hiding this comment.
add a test case so the old code will fail / crash
| # If return_logprob is True, return the logprobs for output tokens by default | ||
| req.logprob_start_len = len(req.origin_input_ids) - 1 | ||
| else: | ||
| # If return_logprob is False, only the last token requires logprob computation |
There was a problem hiding this comment.
In the previous code, we didn't propogate logprob_start_len = -1 beyond the Scheduler. The value was reset to len(req.origin_input_ids) - 1. Hence, for prefill-only, we followed the same way.
Given, with this change we are propogating logprob_start_len = -1 into scheduler_batch, we can do the same logic for prefill_only as well instead of resetting to len(req.origin_input_ids) in line: 1531.
But, looks like, the current change doesn't break the prefill-only logprobs computation, so we are good.
Motivation
The original calculation for
extend_logprob_start_lenwas incorrect (code). When kv retraction happens, it should be determined based onlen(fill_ids)(the total number of original input ids and output ids) whenreturn_logprobisFalsebutlogprob_start_lenis set aslen(origin_input_ids) - 1.Modifications
We set
logprob_start_lenas -1 by default whenreturn_logprobisFalseso thatextend_logprob_start_lencan be set aslen(fill_ids) - 1duringprepare_for_extend. We also reverted changes in #12115 as it's no longer needed.Accuracy Tests, Benchmarking and Profiling
Tests in #9748 and #12115 are reproduced for verifying the robustness, accuracy and performance of this PR.
Test
test_logprobs.py:Checklist