Skip to content

[Bugfix] Fix token loss in PP mode which causes degraded accuracy#41133

Merged
yewentao256 merged 9 commits into
vllm-project:mainfrom
starkwj:bugfix/pp-token-loss
May 6, 2026
Merged

[Bugfix] Fix token loss in PP mode which causes degraded accuracy#41133
yewentao256 merged 9 commits into
vllm-project:mainfrom
starkwj:bugfix/pp-token-loss

Conversation

@starkwj

@starkwj starkwj commented Apr 28, 2026

Copy link
Copy Markdown
Contributor

Purpose

Issue: The current PP mode causes degraded accuracy with concurrent requests.

gsm8k results with Qwen3-8B, num_concurrent 256:

  gsm8k (strict-match)
TP 0.8741
PP 2 + no-async-scheduling 0.8757
PP 2 (num_concurrent 2) 0.8741
PP 2 0.8324

Maybe related issue

Root cause

# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if not is_last_rank:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
req_index, start_token_index:end_token_index
] = new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index

In gpu_model_runner.py, the num_tokens_no_spec should indicate the number of current tokens without spec.
However, in _update_states with PP mode, the above code incorrectly updates it to num_computed_tokens.
For chunked prefill, the num_computed_tokens is the number of current computed chunked prompts, which is less than the number of total prompts (i.e., the real num_tokens_no_spec).

And in chunked prefill, the index to update token_ids_cpu is incorrect too, as it is less than num_prompts. What this code does in chunked prefill:

  • In async-scheduling, new_token_ids is empty, so token_ids_cpu not touched.
  • In no-async-scheduling, new_token_ids is the chunk needs to be computed, and token_ids_cpu is covered with the same prompt chunk, so it's also meaningless.

This is correct in _bookkeeping_sync, where it works in non-PP mode and for the last rank in PP mode, as shown below. And actually, in chunked prefill, the below operations are skipped due to no valid sampled tokens (filtered by discard_request_mask).

code in _bookkeeping_sync

start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + num_sampled_ids
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}"
)
self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx

Then, in _update_states, condense() is called, which reorders requests in the input batch, moves the valid requests at the tail to the empty slots at the front.

num_tokens = self._get_active_token_count(last_req_index)
(self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
self.spec_token_ids[empty_index],
self.spec_token_ids[last_req_index],
)
self.spec_token_ids[last_req_index].clear()
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens
]
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
last_req_index, :num_tokens
]

As shown above, tokens are copied to the empty slot.
However, the num_tokens is calculated with num_tokens_no_spec.
Thus, with the incorrect update in _update_states, it only copies num_computed_tokens tokens rather than all prompt tokens, and the remaining prompt tokens are lost (replaced by invalid tokens in the empty slot).

The code is also incorrect for no-async-scheduling. Why its score of gsm8k seems normal?

In no-async-scheduling, for each deocde step, the request is removed from input_batch due to unscheduled and re-add_request to input_batch before condense().
The add_request will add the request to empty slot first, leaving less empty slots at the front, thus greatly reduce the occurrences of swaps in condense().
However, it still has probability for this bug, especially when prompts are long and max-num-batched-tokens are relatively small.

This PR fixes the incorrect update in _update_states, adopts code logic similar to that in _bookkeeping_sync.

Test Plan

lm_eval with gsm8k

lm_eval --model local-completions --model_args "model=qwen,base_url=http://localhost:8000/v1/completions,tokenized_requests=False,tokenizer_backend=None,num_concurrent=256,max_length=40960" --tasks gsm8k --num_fewshot 5

Also benched performance, no changes were observed.

Test Result

before:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8370 ± 0.0102
strict-match 5 exact_match 0.8324 ± 0.0103

after:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8855 ± 0.0088
strict-match 5 exact_match 0.8772 ± 0.0090

and does not affect no-async-ascheduling:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8810 ± 0.0089
strict-match 5 exact_match 0.8757 ± 0.0091

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: Jing Wang <jingwang96@qq.com>
Copilot AI review requested due to automatic review settings April 28, 2026 12:50
@starkwj starkwj requested a review from njhill as a code owner April 28, 2026 12:50

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added v1 bug Something isn't working labels Apr 28, 2026
@starkwj

starkwj commented Apr 28, 2026

Copy link
Copy Markdown
Contributor Author

@yewentao256 Hi, Could you please review this PR? It's an issue related to PP mode.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the _update_states method in gpu_model_runner.py to adjust how token_ids_cpu is updated for non-last ranks. A logic error was identified where the condition num_new_tokens > 0 likely fails because req_state.num_tokens is incremented before the check, and a simpler check for the presence of new_token_ids was suggested to ensure the update occurs correctly.

Comment thread vllm/v1/worker/gpu_model_runner.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a pipeline-parallel (PP) mode state bookkeeping bug that can drop prompt tokens during InputBatch.condense(), degrading accuracy under high concurrency.

Changes:

  • Adjust _update_states handling of token_ids_cpu / num_tokens_no_spec updates for non-last PP ranks to avoid losing tokens during batch condensation.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread vllm/v1/worker/gpu_model_runner.py Outdated
Comment thread vllm/v1/worker/gpu_model_runner.py Outdated
starkwj and others added 2 commits April 28, 2026 21:24
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Jing Wang <jingwang96@qq.com>
Signed-off-by: Jing Wang <jingwang96@qq.com>

@yewentao256 yewentao256 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 28, 2026
@starkwj

starkwj commented Apr 29, 2026

Copy link
Copy Markdown
Contributor Author

@yewentao256 Thanks for your review and approval! The failed CI checks seem unrelated to this PR.

@yewentao256 yewentao256 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's retry before we force merge

@yewentao256 yewentao256 enabled auto-merge (squash) April 29, 2026 14:48

@hsliuustc0106 hsliuustc0106 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! The core logic of using num_tokens_no_spec as the write offset is correct and well-reasoned. However, I have a few concerns before this is ready to merge:


1. is_token_ids not updated (Medium)

Copilot raised this and the response was "the original code doesn't update it either." But the original code has a bug, so that's not a valid justification. When enable_prompt_embeds is used, the newly written token positions will remain False in is_token_ids, causing _prepare_inputs to treat output tokens as embed-provided tokens.

Could you confirm whether the PP non-last-rank path can ever reach the enable_prompt_embeds branch? If not, please add a comment explaining why. If yes, this should mirror _bookkeeping_sync (L3456) and set:

self.input_batch.is_token_ids[
    req_index, start_token_index:end_token_index
] = True

2. else branch for num_tokens_no_spec update (Low-Medium)

else:
    self.input_batch.num_tokens_no_spec[req_index] = max(
        self.input_batch.num_tokens_no_spec[req_index],
        num_computed_tokens,
    )

This handles the case where new_token_ids is empty or num_new_tokens == 0. The logic is correct — in chunked prefill for non-last chunks, num_computed_tokens advances but there are no new sampled tokens. The max() is safe for decode cases too. But please add a brief comment explaining the rationale.

3. Missing regression test (Medium)

The PR shows gsm8k accuracy validation, which is great. However, this bug only manifests under high concurrency + PP, so it would be valuable to add a unit test that verifies token_ids_cpu and num_tokens_no_spec consistency between PP ranks after _update_states, e.g. PP=2, non-async-scheduling, multiple concurrent requests.

auto-merge was automatically disabled May 1, 2026 12:44

Head branch was pushed to by a user without write access

@starkwj

starkwj commented May 1, 2026

Copy link
Copy Markdown
Contributor Author

@hsliuustc0106 Thanks for your suggestions. is_token_ids is updated accordingly now, and two unit tests have been added and passed.

@yewentao256 yewentao256 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hsliuustc0106 ! Could you take another look if the fix looks good

Also, @starkwj could you re-test the lm_eval result to make sure we don't hurt the acc?

@starkwj

starkwj commented May 1, 2026

Copy link
Copy Markdown
Contributor Author

Thanks @hsliuustc0106 ! Could you take another look if the fix looks good

Also, @starkwj could you re-test the lm_eval result to make sure we don't hurt the acc?

Hi, I have re-tested the lm_eval and results are fine (consistent).

@yewentao256 yewentao256 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding, the comment by @hsliuustc0106 has been resolved, please take another look before we merge

@hsliuustc0106

Copy link
Copy Markdown
Contributor

From my understanding, the comment by @hsliuustc0106 has been resolved, please take another look before we merge

lgtm

@yewentao256 yewentao256 merged commit 27702f6 into vllm-project:main May 6, 2026
55 checks passed
libinta pushed a commit to libinta/vllm that referenced this pull request May 8, 2026
…lm-project#41133)

Signed-off-by: Jing Wang <jingwang96@qq.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Signed-off-by: Libin Tang <libin.tang@intel.com>
wi-adam added a commit to wi-adam/vllm that referenced this pull request May 11, 2026
Under spec-decode + non-async PP, the scheduler can advance
num_computed_tokens by 1 (rejection-bookkeeping for a verified token
that lives only on the last PP rank) while sending an empty
new_token_ids[i] slice to non-last ranks. The previous code crashed
with

    IndexError: list index out of range
    File "gpu_model_runner.py", line 1305, in _update_states
        req_state.output_token_ids.append(new_token_ids[-1])

on the very first chat completion of an MTP+PP=2 deployment. Upstream
PR vllm-project#41133 fixed the parallel `token_ids_cpu` write path but didn't
extend the same tolerance to `output_token_ids` here.

The fix bounds the number of tokens we copy by the actual list
length, which also makes the `==1` and `>1` branches handle empty
input consistently — the elif branch already silently took the
empty slice, only the `==1` branch crashed.

Discovered while shipping MTP+PP on the wi-adam/vllm rebase fork
(2026-05-11). With this guard, the IndexError on first inference is
eliminated; bookkeeping on non-last PP ranks remains correct because
the verified token is only consumed by the last-rank sampler.

Signed-off-by: Adam Winstanley <adam@winstanley.industries>
weifang231 pushed a commit to weifang231/eb-vllm that referenced this pull request May 13, 2026
…lm-project#41133)

Signed-off-by: Jing Wang <jingwang96@qq.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
…lm-project#41133)

Signed-off-by: Jing Wang <jingwang96@qq.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
…lm-project#41133)

Signed-off-by: Jing Wang <jingwang96@qq.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
…lm-project#41133)

Signed-off-by: Jing Wang <jingwang96@qq.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants