Skip to content

feat: add parallel token drafting for EAGLE#32628

Closed
hai-meh-cs wants to merge 5 commits intovllm-project:mainfrom
hai-meh-cs:pard-pr
Closed

feat: add parallel token drafting for EAGLE#32628
hai-meh-cs wants to merge 5 commits intovllm-project:mainfrom
hai-meh-cs:pard-pr

Conversation

@hai-meh-cs
Copy link
Contributor

@hai-meh-cs hai-meh-cs commented Jan 20, 2026

Purpose

PTD (Parallel Token Decoding) generates K draft tokens in a single forward pass instead of K sequential passes, reducing draft overhead for EAGLE speculative decoding.

vllm serve openai/gpt-oss-120b \
  --speculative-config '{"model": "<ptd-draft>", "method": "eagle3-ptd", "num_speculative_tokens": 4}' \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --async-scheduling

Performance

TPOT: Constant vs Linear Growth

TPOT Comparison

Throughput Comparison (C=1)

Throughput Comparison

Speedup vs Vanilla EAGLE3

Speedup Comparison

Status

Output throughput improvement at comparable acceptance rates (within ±2%):

  • C=1: +12% (MT-Bench) to +16% (HumanEval) comparing optimal K for each method
  • C=8: +6.5% (MT-Bench) to +10% (HumanEval) comparing optimal K for each method

TPOT remains constant at ~2.7-2.9ms regardless of K, while vanilla increases linearly from 3.15ms (K=3) to 4.33ms (K=8).

TTFT is comparable: PTD 33-38ms vs Vanilla 34-44ms.

Best Throughput Comparison

Scenario PTD Vanilla Improvement
MT-Bench C=1 345 tok/s (K=4) 308 tok/s (K=3) +12%
MT-Bench C=8 1165 tok/s (K=3) 1094 tok/s (K=3) +6.5%
HumanEval C=1 391 tok/s (K=4) 336 tok/s (K=3) +16%
HumanEval C=8 1301 tok/s (K=6) 1180 tok/s (K=3) +10%

Vanilla peaks at K=3; higher K incurs sequential overhead. PTD EAGLE3 benefits from K=4-6 due to parallel drafting.

Speedup by K Value

K MT-Bench C=1 HumanEval C=1
3 +11% +13%
4 +19% +21%
5 +24% +28%
6 +33% +36%
7 +39% +43%
8 +47% +50%

TPOT Analysis

K PTD Vanilla Delta
3 2.79ms 3.15ms -11%
4 2.74ms 3.33ms -18%
5 2.77ms 3.52ms -21%
6 2.79ms 3.80ms -27%
7 2.83ms 4.05ms -30%
8 2.88ms 4.33ms -33%

PTD EAGLE3 TPOT is constant; vanilla grows linearly with K.

Implications for K Selection

Selecting K involves balancing two factors: higher K can yield more accepted tokens per verification, but also increases draft cost when acceptance is low. With vanilla EAGLE, generating K tokens requires K sequential forward passes, so the optimal K depends on expected acceptance rates.

PTD generates all K tokens in a single forward pass. For the K values tested (3-8), the additional tokens processed per forward have minimal impact on draft latency, as shown in the TPOT data above. This sublinear scaling reduces the cost of higher K settings relative to vanilla EAGLE.

This has a practical benefit: K selection becomes less sensitive to workload characteristics. Acceptance rates vary by prompt type. The benchmark data shows HumanEval averaging ~50% acceptance while MT-Bench averages ~35%. With PTD, choosing a higher K captures more tokens on high-acceptance prompts without proportionally increasing cost on lower-acceptance ones.

The benchmark results reflect this: vanilla EAGLE performs best at K=3, while PTD peaks at K=4-6 depending on workload.

PTD EAGLE3 Results (MT-Bench)
K C=1 C=2 C=4 C=8 TTFT TPOT Accept
3 342 538 794 1165 33ms 2.79ms 44.7%
4 345 532 792 1148 33ms 2.74ms 37.3%
5 341 526 779 1133 33ms 2.77ms 31.7%
6 335 510 774 1133 33ms 2.79ms 26.9%
7 331 503 728 1103 33ms 2.83ms 23.7%
8 324 493 731 1083 35ms 2.88ms 21.0%

Per-Position Acceptance

K P0 P1 P2 P3 P4 P5 P6 P7
3 65% 42% 27% - - - - -
4 65% 41% 26% 17% - - - -
5 64% 40% 25% 17% 12% - - -
6 64% 39% 24% 16% 11% 7% - -
7 64% 39% 24% 16% 11% 7% 5% -
8 63% 39% 24% 15% 11% 7% 5% 4%
PTD EAGLE3 Results (HumanEval)
K C=1 C=2 C=4 C=8 TTFT TPOT Accept
3 379 596 867 1254 37ms 2.57ms 56.8%
4 391 599 902 1284 38ms 2.49ms 48.7%
5 391 597 875 1283 38ms 2.48ms 42.2%
6 389 590 870 1301 38ms 2.50ms 36.7%
7 385 576 832 1269 38ms 2.52ms 32.5%
8 381 564 827 1241 38ms 2.55ms 29.3%

Per-Position Acceptance

K P0 P1 P2 P3 P4 P5 P6 P7
3 77% 54% 40% - - - - -
4 76% 52% 38% 28% - - - -
5 75% 51% 36% 27% 21% - - -
6 74% 50% 35% 26% 20% 15% - -
7 74% 50% 35% 25% 19% 14% 11% -
8 74% 49% 34% 25% 19% 14% 11% 9%
Vanilla EAGLE3 Results (MT-Bench)
K C=1 C=2 C=4 C=8 TTFT TPOT Accept
3 308 490 743 1094 34ms 3.15ms 43.0%
4 289 466 715 1057 35ms 3.33ms 35.2%
5 276 440 671 1005 37ms 3.52ms 30.9%
6 252 401 626 951 38ms 3.80ms 25.5%
7 238 380 588 916 39ms 4.05ms 22.4%
8 221 352 560 859 40ms 4.33ms 19.4%

Per-Position Acceptance

K P0 P1 P2 P3 P4 P5 P6 P7
3 64% 40% 25% - - - - -
4 64% 39% 24% 15% - - - -
5 64% 40% 25% 15% 10% - - -
6 63% 38% 23% 14% 9% 5% - -
7 63% 38% 23% 15% 9% 5% 3% -
8 63% 38% 23% 14% 9% 5% 3% 2%
Vanilla EAGLE3 Results (HumanEval)
K C=1 C=2 C=4 C=8 TTFT TPOT Accept
3 336 536 803 1180 38ms 2.91ms 53.5%
4 324 516 795 1167 39ms 3.02ms 45.2%
5 305 488 756 1140 41ms 3.20ms 38.5%
6 286 459 712 1090 42ms 3.41ms 33.3%
7 269 431 673 1053 43ms 3.63ms 29.1%
8 254 411 639 1018 44ms 3.86ms 25.7%

Per-Position Acceptance

K P0 P1 P2 P3 P4 P5 P6 P7
3 73% 52% 36% - - - - -
4 72% 51% 35% 23% - - - -
5 72% 50% 34% 22% 15% - - -
6 72% 50% 33% 22% 14% 9% - -
7 72% 50% 33% 21% 14% 9% 6% -
8 71% 49% 32% 21% 14% 9% 5% 3%
Benchmark Configuration

Hardware: P5e (H200), TP=1

Draft Models:

  • PTD EAGLE3: Internally trained PTD draft model
  • Vanilla EAGLE3: nvidia/gpt-oss-120b-Eagle3-short-context

Server:

vllm serve openai/gpt-oss-120b \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --speculative-config '{"model": "<draft>", "method": "<method>", "num_speculative_tokens": <K>}' \
  --no-enable-prefix-caching \
  --async-scheduling

Benchmark (MT-Bench):

vllm bench serve \
  --backend openai-chat \
  --dataset-name hf \
  --dataset-path philschmid/mt-bench \
  --hf-split train \
  --hf-output-len 2048 \
  --num-prompts 80 \
  --max-concurrency <C> \
  --num-warmups 5 \
  --request-rate inf

Benchmark (HumanEval):

vllm bench serve \
  --backend openai-chat \
  --dataset-name custom \
  --dataset-path /tmp/humaneval_custom.jsonl \
  --custom-output-len 512 \
  --num-prompts 164 \
  --max-concurrency <C> \
  --num-warmups 5 \
  --request-rate inf

Changes

File Description
vllm/v1/spec_decode/ptd_eagle.py New file: PtdEagleProposer with Triton kernel
vllm/config/speculative.py Add eagle-ptd, eagle3-ptd to EagleModelTypes
vllm/v1/worker/gpu_model_runner.py PTD proposer initialization and selection
vllm/model_executor/models/llama_eagle3.py Skip mask_hidden during weight loading
vllm/transformers_utils/configs/eagle.py Config validation for PTD fields
examples/offline_inference/spec_decode.py Add PTD method to examples
tests/v1/e2e/test_spec_decode.py Add test_ptd_correctness

Test Plan

  • MT-Bench benchmark (K=3-8, C=1,2,4,8)
  • HumanEval benchmark (K=3-8, C=1,2,4,8)
  • Vanilla EAGLE3 comparison
  • End-to-end inference test

@mergify
Copy link

mergify bot commented Jan 20, 2026

Documentation preview: https://vllm--32628.org.readthedocs.build/en/32628/

@mergify mergify bot added documentation Improvements or additions to documentation llama Related to Llama models speculative-decoding v1 labels Jan 20, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Parallel Token Decoding (PTD) for EAGLE speculative decoding, which is a significant performance enhancement. The implementation is well-structured, with a new PtdEagleProposer and a Triton kernel to handle the parallel draft token generation. The changes are consistently integrated across the configuration, model files, and worker. The addition of a new e2e test for PTD correctness is also a good practice, even if it's currently skipped.

My main feedback is regarding the handling of sequences that exceed max_model_len during drafting. The current implementation wraps around the position embeddings, which could lead to poor draft quality and reduced performance. I've left a specific comment with a suggestion on how to address this.

Comment on lines +88 to +96
out_pos = tl.where(out_pos >= max_model_len, 0, out_pos)
tl.store(out_positions_ptr + out_idx, out_pos)

if is_verified:
slot = tl.load(original_slot_mapping_ptr + in_start + local_idx)
else:
last_pos = tl.load(target_positions_ptr + in_start + last_idx)
draft_pos = last_pos + (local_idx - last_idx)
draft_pos = tl.where(draft_pos >= max_model_len, 0, draft_pos)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The position wrapping logic tl.where(out_pos >= max_model_len, 0, out_pos) for draft tokens that exceed max_model_len could lead to poor draft quality and consequently low acceptance rates for long sequences. When a sequence's length plus the number of speculative tokens (K) exceeds max_model_len, the positions for draft tokens are reset to 0. This provides incorrect positional information to the model, especially for models using positional embeddings like RoPE.

While this prevents out-of-bounds errors, it results in generating low-quality draft tokens that are unlikely to be accepted. This is computationally wasteful.

A better approach would be to cap the number of draft tokens for each sequence to ensure the total length does not exceed max_model_len. This could be done within the propose method before preparing inputs for the kernel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haven't confirmed this case, but can see how it would play out. when sequence length approaches max_model_len, draft positions that exceed the limit get wrapped, which would result in low-quality drafts that get rejected. correctness should be preserved but compute would be wasted. will address in a follow-up by capping draft count near max_model_len

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might have some unit tests for this from previous PRs. Worth taking a look to see if this would be covered by our existing test cases

@mergify
Copy link

mergify bot commented Jan 20, 2026

Hi @hai-meh-cs, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 4 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

in_qsl_cpu[1:batch_size+1] - in_qsl_cpu[:batch_size]
)
out_qsl_cpu = torch.zeros(batch_size + 1, dtype=torch.int32)
out_qsl_cpu[1:] = torch.cumsum(accepted_lengths_cpu + draft_len, dim=0)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU/GPU mismatch in query_start_loc calculation

High Severity

The GPU out_qsl is computed using last_token_indices to determine accepted lengths, but out_qsl_cpu ignores last_token_indices and instead uses in_qsl_cpu[1:batch_size+1] - in_qsl_cpu[:batch_size] (the original query lengths). When last_token_indices is explicitly provided (in padded batch mode with rejected tokens), these tensors will have inconsistent values. This mismatch in query_start_loc vs query_start_loc_cpu can cause incorrect attention metadata construction, particularly in backends like FlashInfer that use both CPU and GPU tensors.

Additional Locations (1)

Fix in Cursor Fix in Web

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks cursor bot for flagging! will take a look at the cpu/gpu query_start_loc calculation to verify whether there's an actual mismatch and what the impact would be

block_offset = draft_pos % block_size
block_id = tl.load(block_table_ptr + req_idx * max_blocks + block_num)
slot = block_id * block_size + block_offset
tl.store(out_slot_mapping_ptr + out_idx, slot)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing PADDING_SLOT_ID for exceeded max_model_len positions

Medium Severity

When draft positions exceed max_model_len, the Triton kernel clamps draft_pos to 0 and computes a slot based on block 0, instead of using PADDING_SLOT_ID = -1. The base EagleProposer uses PADDING_SLOT_ID to prevent KV cache corruption for out-of-bounds positions. Without this, draft tokens with positions beyond max_model_len could incorrectly update the KV cache at block 0.

Fix in Cursor Fix in Web

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related to the max_model_len edge case above. will look into using PADDING_SLOT_ID for positions that exceed max_model_len to avoid unnecessary KV cache writes


total_out = (
common_attn_metadata.num_actual_tokens + batch_size * draft_len
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect total_out when rejected tokens present

High Severity

The total_out calculation uses num_actual_tokens (all input tokens including rejected ones), but out_qsl is computed using accepted_lengths (only accepted tokens). When tokens are rejected, total_out > out_qsl[-1]. The kernel grid uses total_out, so for positions beyond out_qsl[-1], the request lookup loop defaults to request 0, computing garbage slot mappings. These garbage values are included in slot_mapping and num_actual_tokens, potentially corrupting the KV cache during attention.

Additional Locations (1)

Fix in Cursor Fix in Web

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will investigate. need to trace through the rejected tokens path to verify whether total_out and output_query_start_loc can actually diverge

self.K = self.num_speculative_tokens
self.slot_buffer = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Buffer overflow when K > 2 in PTD

High Severity

PtdEagleProposer inherits buffer sizes from the parent class where max_num_tokens = max_num_batched_tokens + max_batch_size. However, PTD processes total_out = num_actual_tokens + batch_size * (K-1) positions in a single pass. When K > 2, total_out can exceed max_num_tokens. The Triton kernel writes to input_ids, positions, hidden_states, and slot_buffer up to position total_out - 1, causing out-of-bounds memory writes that could corrupt memory or crash the system.

Additional Locations (1)

Fix in Cursor Fix in Web

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't believe this is correct. we've tested with K = 3 ... 8 without buffer issues. the parent class buffer sizing should account for speculative tokens. will double check the allocation logic to confirm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this does need to account for the new tokens, since we can add up to num_spec_tokens * num_reqs new elements into the batch. I fixed this here on the other branch:

https://github.com/vllm-project/vllm/pull/32887/files#diff-a4809a837fbf535a8f0999b11087a53ec1c53948b50c0a1fe64396bc86de9461R105

@@ -104,13 +105,16 @@ def main(args):
else:
prompts = get_custom_mm_prompts(args.num_prompts)

if args.method == "eagle" or args.method == "eagle3":
if args.method in ("eagle", "eagle3", "eagle-ptd", "eagle3-ptd"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be simpler to add something like "parallel_draft" to the speculative config instead of duplicating all the methods that support this technique, since it's largely independent of the architecture used to draft. We will want to support PARD (parallel drafting for external draft models) at some point, which also shares most of this implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. replaced method names with parallel_draft: bool config flag

(
"eagle3-ptd",
"openai/gpt-oss-120b",
"PATH_TO_PTD_MODEL", # Replace with actual PTD model path
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you planning to release a model alongside this PR? If not, could you upload a dummy checkpoint or find a compatible EAGLE3 checkpoint to use as a placeholder so that the tests are able to run?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are trying to release a checkpoint that are currently re-training with limited datasets. We will upload a dummy checkpoint for now.

@robertgshaw2-redhat
Copy link
Collaborator

this is very exciting!

@mergify
Copy link

mergify bot commented Jan 21, 2026

Hi @hai-meh-cs, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@benchislett
Copy link
Collaborator

Does your PTD model have equivalent number of params as the "vanilla EAGLE3" baseline? Or does it have a bunch of layers?

Add PTD proposer that generates K draft tokens in a single forward pass
using mask tokens, enabling more efficient speculative decoding.

- Add PtdEagleProposer with Triton kernel for input preparation
- Support eagle-ptd and eagle3-ptd speculative methods
- Add test and offline inference example for PTD
- Replace eagle-ptd/eagle3-ptd methods with parallel_draft bool flag
- Reuse parent class load_model() in PtdEagleProposer
- Load mask_hidden via normal weight loading in model
- Improve naming and add Triton kernel documentation
@hai-meh-cs hai-meh-cs changed the title feat: add Parallel Token Decoding for EAGLE feat: add parallel token drafting for EAGLE Jan 22, 2026
if self.method == "eagle3" and self.eagle3_use_aux_hidden_state:
expected_aux_size = self.hidden_size * 3
if self.mask_hidden.shape[-1] == expected_aux_size:
self.mask_hidden = self.model.combine_hidden_states(self.mask_hidden)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem to be necessary. It seems much more effective to do this projection when preparing the model, or omit it entirely during training

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do that

@laviier
Copy link
Contributor

laviier commented Jan 22, 2026

Do you plan to include this bug fix in this pr as well? Or you prefer to keep the bugfix pr separate like https://github.com/hai-meh-cs/vllm-pard-eagle/pull/4?

I've created a new PR request https://github.com/hai-meh-cs/vllm/pull/1 which targets a new bug feature branch to be merged into your feature branch.

@benchislett
Copy link
Collaborator

That bug has been fixed in the draft-model side and likely just needs to be propagated into this branch as well.

@laviier
Copy link
Contributor

laviier commented Jan 22, 2026

That bug has been fixed in the draft-model side and likely just needs to be propagated into this branch as well.

It's a bit different issue as parallel drafting needs more than 1 extra, as a result, I pulled the bug fix you mentioned and iterate on that. Now below is the new logic

             if self.speculative_config is not None and (
                self.speculative_config.uses_draft_model()
                or self.speculative_config.use_eagle()
            ):
                multiplier = (
                    self.speculative_config.num_speculative_tokens
                    if self.speculative_config.parallel_draft
                    else 1
                )
                compile_range_end += multiplier * self.scheduler_config.max_num_seqs

See https://github.com/hai-meh-cs/vllm/pull/1/files

Li Zhang and others added 3 commits January 22, 2026 18:52
…ding

Parallel draft methods (PTD EAGLE) generate K draft tokens in a single
forward pass using mask tokens, which requires larger buffers than sequential
drafting. The inherited buffer allocation formula was insufficient, causing
crashes under load.

Bug manifestation:
- Sequential EAGLE: needs max_num_batched_tokens + max_num_seqs tokens
- Parallel draft: needs max_num_batched_tokens + max_num_seqs * num_speculative_tokens tokens
- Error: "AssertionError: Shape: 8213 out of considered ranges: [(1, 8192)]"

This fix addresses three critical issues:

1. Buffer Allocation (ptd_eagle.py):
   - Corrects max_num_tokens formula for parallel draft generation pattern
   - Reallocates all buffers (input_ids, positions, hidden_states, slot_buffer)
   - Adds ~6MB memory overhead (negligible for 3-4x speedup)

2. Compilation Ranges (vllm.py):
   - Extends compile_ranges_split_points when parallel_draft=True
   - Ensures CUDA graph compilation handles expanded token counts
   - Adds informative logging for parallel draft detection

The bug was caught during benchmarking with 100 prompts (1600 input, 600 output
tokens) where batch size reached 8192 tokens + 7 requests * 3 masks = 8213 tokens,
exceeding the compilation range of 8192.

Tested-by: Load testing with max batch size configurations

Signed-off-by: Li Zhang <lzhanga@amazon.com>

Simplify updates to eagle files

Signed-off-by: Li Zhang <lzhanga@amazon.com>

Minor format updates
[Bugfix][ptd_eagle] Fix buffer overflow in PTD EAGLE speculative deco…
- Use PADDING_SLOT_ID (-1) for overflow draft positions to avoid KV cache writes
- Pre-project mask_hidden through fc layer for eagle3 architecture
- Add comprehensive unit tests for PTD kernel and proposer

Signed-off-by: Jaime Campos Salas <jaime.campos.salas@gmail.com>
@mergify
Copy link

mergify bot commented Jan 23, 2026

Hi @hai-meh-cs, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link

mergify bot commented Jan 28, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hai-meh-cs.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 28, 2026
@benchislett
Copy link
Collaborator

Closing now that #32887 is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation llama Related to Llama models needs-rebase speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants