Skip to content

Conversation

@eitanturok
Copy link

@eitanturok eitanturok commented Sep 5, 2025

Implement FR-Spec: Accelerating Large-Vocabulary Language Models via Frequency-Ranked Speculative Sampling to speedup speculative decoding. @keyboardAnt @Achazwl @jmamou.

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@eitanturok eitanturok changed the title Implement fr-spec to speedup speculative decoding Implement fr-spec to speedup speculative decoding Sep 5, 2025
@mergify
Copy link

mergify bot commented Sep 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @eitanturok.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 7, 2025
@mergify mergify bot removed the needs-rebase label Sep 7, 2025
@mergify mergify bot added the performance Performance-related issues label Sep 8, 2025
@mergify mergify bot added the documentation Improvements or additions to documentation label Sep 8, 2025
@mergify
Copy link

mergify bot commented Sep 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @eitanturok.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 8, 2025
@eitanturok
Copy link
Author

I fixed a couple of issues with the previous benchmark. Turns out, we were compute bound, not memory bound. I re-ran the benchmark and got:

  • num-spec-tokens=1, max-seq-len=1
    • eagle-2 is 61% faster then vanilla
    • fr-spec is 68% faster then vanilla
  • num-spec-tokens=39, max-seq-len=1
    • eagle-2 is ??% faster then vanilla
    • fr-spec is ??% faster then vanilla

Like before, I benchmarked vanilla, eagle-2, and fr-spec on mt-bench with llama-3.1-8b-instruct on 100 prompts.

Speculative Decoding Benchmark Results

Method Depth Branching Num Spec Tokens Mean Acceptance Length Decoding Throughput (tokens/s) Total Time (s) Forward Ratio
Eagle 3 1 3 2.31 119.31 181.29 0.137
fr-spec 3 1 3 2.23 121.46 178.13 0.138
Eagle 1 3 3 2.31 120.16 180.02 0.138
fr-spec 1 3 3 2.23 121.99 177.36 0.139
fr-spec 1 1 1 1.68 104.59 206.23 0.044
Eagle 1 1 1 1.71 100.01 215.60 0.044
Vanilla N/A N/A 0 1.00 61.74 347.07 0.000
Commands to reproduce the table
Method Depth Branching Num Spec Tokens Mean Acceptance Length Decoding Throughput (tokens/s) Total Time (s) Forward Ratio Command
Eagle 3 1 3 2.31 119.31 181.29 0.137 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 3
fr-spec 3 1 3 2.23 121.46 178.13 0.138 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 3 --draft-vocab-frequency-path 'eturok/llama-3.1-8b-instruct-vocab-freq/vocab_freq.pt' --draft-vocab-frequency-keep-threshold 0.25
Eagle 1 3 3 2.31 120.16 180.02 0.138 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 3 --spec-token-tree-depth 1 --spec-token-tree-branching 3
fr-spec 1 3 3 2.23 121.99 177.36 0.139 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 3 --spec-token-tree-depth 1 --spec-token-tree-branching 3 --draft-vocab-frequency-path 'eturok/llama-3.1-8b-instruct-vocab-freq/vocab_freq.pt' --draft-vocab-frequency-keep-threshold 0.25
fr-spec 1 1 1 1.68 104.59 206.23 0.044 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 1 --draft-vocab-frequency-path 'eturok/llama-3.1-8b-instruct-vocab-freq/vocab_freq.pt' --draft-vocab-frequency-keep-threshold 0.25
Eagle 1 1 1 1.71 100.01 215.60 0.044 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 1
Vanilla N/A N/A 0 1.00 61.74 347.07 0.000 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 0

Observations:

  1. We definetly get a speedup over vanilla.
  2. In eagle num-spec-tokens=1, the drafter forward pass takes 4% of the time of the target forward pass. So the drafter forward pass is not nearly a huge bottleneck, so we don't expect fr-spec to speed things up that much. Another perspective is that if the drafter forward pass is not a huge bottleneck, then maybe we can "sneak" in more speculative tokens to get a higher accepted length without incurring the cost of more compute.

I made several fixes:

  1. We previously ran the benchmark with max-seq-length=100 and max-seq-length=100. I accidetanlly called max-seq-length the batch-size previously but these are diff as max=seq-length is the number of user prompts we process at once. In this setting, we are compute bound and don't really see a speedup because speculative decoding is faster only when we are memory bound. So this time, we run with max-seq-length=1 as @Achazwl suggested. Also, in the fr-spec repo, it looks like they also ran with a batch size of one here and here.
  2. It seems like cudagraphs are broken for speculative decoding drafters. This gives vanilla an unfair advantage in comparsirons. and so I turned off any model compilation for a fair comparison.
  3. When I set num-speculative-tokens to 32, eagle-2 became incredibly slow.

@mergify mergify bot added the frontend label Sep 19, 2025
old_weight = self.model.lm_head.weight

# In-place pruning of the weight
self.model.lm_head.weight.data = self.model.lm_head.weight.data[self.pruned_vocab].clone().detach()
Copy link

@jmamou jmamou Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eitanturok

Since you are selcting part of the indices you should ensure that self.model.lm_head.weight.data[self.pruned_vocab].clone().detach() is contiguous.
I guess self.model.lm_head.weight.data[self.pruned_vocab].clone().detach().contiguous() should work.
Look at
https://docs.pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch-tensor-contiguous

@eitanturok
Copy link
Author

eitanturok commented Sep 30, 2025

@jmamou self.model.lm_head.weight.data is already contiguous so we don't need to add .contiguous().

I added a print to the code

self.model.lm_head.weight.data = self.model.lm_head.weight.data[self.pruned_vocab].clone().detach()
print(self.model.lm_head.weight.data.is_contiguous())

ran the cmd

VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py \
  --dataset-name hf \
  --dataset-path philschmid/mt-bench \
  --num-prompts 100 \
  --compilation-config '{"level": "0"}' \
  --max-num-seqs 1 \
  --num-spec-tokens 1 \
  --draft-vocab-frequency-path 'thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt' \
  --draft-vocab-frequency-keep-threshold 0.25

and got the output

self.model.lm_head.weight.data.is_contiguous(): True

For more details, see this pytorch PR.

@mergify
Copy link

mergify bot commented Oct 8, 2025

Documentation preview: https://vllm--24343.org.readthedocs.build/en/24343/

@keyboardAnt
Copy link

keyboardAnt commented Nov 12, 2025

I fixed a couple of issues with the previous benchmark. Turns out, we were compute bound, not memory bound. I re-ran the benchmark and got:

  • num-spec-tokens=1, max-seq-len=1

    • eagle-2 is 61% faster then vanilla
    • fr-spec is 68% faster then vanilla
  • num-spec-tokens=39, max-seq-len=1

    • eagle-2 is ??% faster then vanilla
    • fr-spec is ??% faster then vanilla

Like before, I benchmarked vanilla, eagle-2, and fr-spec on mt-bench with llama-3.1-8b-instruct on 100 prompts.

Speculative Decoding Benchmark Results

Method Depth Branching Num Spec Tokens Mean Acceptance Length Decoding Throughput (tokens/s) Total Time (s) Forward Ratio
Eagle 3 1 3 2.31 119.31 181.29 0.137
fr-spec 3 1 3 2.23 121.46 178.13 0.138
Eagle 1 3 3 2.31 120.16 180.02 0.138
fr-spec 1 3 3 2.23 121.99 177.36 0.139
fr-spec 1 1 1 1.68 104.59 206.23 0.044
Eagle 1 1 1 1.71 100.01 215.60 0.044
Vanilla N/A N/A 0 1.00 61.74 347.07 0.000
Commands to reproduce the table
Method Depth Branching Num Spec Tokens Mean Acceptance Length Decoding Throughput (tokens/s) Total Time (s) Forward Ratio Command
Eagle 3 1 3 2.31 119.31 181.29 0.137 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 3
fr-spec 3 1 3 2.23 121.46 178.13 0.138 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 3 --draft-vocab-frequency-path 'eturok/llama-3.1-8b-instruct-vocab-freq/vocab_freq.pt' --draft-vocab-frequency-keep-threshold 0.25
Eagle 1 3 3 2.31 120.16 180.02 0.138 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 3 --spec-token-tree-depth 1 --spec-token-tree-branching 3
fr-spec 1 3 3 2.23 121.99 177.36 0.139 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 3 --spec-token-tree-depth 1 --spec-token-tree-branching 3 --draft-vocab-frequency-path 'eturok/llama-3.1-8b-instruct-vocab-freq/vocab_freq.pt' --draft-vocab-frequency-keep-threshold 0.25
fr-spec 1 1 1 1.68 104.59 206.23 0.044 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 1 --draft-vocab-frequency-path 'eturok/llama-3.1-8b-instruct-vocab-freq/vocab_freq.pt' --draft-vocab-frequency-keep-threshold 0.25
Eagle 1 1 1 1.71 100.01 215.60 0.044 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 1
Vanilla N/A N/A 0 1.00 61.74 347.07 0.000 VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 100 --max-num-seqs 1 --compilation-config '{"level": "0"}' --num-spec-tokens 0
Observations:

  1. We definetly get a speedup over vanilla.
  2. In eagle num-spec-tokens=1, the drafter forward pass takes 4% of the time of the target forward pass. So the drafter forward pass is not nearly a huge bottleneck, so we don't expect fr-spec to speed things up that much. Another perspective is that if the drafter forward pass is not a huge bottleneck, then maybe we can "sneak" in more speculative tokens to get a higher accepted length without incurring the cost of more compute.

I made several fixes:

  1. We previously ran the benchmark with max-seq-length=100 and max-seq-length=100. I accidetanlly called max-seq-length the batch-size previously but these are diff as max=seq-length is the number of user prompts we process at once. In this setting, we are compute bound and don't really see a speedup because speculative decoding is faster only when we are memory bound. So this time, we run with max-seq-length=1 as @Achazwl suggested. Also, in the fr-spec repo, it looks like they also ran with a batch size of one here and here.
  2. It seems like cudagraphs are broken for speculative decoding drafters. This gives vanilla an unfair advantage in comparsirons. and so I turned off any model compilation for a fair comparison.
  3. When I set num-speculative-tokens to 32, eagle-2 became incredibly slow.

@eitanturok -

Forward ratios:
What fraction of the vocab remains after pruning here? (Why does the forward ratio of frspec equal the forward ratio of eagle in this benchmark?)

Complete benchmark:
frspec seems to increase throughput by 1.52% (=100*(121.99-120.16)/120.16) in the benchmark above. Are there any blockers to running a sweep over different configurations? For example, (eagle_model, eagle_params, dataset, batch_size, hardware) combinations, where eagle_params defines the draft-tree shape (max_depth, max_width, max_num_of_nodes).

My intuition is that (i) large batches evaluated on datasets with (ii) long inputs that induce (iii) long outputs (e.g., GovReport, BookSum) are more likely to demonstrate significant improvements, based on this microbenchmark: #24506 (comment).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation frontend needs-rebase performance Performance-related issues speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants