[V1] Logit processors for rejection sampler#19482
[V1] Logit processors for rejection sampler#19482simon-mo merged 20 commits intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Summary of Changes
Hello @southfreebird, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses issue #18902 by integrating logit processors into the rejection sampling mechanism used for speculative decoding. This enables the use of features like bad word lists, various penalties, and allowed token IDs, which were previously incompatible with speculative decoding, by correctly applying these constraints to the target model's logits during the verification and sampling steps.
Highlights
- Speculative Decoding Logit Processors: Implemented support for applying logit processors (bad words, penalties, allowed token IDs) within the rejection sampling process used in speculative decoding. Previously, these parameters were ignored.
- Rejection Sampler Logic Update: Modified the
RejectionSamplerto include a newapply_logits_processorsmethod that orchestrates the application of various constraints and penalties to the target model's logits before sampling, correctly accounting for the structure of speculative draft tokens. - Penalty and Bad Word Ops Updates: Updated the underlying penalty and bad word application functions (
apply_min_token_penalties,apply_bad_words_with_drafts) to handle the batched structure of speculative decoding, where multiple draft tokens per request need processing. - Metadata and Batch Management: Added tracking for
last_spec_token_idsinSamplingMetadataandGPUInputBatchto provide the necessary context (previous speculative tokens) for applying logit processors during rejection sampling. - New Test Cases: Added specific test cases (
test_frequency_penalties,test_bad_words,test_allowed_token_ids,test_min_tokens_penalty) totest_rejection_sampler.pyto verify the correct application of logit processors in the speculative decoding flow.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces logit processors to the rejection sampler, enabling features like bad words, penalties, and allowed token IDs for speculative decoding. The changes span across test files, sampling operations, metadata structures, and the GPU model runner.
Key changes include:
- Modifications to
SamplingMetadataandInputBatchto storelast_spec_token_ids. - New methods in
RejectionSamplerandSampler(apply_logits_processors,_combine_outputs_with_spec_tokens, etc.) to handle the application of these processors, considering speculative tokens. - Updates to penalty and bad word application logic to support draft tokens.
- New test cases to verify the functionality of frequency penalties, bad words, allowed token IDs, and min_tokens penalty with the rejection sampler.
Overall, the changes seem well-structured to integrate these features. I've pointed out a few areas for minor improvements or clarification, mainly concerning type hints and debugging statements. The core logic for applying processors in the context of speculative decoding appears sound.
There was a problem hiding this comment.
The logic for handling penalties seems to correctly initialize tensors based on whether penalties are provided. However, consider using torch.empty for frequency_penalties, presence_penalties, and repetition_penalties when no_penalties is True, similar to how it's done in SamplingMetadata elsewhere. This avoids creating small, potentially unnecessary tensors with torch.tensor([]) if the downstream code handles empty tensors appropriately. If torch.tensor([]) is specifically required for compatibility, this can be ignored.
vllm/v1/sample/ops/bad_words.py
Outdated
There was a problem hiding this comment.
The function apply_bad_words_with_drafts iterates through bad_words_token_ids.items(). The key i from this iteration seems to correspond to the request index. Inside the inner loop, past_tokens_ids[start_idx + draft_idx] is used. Ensure that start_idx + draft_idx correctly maps to the sequence within past_tokens_ids that corresponds to the i-th request's draft_idx-th draft token. The logic seems to assume a flattened structure for past_tokens_ids where segments for each request are contiguous and ordered by num_draft_tokens[i]. This is a common pattern but worth double-checking for correctness under all conditions, especially if past_tokens_ids could be structured differently or if num_draft_tokens might not align perfectly with the segments in past_tokens_ids.
vllm/v1/sample/ops/penalties.py
Outdated
There was a problem hiding this comment.
The logic for handling num_draft_tokens seems correct. When num_draft_tokens is None, it operates in standard mode. When provided, it iterates through draft tokens for each request. Ensure that output_token_ids is structured such that output_token_ids[start_idx + draft_idx] correctly refers to the token history for the specific draft token being processed. This implies output_token_ids is a flattened list of lists, where each inner list is the token history for a specific (request, draft_token_index) pair.
vllm/v1/sample/rejection_sampler.py
Outdated
There was a problem hiding this comment.
The _combine_outputs_with_spec_tokens method constructs a list of token histories. For each original output out and its corresponding spec tokens, it first appends out. Then, for each token in spec (except the last), it creates a new history by appending the spec token to the previously appended history (result[-1]). This creates a sequence of histories: [out, [*out, spec[0]], [*out, spec[0], spec[1]], ...]. This seems correct for applying penalties/bad words progressively for each speculative token. Ensure that output_token_ids and last_spec_token_ids arguments to this function have the expected types (they are annotated as torch.Tensor but used as iterables of lists/sequences). If they are indeed tensors, they might need conversion to lists of lists first.
vllm/v1/sample/sampler.py
Outdated
There was a problem hiding this comment.
The logic to combine output_token_ids with last_spec_token_ids when predict_bonus_token is true and penalties/bad words are active is crucial for speculative decoding. This ensures that penalties are applied considering the full sequence including speculative tokens. The _combine_outputs_with_spec_tokens method seems to correctly append speculative tokens to the base output tokens for this purpose.
85cfd21 to
d64f7dd
Compare
vllm/v1/sample/rejection_sampler.py
Outdated
There was a problem hiding this comment.
I'm not sure about this change.
I added it because of the new logic in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/utils.py#L56-L57.
It requires logits and repetition_penalties to have the same dtype. Without casting logits to float32, we get logits in bfloat16 and repetition_penalties in float32. To fix that, we could cast repetition_penalties to bfloat16, but we would lose some accuracy here. Casting logits to float32 makes it consistent with the sampler logic.
There was a problem hiding this comment.
This seems reasonable, but I'm not too sure about the potential consequences. FWIW, I don't think we would lose too much accuracy by casting repetition_penalties to BF16 since they are usually very small values (often < 2)
|
Hi team! |
|
Do you still observe benefits from spec decode when penalties are applied? |
cd2fd0d to
c144313
Compare
Yeah, you are right. This PR just applies penalties to the bonus token when penalties are applied. |
benchislett
left a comment
There was a problem hiding this comment.
Why do you need to track last_spec_token_ids instead of using scheduler_output.scheduled_spec_decode_tokens?
|
Hey @benchislett |
benchislett
left a comment
There was a problem hiding this comment.
I see now. It is somewhat confusing to me that self.input_batch.last_spec_token_ids is named this way, as it looks like it's just the current spec token ids that need to be verified. If possible I would suggest renaming to spec_token_ids to clarify if possible.
Also, please include a benchmark with performance before/after this change. It is unclear to me how much (if any) slowdown will be seen when we add this feature.
Otherwise, LGTM. I think someone else should take a look at the logits-casting to make sure it won't have consequences in other areas.
|
@southfreebird in your benchmark, please ensure that some logit processors are enabled. Otherwise I worry this might not capture the performance change. |
|
JFYI @southfreebird the approach for supporting built-in logits processors in vLLM v1 is changing, as described here: https://docs.vllm.ai/en/latest/design/logits_processors.html So in the long term, bad words, penalties and allowed token ids should be updated to be subclasses of However, #25957 depends on #25389 (revision to the programming model for defining subclasses of So for the purposes of this PR I think the approach taken (augment the existing bad words/penalties/allowed token ids support to be compatible with spec decoding) is fine, since the logits processor programming model will be revised soon. |
58f41ad to
4028195
Compare
|
Added benchmarks with logit processors enabled. As I can see, there is a performance degradation when logit processors are enabled |
|
@22quinn @houseroad please help review logits/sampler changes |
|
I'll review this today/tomorrow |
4028195 to
4e38206
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
|
OMG |
17fa609 to
f79d8b4
Compare
|
Sorry, everyone, for mentioning. |
Hi team,
Thank you for your amazing work!
This PR was prepared by the Nebius team.
This is a draft of the PR, so your suggestions are welcome! I would be happy to improve it or add anything I missed.
Purpose
This PR adds logit processors to the rejection sampler, which allows the use of bad_words, penalties, and allowed_token_ids with speculative decoding. Previously, these parameters were silently ignored. The problem was described in issue: #18902
Test Plan
You can test this functionality using the following pytest command:
Or use the script provided in the issue.
Test Result
Server command:
Client command:
Baseline:
Measurements with a frequency-penalty enabled:
(Optional) Documentation Update