Skip to content

[V1] Logit processors for rejection sampler#19482

Merged
simon-mo merged 20 commits intovllm-project:mainfrom
southfreebird:feature/logit-processors-for-rejection-sampler
Oct 7, 2025
Merged

[V1] Logit processors for rejection sampler#19482
simon-mo merged 20 commits intovllm-project:mainfrom
southfreebird:feature/logit-processors-for-rejection-sampler

Conversation

@southfreebird
Copy link
Contributor

@southfreebird southfreebird commented Jun 11, 2025

Hi team,
Thank you for your amazing work!
This PR was prepared by the Nebius team.

This is a draft of the PR, so your suggestions are welcome! I would be happy to improve it or add anything I missed.

Purpose

This PR adds logit processors to the rejection sampler, which allows the use of bad_words, penalties, and allowed_token_ids with speculative decoding. Previously, these parameters were silently ignored. The problem was described in issue: #18902

Test Plan

You can test this functionality using the following pytest command:

pytest tests/v1/sample/test_rejection_sampler.py -k "test_frequency_penalties or test_bad_words or test_allowed_token_ids"

Or use the script provided in the issue.

Test Result

Server command:

vllm serve meta-llama/Llama-3.1-8B-Instruct --speculative-config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 4}' --max-model-len 2048 --no-enable-prefix-caching

Client command:

vllm bench serve --model meta-llama/Llama-3.1-8B-Instruct --dataset-name sharegpt --num-prompts 200 --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --seed 0

Baseline:

============ Serving Benchmark Result ============
Successful requests:                     200       
Benchmark duration (s):                  10.76     
Total input tokens:                      42659     
Total generated tokens:                  43382     
Request throughput (req/s):              18.59     
Output token throughput (tok/s):         4032.80   
Peak output token throughput (tok/s):    4726.00   
Peak concurrent requests:                200.00    
Total Token throughput (tok/s):          7998.39   
---------------Time to First Token----------------
Mean TTFT (ms):                          887.85    
Median TTFT (ms):                        854.29    
P99 TTFT (ms):                           1306.54   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          28.76     
Median TPOT (ms):                        18.41     
P99 TPOT (ms):                           159.41    
---------------Inter-token Latency----------------
Mean ITL (ms):                           26.70     
Median ITL (ms):                         24.23     
P99 ITL (ms):                            207.50    
==================================================

Measurements with a frequency-penalty enabled:

============ Serving Benchmark Result ============
Successful requests:                     200       
Benchmark duration (s):                  11.50     
Total input tokens:                      42659     
Total generated tokens:                  43052     
Request throughput (req/s):              17.39     
Output token throughput (tok/s):         3743.33   
Peak output token throughput (tok/s):    4310.00   
Peak concurrent requests:                200.00    
Total Token throughput (tok/s):          7452.49   
---------------Time to First Token----------------
Mean TTFT (ms):                          902.73    
Median TTFT (ms):                        874.15    
P99 TTFT (ms):                           1341.88   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          31.91     
Median TPOT (ms):                        21.44     
P99 TPOT (ms):                           161.59    
---------------Inter-token Latency----------------
Mean ITL (ms):                           27.22     
Median ITL (ms):                         23.93     
P99 ITL (ms):                            207.16    
==================================================

(Optional) Documentation Update

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @southfreebird, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses issue #18902 by integrating logit processors into the rejection sampling mechanism used for speculative decoding. This enables the use of features like bad word lists, various penalties, and allowed token IDs, which were previously incompatible with speculative decoding, by correctly applying these constraints to the target model's logits during the verification and sampling steps.

Highlights

  • Speculative Decoding Logit Processors: Implemented support for applying logit processors (bad words, penalties, allowed token IDs) within the rejection sampling process used in speculative decoding. Previously, these parameters were ignored.
  • Rejection Sampler Logic Update: Modified the RejectionSampler to include a new apply_logits_processors method that orchestrates the application of various constraints and penalties to the target model's logits before sampling, correctly accounting for the structure of speculative draft tokens.
  • Penalty and Bad Word Ops Updates: Updated the underlying penalty and bad word application functions (apply_min_token_penalties, apply_bad_words_with_drafts) to handle the batched structure of speculative decoding, where multiple draft tokens per request need processing.
  • Metadata and Batch Management: Added tracking for last_spec_token_ids in SamplingMetadata and GPUInputBatch to provide the necessary context (previous speculative tokens) for applying logit processors during rejection sampling.
  • New Test Cases: Added specific test cases (test_frequency_penalties, test_bad_words, test_allowed_token_ids, test_min_tokens_penalty) to test_rejection_sampler.py to verify the correct application of logit processors in the speculative decoding flow.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the v1 label Jun 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces logit processors to the rejection sampler, enabling features like bad words, penalties, and allowed token IDs for speculative decoding. The changes span across test files, sampling operations, metadata structures, and the GPU model runner.

Key changes include:

  • Modifications to SamplingMetadata and InputBatch to store last_spec_token_ids.
  • New methods in RejectionSampler and Sampler (apply_logits_processors, _combine_outputs_with_spec_tokens, etc.) to handle the application of these processors, considering speculative tokens.
  • Updates to penalty and bad word application logic to support draft tokens.
  • New test cases to verify the functionality of frequency penalties, bad words, allowed token IDs, and min_tokens penalty with the rejection sampler.

Overall, the changes seem well-structured to integrate these features. I've pointed out a few areas for minor improvements or clarification, mainly concerning type hints and debugging statements. The core logic for applying processors in the context of speculative decoding appears sound.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling penalties seems to correctly initialize tensors based on whether penalties are provided. However, consider using torch.empty for frequency_penalties, presence_penalties, and repetition_penalties when no_penalties is True, similar to how it's done in SamplingMetadata elsewhere. This avoids creating small, potentially unnecessary tensors with torch.tensor([]) if the downstream code handles empty tensors appropriately. If torch.tensor([]) is specifically required for compatibility, this can be ignored.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function apply_bad_words_with_drafts iterates through bad_words_token_ids.items(). The key i from this iteration seems to correspond to the request index. Inside the inner loop, past_tokens_ids[start_idx + draft_idx] is used. Ensure that start_idx + draft_idx correctly maps to the sequence within past_tokens_ids that corresponds to the i-th request's draft_idx-th draft token. The logic seems to assume a flattened structure for past_tokens_ids where segments for each request are contiguous and ordered by num_draft_tokens[i]. This is a common pattern but worth double-checking for correctness under all conditions, especially if past_tokens_ids could be structured differently or if num_draft_tokens might not align perfectly with the segments in past_tokens_ids.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling num_draft_tokens seems correct. When num_draft_tokens is None, it operates in standard mode. When provided, it iterates through draft tokens for each request. Ensure that output_token_ids is structured such that output_token_ids[start_idx + draft_idx] correctly refers to the token history for the specific draft token being processed. This implies output_token_ids is a flattened list of lists, where each inner list is the token history for a specific (request, draft_token_index) pair.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _combine_outputs_with_spec_tokens method constructs a list of token histories. For each original output out and its corresponding spec tokens, it first appends out. Then, for each token in spec (except the last), it creates a new history by appending the spec token to the previously appended history (result[-1]). This creates a sequence of histories: [out, [*out, spec[0]], [*out, spec[0], spec[1]], ...]. This seems correct for applying penalties/bad words progressively for each speculative token. Ensure that output_token_ids and last_spec_token_ids arguments to this function have the expected types (they are annotated as torch.Tensor but used as iterables of lists/sequences). If they are indeed tensors, they might need conversion to lists of lists first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to combine output_token_ids with last_spec_token_ids when predict_bonus_token is true and penalties/bad words are active is crucial for speculative decoding. This ensures that penalties are applied considering the full sequence including speculative tokens. The _combine_outputs_with_spec_tokens method seems to correctly append speculative tokens to the base output tokens for this purpose.

@southfreebird southfreebird force-pushed the feature/logit-processors-for-rejection-sampler branch from 85cfd21 to d64f7dd Compare June 11, 2025 09:16
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this change.
I added it because of the new logic in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/utils.py#L56-L57.

It requires logits and repetition_penalties to have the same dtype. Without casting logits to float32, we get logits in bfloat16 and repetition_penalties in float32. To fix that, we could cast repetition_penalties to bfloat16, but we would lose some accuracy here. Casting logits to float32 makes it consistent with the sampler logic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable, but I'm not too sure about the potential consequences. FWIW, I don't think we would lose too much accuracy by casting repetition_penalties to BF16 since they are usually very small values (often < 2)

@southfreebird
Copy link
Contributor Author

Hi team!
Any chance this PR to be reviewed?

@ilyal-cerebras
Copy link
Contributor

Do you still observe benefits from spec decode when penalties are applied?
There is an opinion that when penalties are applied, draft and target models become dis-synchronized and a number of confirmed candidates is reduced.

@southfreebird southfreebird force-pushed the feature/logit-processors-for-rejection-sampler branch from cd2fd0d to c144313 Compare September 2, 2025 14:59
@southfreebird
Copy link
Contributor Author

Do you still observe benefits from spec decode when penalties are applied?
There is an opinion that when penalties are applied, draft and target models become dis-synchronized and a number of confirmed candidates is reduced.

Yeah, you are right. This PR just applies penalties to the bonus token when penalties are applied.

Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to track last_spec_token_ids instead of using scheduler_output.scheduled_spec_decode_tokens?

@southfreebird
Copy link
Contributor Author

Hey @benchislett
My main idea was to keep the interface of Samplers consistent and to pass all arguments required for sampling through sampling_metadata. Do you suggest passing them directly to the samplers as scheduler_output.scheduled_spec_decode_tokens instead?

Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now. It is somewhat confusing to me that self.input_batch.last_spec_token_ids is named this way, as it looks like it's just the current spec token ids that need to be verified. If possible I would suggest renaming to spec_token_ids to clarify if possible.

Also, please include a benchmark with performance before/after this change. It is unclear to me how much (if any) slowdown will be seen when we add this feature.

Otherwise, LGTM. I think someone else should take a look at the logits-casting to make sure it won't have consequences in other areas.

@benchislett
Copy link
Collaborator

@southfreebird in your benchmark, please ensure that some logit processors are enabled. Otherwise I worry this might not capture the performance change.

@afeldman-nm
Copy link
Contributor

JFYI @southfreebird the approach for supporting built-in logits processors in vLLM v1 is changing, as described here:

https://docs.vllm.ai/en/latest/design/logits_processors.html

So in the long term, bad words, penalties and allowed token ids should be updated to be subclasses of LogitsProcessor as described in #25957

However, #25957 depends on #25389 (revision to the programming model for defining subclasses of LogitsProcessor) which is still WIP.

So for the purposes of this PR I think the approach taken (augment the existing bad words/penalties/allowed token ids support to be compatible with spec decoding) is fine, since the logits processor programming model will be revised soon.

@southfreebird southfreebird force-pushed the feature/logit-processors-for-rejection-sampler branch from 58f41ad to 4028195 Compare September 30, 2025 17:29
@southfreebird
Copy link
Contributor Author

Added benchmarks with logit processors enabled. As I can see, there is a performance degradation when logit processors are enabled

@simon-mo
Copy link
Collaborator

@22quinn @houseroad please help review logits/sampler changes

@njhill
Copy link
Member

njhill commented Oct 1, 2025

I'll review this today/tomorrow

@southfreebird southfreebird force-pushed the feature/logit-processors-for-rejection-sampler branch from 4028195 to 4e38206 Compare October 1, 2025 09:33
@mergify mergify bot added deepseek Related to DeepSeek models frontend llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models gpt-oss Related to GPT-OSS models rocm Related to AMD ROCm labels Oct 7, 2025
@mergify
Copy link

mergify bot commented Oct 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @southfreebird.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@southfreebird
Copy link
Contributor Author

OMG
I just tried to sign off your ruff commit by simple git rebase -i HEAD~X

@southfreebird southfreebird force-pushed the feature/logit-processors-for-rejection-sampler branch from 17fa609 to f79d8b4 Compare October 7, 2025 11:32
@mergify mergify bot removed the tpu Related to Google TPUs label Oct 7, 2025
@southfreebird
Copy link
Contributor Author

Sorry, everyone, for mentioning.
I just tried to fix the DCO check.
Reverted everything back (except a hundred mentions and new reviewers)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1

Projects

Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants