Skip to content

[Bugfix] Fix Hybrid KV cache hit length computation for eagle#33270

Closed
xyang16 wants to merge 3 commits intovllm-project:mainfrom
xyang16:fix
Closed

[Bugfix] Fix Hybrid KV cache hit length computation for eagle#33270
xyang16 wants to merge 3 commits intovllm-project:mainfrom
xyang16:fix

Conversation

@xyang16
Copy link
Copy Markdown
Contributor

@xyang16 xyang16 commented Jan 28, 2026

Purpose

We observed a 20% performance regression for gpt-oss with eagle in 0.14.0 release. We found it's caused by 0% prefix cache hit rate.

  • Recently, find_longest_cache_hit is changed to use a while loop to check all attention groups until a prefix gets cache hits from all of them in [Feat][Core] Support multiple KV cache groups in Hybrid KV Coordinator #31707. However, if eagle is enabled, the last matched block is dropped, see here and here. In every while loop,curr_hit_length is reduced by block_size and curr_hit_length < hit_length is always true here, until hit_length reach 0. So the returned hit_blocks become empty and returned hit_length become 0. This makes prefix cache hit rate always 0%.
  • Therefore, if eagle is enabled, this PR fix it by reuse cached blocks, to avoid recalculation and hit blocks become empty. This restore to the previous behavior where it does not iterate in while-loop iteration.
  • Tested gpt-oss with eagle (full attention with sliding window). Mamba doesn't support prefix caching with speculative decoding, see here, so mamba will not be impacted by this PR.
  • Added tests in test_prefix_caching.py.

Test Plan

pytest -s -v tests/v1/core/test_prefix_caching.py

Test Result

Unit tests passed.

Benchmarking

vllm serve openai/gpt-oss-120b \
  --tensor-parallel-size 8 \
  --max-num-seqs 128 \
  --speculative-config '{"method": "eagle3", "model": "nvidia/gpt-oss-120b-Eagle3-short-context", "num_speculative_tokens": 3}'
vllm bench serve \
    --model openai/gpt-oss-120b \
    --dataset-name sonnet \
    --dataset-path /tmp/sonnet.txt \
    --sonnet-input-len 1000 \
    --sonnet-output-len 200 \
    --max-concurrency 64 \
    --num-prompts 1000 \
    --num-warmups 100 \
    --ignore-eos

Main:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             64        
Benchmark duration (s):                  74.13     
Total input tokens:                      987283    
Total generated tokens:                  200000    
Request throughput (req/s):              13.49     
Output token throughput (tok/s):         2697.94   
Peak output token throughput (tok/s):    4182.00   
Peak concurrent requests:                94.00     
Total token throughput (tok/s):          16016.12  
---------------Time to First Token----------------
Mean TTFT (ms):                          322.33    
Median TTFT (ms):                        197.75    
P99 TTFT (ms):                           2132.18   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          22.06     
Median TPOT (ms):                        21.37     
P99 TPOT (ms):                           39.94     
---------------Inter-token Latency----------------
Mean ITL (ms):                           45.93     
Median ITL (ms):                         12.75     
P99 ITL (ms):                            98.15     
---------------Speculative Decoding---------------
Acceptance rate (%):                     36.36     
Acceptance length:                       2.09      
Drafts:                                  95564     
Draft tokens:                            286692    
Accepted tokens:                         104244    
Per-position acceptance (%):
  Position 0:                            56.52     
  Position 1:                            33.46     
  Position 2:                            19.10     
==================================================

PR:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             64        
Benchmark duration (s):                  61.46     
Total input tokens:                      987283    
Total generated tokens:                  200000    
Request throughput (req/s):              16.27     
Output token throughput (tok/s):         3253.98   
Peak output token throughput (tok/s):    4086.00   
Peak concurrent requests:                108.00    
Total token throughput (tok/s):          19316.95  
---------------Time to First Token----------------
Mean TTFT (ms):                          229.40    
Median TTFT (ms):                        179.43    
P99 TTFT (ms):                           1254.31   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          18.46     
Median TPOT (ms):                        18.75     
P99 TPOT (ms):                           25.86     
---------------Inter-token Latency----------------
Mean ITL (ms):                           38.34     
Median ITL (ms):                         12.80     
P99 ITL (ms):                            91.50     
---------------Speculative Decoding---------------
Acceptance rate (%):                     36.17     
Acceptance length:                       2.09      
Drafts:                                  95828     
Draft tokens:                            287484    
Accepted tokens:                         103985    
Per-position acceptance (%):
  Position 0:                            56.30     
  Position 1:                            33.17     
  Position 2:                            19.04     
==================================================

Output token throughput improves 20%.

Accuracy Testing

OPENAI_API_KEY=EMPTY python3 -m gpt_oss.evals --model openai/gpt-oss-120b --eval gpqa --n-threads 200 --reasoning-effort low

Main:

[{'eval_name': 'gpqa', 'model_name': '__opt__dlami__nvme__models__gpt-oss-120b-low_temp1.0_20260128_223750', 'metric': 0.6596212121212122}]

PR:

[{'eval_name': 'gpqa', 'model_name': '__opt__dlami__nvme__models__gpt-oss-120b-low_temp1.0_20260128_222331', 'metric': 0.6565656565656566}]

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Xin Yang <xyangx@amazon.com>
@xyang16 xyang16 changed the title [Bugfix] Fix cache hit computation for eagle [Bugfix] Fix KV cache hit computation for eagle Jan 28, 2026
@mergify mergify bot added v1 bug Something isn't working labels Jan 28, 2026
@xyang16 xyang16 changed the title [Bugfix] Fix KV cache hit computation for eagle [Bugfix] Fix Hybrid KV cache hit length computation for eagle Jan 28, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a performance regression caused by an infinite loop in cache hit computation for hybrid models using Eagle speculative decoding. The fix involves breaking the iterative refinement loop after one pass when Eagle is enabled, which correctly resolves the issue for hybrid attention models. My review focuses on the potential for silent correctness issues when Mamba layers are also present, as the Eagle-specific logic is not applied consistently across all layer types.

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 28, 2026

Hi @xyang16, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: Xin Yang <xyangx@amazon.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 29, 2026

Hi @xyang16, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 29, 2026

Hi @xyang16, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: Xin Yang <xyangx@amazon.com>
@xyang16 xyang16 force-pushed the fix branch 5 times, most recently from 3219b8d to 63e16e4 Compare January 29, 2026 08:03
# (downward-closed property)
cached_blocks = hit_blocks_by_group[group_ids[0]]
if is_full_attn and cached_blocks is not None:
if (is_full_attn or self.use_eagle) and cached_blocks is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for review! I will try this.

Copy link
Copy Markdown
Contributor Author

@xyang16 xyang16 Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried the approach, it still got the same problem, curr_hit_length is reduced by block_size in every while loop until it reaches 0. Also curr_hit_length doesn't match the length of hit_blocks anymore.

                    if self.use_eagle:
                        hit_blocks = manager_cls.find_longest_cache_hit(
                            block_hashes=_get_block_hashes(spec),
                            max_length=curr_hit_length + spec.block_size,
                            kv_cache_group_ids=group_ids,
                            block_pool=self.block_pool,
                            kv_cache_spec=spec,
                            use_eagle=self.use_eagle,
                            alignment_tokens=self.lcm_block_size,
                        )
                        curr_hit_length = max(0, len(hit_blocks[0]) * spec.block_size - spec.block_size)

Please let me know how you think. Thanks!

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 2, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xyang16.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 2, 2026
@xyang16 xyang16 closed this Feb 2, 2026
@xyang16 xyang16 deleted the fix branch February 6, 2026 17:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working needs-rebase v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants