Skip to content

[Core] Support logprobs with spec decode + async scheduling #29223

Merged
njhill merged 1 commit intovllm-project:mainfrom
njhill:as-sd-logprobs
Nov 25, 2025
Merged

[Core] Support logprobs with spec decode + async scheduling #29223
njhill merged 1 commit intovllm-project:mainfrom
njhill:as-sd-logprobs

Conversation

@njhill
Copy link
Member

@njhill njhill commented Nov 22, 2025

No description provided.

@mergify mergify bot added the v1 label Nov 22, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for logprobs with speculative decoding and asynchronous scheduling. The changes involve refactoring how cumulative token counts are calculated and passed to correctly process logprobs in these scenarios. The modifications in vllm/v1/sample/rejection_sampler.py and vllm/v1/worker/gpu_model_runner.py seem correct and well-structured. New tests are added to cover these cases. My main concern is the significant increase in tolerance in tests/v1/sample/test_logprobs.py for comparing logprobs, which might hide numerical precision issues. Please see the specific comment for details.

@njhill njhill requested a review from benchislett November 22, 2025 05:13
@mergify
Copy link

mergify bot commented Nov 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @njhill.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 22, 2025
Signed-off-by: Nick Hill <nhill@redhat.com>
cu_num_tokens = None
if return_cu_num_tokens:
cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
if len(discard_req_indices) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this done after computing cu_num_tokens?

Copy link
Member Author

@njhill njhill Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because cu_num_tokens is used to index into the logprobs tensors that don't take the discarded indices into account, so doing it beforehand results in incorrect output.

Originally it was done before which was a bug, fixed by #29216.

Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall

@njhill njhill merged commit 4e57c65 into vllm-project:main Nov 25, 2025
49 checks passed
@njhill njhill deleted the as-sd-logprobs branch November 25, 2025 20:55
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Dec 16, 2025
### What this PR does / why we need it?

Currently, we are using `AscendRejctionSampler` that extends from
`RejctionSampler` in spec decoding. `AscendRejctionSampler` override
`forward` of `RejctionSampler`, only aming to replace `rejection_sample`
func. This
causes a lot of code of `RejctionSampler` cannot be reused, for example:
- vllm-project/vllm#19482
- vllm-project/vllm#26060
- vllm-project/vllm#29223

#### Proposed Change:
- Delete `AscendRejctionSampler` and use `RejctionSampler` directly in
model runner.
- Patch `RejctionSampler.expand_batch_to_tokens` and
`RejctionSampler.rejection_sample`, maybe a better way is to make them
as custom ops.
- Modify `NPUModelRunner` following
vllm-project/vllm#26060

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
- [x] test logits processor for spec decoding
- [x] test logprobs for spec decoding
- [x] test logprobs for spec decoding + async shcheduling (test with
#4893)


- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
chenaoxuan pushed a commit to chenaoxuan/vllm-ascend that referenced this pull request Dec 20, 2025
### What this PR does / why we need it?

Currently, we are using `AscendRejctionSampler` that extends from
`RejctionSampler` in spec decoding. `AscendRejctionSampler` override
`forward` of `RejctionSampler`, only aming to replace `rejection_sample`
func. This
causes a lot of code of `RejctionSampler` cannot be reused, for example:
- vllm-project/vllm#19482
- vllm-project/vllm#26060
- vllm-project/vllm#29223

#### Proposed Change:
- Delete `AscendRejctionSampler` and use `RejctionSampler` directly in
model runner.
- Patch `RejctionSampler.expand_batch_to_tokens` and
`RejctionSampler.rejection_sample`, maybe a better way is to make them
as custom ops.
- Modify `NPUModelRunner` following
vllm-project/vllm#26060

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
- [x] test logits processor for spec decoding
- [x] test logprobs for spec decoding
- [x] test logprobs for spec decoding + async shcheduling (test with
vllm-project#4893)


- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…ject#29223)

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?

Currently, we are using `AscendRejctionSampler` that extends from
`RejctionSampler` in spec decoding. `AscendRejctionSampler` override
`forward` of `RejctionSampler`, only aming to replace `rejection_sample`
func. This
causes a lot of code of `RejctionSampler` cannot be reused, for example:
- vllm-project/vllm#19482
- vllm-project/vllm#26060
- vllm-project/vllm#29223

#### Proposed Change:
- Delete `AscendRejctionSampler` and use `RejctionSampler` directly in
model runner.
- Patch `RejctionSampler.expand_batch_to_tokens` and
`RejctionSampler.rejection_sample`, maybe a better way is to make them
as custom ops.
- Modify `NPUModelRunner` following
vllm-project/vllm#26060

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
- [x] test logits processor for spec decoding
- [x] test logprobs for spec decoding
- [x] test logprobs for spec decoding + async shcheduling (test with
vllm-project#4893)

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?

Currently, we are using `AscendRejctionSampler` that extends from
`RejctionSampler` in spec decoding. `AscendRejctionSampler` override
`forward` of `RejctionSampler`, only aming to replace `rejection_sample`
func. This
causes a lot of code of `RejctionSampler` cannot be reused, for example:
- vllm-project/vllm#19482
- vllm-project/vllm#26060
- vllm-project/vllm#29223

#### Proposed Change:
- Delete `AscendRejctionSampler` and use `RejctionSampler` directly in
model runner.
- Patch `RejctionSampler.expand_batch_to_tokens` and
`RejctionSampler.rejection_sample`, maybe a better way is to make them
as custom ops.
- Modify `NPUModelRunner` following
vllm-project/vllm#26060

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
- [x] test logits processor for spec decoding
- [x] test logprobs for spec decoding
- [x] test logprobs for spec decoding + async shcheduling (test with
vllm-project#4893)

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants