Skip to content

Conversation

@TheEpicDolphin
Copy link
Collaborator

@TheEpicDolphin TheEpicDolphin commented Jul 2, 2025

Purpose

Add support for tree attention v1 backend. Tree attention is used in EAGLE speculative decoding by the target model to validate a set of draft tokens. Draft tokens only attend to ancestor tokens, and so attention bias must be used to omit attention between non-descendant tokens. To suppor that, I added a new parameter to triton unified_attention called qq_bias. This parameter enables applying query-on-query attention bias using a 2D (q_len, q_len) tensor. This feature is only enabled if a non-None value is provided for that parameter. Otherwise, it is disabled (the default case).

I also implemented the logic for tree draft proposal, in Eagle.py. For chain drafts, it behaves the same as before. However, if a tree of speculative tokens is specified (via the Speculative Config), then this system can leverage TreeAttentionBackend for drafting. Top-K is used to select the drafted child tokens at each level of the tree.

NOTE: This PR does NOT change the existing behavior of v1 EAGLE. It simply adds the capability to use the TreeAttentionBackend, which can validate a tree of draft tokens. However, since tree scoring is still not implemented (I am working on it right now), only chain drafts are supported at this moment. But this is the first step to unlocking tree drafting and scoring functionality!

Test Plan

Benchmark

In addition, I used the following command to run the LLM service and benchmark TreeAttentionBackend vs FlashAttentionBackend:
Server

export VLLM_TORCH_PROFILER_DIR=~/traces/vllm
export LLAMA_MODEL=meta-llama/Llama-3.1-8B-Instruct
export DRAFT_MODEL=yuhuili/EAGLE-LLaMA3.1-Instruct-8B
export VLLM_USE_V1=1
export VLLM_ATTENTION_BACKEND=<backend>
export SPEC_DEC_CONFIG='{"method": "eagle", "model": "'$DRAFT_MODEL'", "num_speculative_tokens": 3, "draft_tensor_parallel_size": 1, "max_model_len": 2048, "speculative_token_tree": "[(0,), (0, 0), (0, 0, 0)]"}'
python -m vllm.entrypoints.openai.api_server --model $LLAMA_MODEL --disable-log-requests --tensor-parallel-size=1 --max-num-seqs=64 --max-model-len=32768 --block-size=128 --no-enable-prefix-caching --speculative-config="$SPEC_DEC_CONFIG" 2>&1 | tee ~/server_logs/vllm_server.log

Client

export LLAMA_MODEL=meta-llama/Llama-3.1-8B-Instruct
python benchmarks/benchmark_serving.py --model $LLAMA_MODEL --tokenizer $LLAMA_MODEL --host 0.0.0.0 --dataset-name random --ignore-eos --request-rate inf --random-input-len 1000 --random-output-len 300 --max-concurrency 64 --num-prompts 128

Results

Serving Benchmark Result Flash Attention (Baseline) Flash Attention (After) Tree Attention
Successful requests 128 128 128
Benchmark duration (s) 11.82 11.6 12.80
Total input tokens 127731 127731 127731
Total generated tokens 38400 38400 38400
Request throughput (req/s) 10.83 11.04 10.00
Output token throughput (tok/s) 3248.45 3311.34 3000.36
Total Token throughput (tok/s) 14053.85 14325.95 12980.52
Time to First Token
Mean TTFT (ms) 659.81 608.03 711.49
Median TTFT (ms) 238.31 212.95 181.93
P99 TTFT (ms) 2016.53 1929.07 2261.90
Time per Output Token (excl. 1st token)
Mean TPOT (ms) 13.6 13.52 15.77
Median TPOT (ms) 12.89 12.66 14.94
P99 TPOT (ms) 26.36 26.3 27.57
Inter-token Latency
Mean ITL (ms) 30.94 30.96 30.43
Median ITL (ms) 21.53 21.27 21.91
P99 ITL (ms) 230.79 236.71 206.91

This benchmarking helped me verify that this PR did NOT regress performance on v1 spec decoding.
Improvements still need to be made for tree attention. I will investigate further on how to close the gap.

Manual Testing

Used the code below to send a completion request to the vLLM service running with TREE_ATTN backend:

from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
response = client.chat.completions.create(model="meta-llama/Llama-3.1-8B-Instruct", messages=[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Explain the theory of relativity in simple terms."}],temperature=0.2)
print(response)

Flash Attention Output

ChatCompletion(id='chatcmpl-6fcc6d98bce64d45b18dc795faa788f5', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content="The theory of relativity, developed by Albert Einstein, is a fundamental concept in modern physics. I'll break it down in simple terms:\n\n**What is the theory of relativity?**\n\nThe theory of relativity is a way of understanding how the universe works, particularly when it comes to space and time. It's divided into two main parts: special relativity and general relativity.\n\n**Special Relativity (1905)**\n\nSpecial relativity says that how we measure time and space can be different depending on how fast we're moving and where we are. Here are the key points:\n\n1. **Time dilation**: Time can appear to slow down or speed up depending on your speed. The faster you move, the slower time passes.\n2. **Length contraction**: Objects can appear shorter when you're moving really fast.\n3. **The speed of light is always the same**: No matter how fast you're moving, the speed of light remains constant.\n4. **Relativity of simultaneity**: Two events that happen at the same time for one observer might not happen at the same time for another observer in a different state of motion.\n\n**General Relativity (1915)**\n\nGeneral relativity builds on special relativity and adds gravity to the mix. It says that:\n\n1. **Gravity is not a force**: Gravity is actually the curvature of spacetime caused by massive objects.\n2. **Spacetime is flexible**: The presence of massive objects warps spacetime, creating gravitational fields.\n3. **Equivalence principle**: The effects of gravity are equivalent to the effects of acceleration.\n\n**Key Takeaways**\n\nThe theory of relativity revolutionized our understanding of space, time, and gravity. Some of the key implications include:\n\n* Time and space are not absolute, but relative to the observer.\n* The laws of physics are the same everywhere in the universe.\n* Gravity is not a force, but a result of the curvature of spacetime.\n\n**In Simple Terms**\n\nImagine you're on a train, and you throw a ball straight up in the air. To you, on the train, the ball goes straight up and comes straight back down. But to someone watching from the platform, the ball looks like it's moving in a curved path because the train is moving really fast.\n\nThat's kind of like what's happening with time and space in the theory of relativity. The faster you move, the more time and space can appear to change. And gravity is like a big, cosmic curve that warps spacetime, affecting how objects move and interact.\n\nI hope this helps you understand the basics of the theory of relativity!", refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[], reasoning_content=None), stop_reason=None)], created=1752615338, model='meta-llama/Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=536, prompt_tokens=52, total_tokens=588, completion_tokens_details=None, prompt_tokens_details=None), prompt_logprobs=None, kv_transfer_params=None)

Tree Attention Output

ChatCompletion(id='chatcmpl-1ff4447ed33e4a91b89fa3f1e25d1b14', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content="The theory of relativity, developed by Albert Einstein, is a fundamental concept in modern physics. I'll break it down in simple terms:\n\n**What is the theory of relativity?**\n\nThe theory of relativity is a way of understanding how the universe works, particularly when it comes to space and time. It's based on two main ideas: special relativity and general relativity.\n\n**Special Relativity (1905)**\n\nSpecial relativity says that how we measure time and space can be different depending on how fast we're moving and where we are. Here are some key points:\n\n1. **Time dilation**: Time can seem to pass slower for someone moving really fast compared to someone who is standing still.\n2. **Length contraction**: Objects can appear shorter to someone moving really fast compared to someone who is standing still.\n3. **The speed of light is always the same**: No matter how fast you're moving, the speed of light remains the same.\n\n**General Relativity (1915)**\n\nGeneral relativity builds on special relativity and adds a new idea: gravity is not a force, but rather a curvature of space and time caused by massive objects. Here are some key points:\n\n1. **Gravity warps space and time**: The more massive an object is, the more it warps the fabric of space and time around it.\n2. **Gravity is not a force**: Objects don't attract each other with a force called gravity; instead, they follow the curvature of space and time.\n\n**Key Takeaways**\n\n1. **Time and space are relative**: They can be affected by motion and gravity.\n2. **The speed of light is always the same**: It's a universal constant that doesn't change.\n3. **Gravity is a curvature of space and time**: It's not a force, but rather a result of massive objects warping the fabric of the universe.\n\nThe theory of relativity has been extensively tested and confirmed through numerous experiments and observations. It's a fundamental concept in modern physics and has had a profound impact on our understanding of the universe.", refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[], reasoning_content=None), stop_reason=None)], created=1752621554, model='meta-llama/Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=423, prompt_tokens=52, total_tokens=475, completion_tokens_details=None, prompt_tokens_details=None), prompt_logprobs=None, kv_transfer_params=None)

Tree Drafts

I tested generating a tree with the following structure:

ROOT
├── 0
│   ├── 0
│   │   └── 0 ── 0 ── 0
│   └── 1
│       └── 0 ── 0 ── 0
├── 1
│   ├── 0
│   │   └── 0 ── 0 ── 0
│   └── 1
│       └── 0 ── 0 ── 0
└── 2
    ├── 0
    │   └── 0 ── 0 ── 0
    └── 1
        └── 0 ── 0 ── 0

Represented by the following list of tuples:

[(0,), (1,), (2,), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 0), (2, 0, 0), (2, 1, 0), (0, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0), (1, 1, 0, 0), (2, 0, 0, 0), (2, 1, 0, 0), (0, 0, 0, 0, 0), (0, 1, 0, 0, 0), (1, 0, 0, 0, 0), (1, 1, 0, 0, 0), (2, 0, 0, 0, 0), (2, 1, 0, 0, 0)]

For the input prompt, "Explain the theory of relativity in simple terms.", the backend proposed the following speculative tokens:

"The"
├── " theory"
│   ├── " of"
│   │   └── " rel" ── "ativity" ── ","
│   └── " is"
│       └── " a" ── " fundamental" ── " theory"
├── " Theory"
│   ├── " of"
│   │   └── " Rel" ── "ativity" ── ","
│   └── " Of"
│       └── " Rel" ── "ativity" ── ","
└── " Einstein"
    ├── " Theory"
    │   └── "," ── " " ── " Albert"
    └── "'s"
        └── " Theory" ── " of" ── " Rel"

And also for the input prompt, "Write the first line of a novel that doesn’t exist yet.":

"As"
├── " the"
│   ├── " city"
│   │   └── " stood" ── "at" ── " the"
│   └── " storm"
│       └── " began" ── " its'" ── " whispered"
├── " she"
│   ├── " stood"
│   │   └── " in" ── " the" ── " dim"
│   └── " walked"
│       └── " through" ── " the" ── " mist"
└── " I"
    ├── " stood"
    │   └── " in" ── " the" ── " dim"
    └── " walked"
        └── " through" ── " the" ── " bustling"

The paths in either draft trees are coherent.

NOTE: There is currently no way to sample tokens from a tree, so when this token tree was used, only the first few tokens were ever accepted.

Eagle Test

Added test case for tree attention backend. All pass:

(py312conda) bash-5.1$ pytest tests/v1/spec_decode/test_eagle.py -k test_propose
============================================================================= test session starts ==============================================================================
platform linux -- Python 3.12.9, pytest-8.4.1, pluggy-1.6.0
rootdir: /data/users/gdelfin/gitrepos/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0
collected 15 items / 9 deselected / 6 selected                                                                                                                                 

tests/v1/spec_decode/test_eagle.py ......                                                                                                                                [100%]

=============================================================================== warnings summary ===============================================================================
../../../../../home/gdelfin/.conda/envs/py312conda/lib/python3.12/site-packages/triton/runtime/autotuner.py:108
../../../../../home/gdelfin/.conda/envs/py312conda/lib/python3.12/site-packages/triton/runtime/autotuner.py:108
../../../../../home/gdelfin/.conda/envs/py312conda/lib/python3.12/site-packages/triton/runtime/autotuner.py:108
../../../../../home/gdelfin/.conda/envs/py312conda/lib/python3.12/site-packages/triton/runtime/autotuner.py:108
  /home/gdelfin/.conda/envs/py312conda/lib/python3.12/site-packages/triton/runtime/autotuner.py:108: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details.
    warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================= 6 passed, 9 deselected, 4 warnings in 9.82s ==================================================================

Tree Attention Correctness Test

(py312conda) bash-5.1$ pytest tests/v1/spec_decode/test_tree_attention.py -k test_tree_attn_correctness
================================================================================================================ test session starts ================================================================================================================
platform linux -- Python 3.12.9, pytest-8.4.1, pluggy-1.6.0
rootdir: /data/users/gdelfin/gitrepos/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0
collected 1 item                                                                                                                                                                                                                                    

tests/v1/spec_decode/test_tree_attention.py .                                                                                                                                                                                                 [100%]

================================================================================================================= 1 passed in 3.78s =================================================================================================================

Also added a test case to test_attention_backends for tree attention.

(py312conda) bash-5.1$ pytest tests/v1/attention/test_attention_backends.py -k test_backend_correctness
========================================================================= test session starts =========================================================================
platform linux -- Python 3.12.9, pytest-8.4.1, pluggy-1.6.0
rootdir: /data/users/gdelfin/gitrepos/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0
collected 6 items                                                                                                                                                     

tests/v1/attention/test_attention_backends.py ......                                                                                                            [100%]

========================================================================= 6 passed in 24.79s ==========================================================================

Tree Attention vs Triton Attention

Given that tree attention backend currently uses triton attention under the hood, but with a custom query-on-query tree attention bias, I decided to measure the performance difference between the two, for various batch sizes, sequence lengths, and query lengths. In this case, Seqlen Q could represent the tree of tokens that is being validated by the target model. The Here are the results:

Backend Batch Size Sequence Position Seqlen Q Time (us)
TritonAttentionBackend 1 16 1 13.82344496
TreeAttentionBackend 1 16 1 14.91896063
TritonAttentionBackend 1 16 4 11.25741284
TreeAttentionBackend 1 16 4 12.08161376
TritonAttentionBackend 1 16 13 11.28191594
TreeAttentionBackend 1 16 13 12.12395076
TritonAttentionBackend 1 16 40 11.40984613
TreeAttentionBackend 1 16 40 12.24565692
TritonAttentionBackend 1 1024 1 14.24317993
TreeAttentionBackend 1 1024 1 15.28975647
TritonAttentionBackend 1 1024 4 29.86434661
TreeAttentionBackend 1 1024 4 33.38198364
TritonAttentionBackend 1 1024 13 29.9381651
TreeAttentionBackend 1 1024 13 33.45626593
TritonAttentionBackend 1 1024 40 30.08378111
TreeAttentionBackend 1 1024 40 33.47317129
TritonAttentionBackend 1 2048 1 16.25721902
TreeAttentionBackend 1 2048 1 17.85656624
TritonAttentionBackend 1 2048 4 48.68557677
TreeAttentionBackend 1 2048 4 54.80957031
TritonAttentionBackend 1 2048 13 48.73485863
TreeAttentionBackend 1 2048 13 54.84380573
TritonAttentionBackend 1 2048 40 48.88690636
TreeAttentionBackend 1 2048 40 54.83540148
TritonAttentionBackend 16 16 1 150.4525393
TreeAttentionBackend 16 16 1 151.191473
TritonAttentionBackend 16 16 4 147.2443491
TreeAttentionBackend 16 16 4 147.4563181
TritonAttentionBackend 16 16 13 147.6421207
TreeAttentionBackend 16 16 13 147.7924883
TritonAttentionBackend 16 16 40 152.5316983
TreeAttentionBackend 16 16 40 153.0877352
TritonAttentionBackend 16 1024 1 161.1640006
TreeAttentionBackend 16 1024 1 162.6019776
TritonAttentionBackend 16 1024 4 163.6626571
TreeAttentionBackend 16 1024 4 163.9055759
TritonAttentionBackend 16 1024 13 163.8339609
TreeAttentionBackend 16 1024 13 163.6852324
TritonAttentionBackend 16 1024 40 185.215503
TreeAttentionBackend 16 1024 40 187.3141825
TritonAttentionBackend 16 2048 1 169.8984355
TreeAttentionBackend 16 2048 1 171.3221073
TritonAttentionBackend 16 2048 4 179.2972535
TreeAttentionBackend 16 2048 4 179.6866953
TritonAttentionBackend 16 2048 13 179.8290014
TreeAttentionBackend 16 2048 13 179.4920713
TritonAttentionBackend 16 2048 40 218.9038545
TreeAttentionBackend 16 2048 40 223.4884799
TritonAttentionBackend 64 16 1 552.5404811
TreeAttentionBackend 64 16 1 552.8869033
TritonAttentionBackend 64 16 4 552.7619123
TreeAttentionBackend 64 16 4 553.4328222
TritonAttentionBackend 64 16 13 557.2404861
TreeAttentionBackend 64 16 13 557.3884845
TritonAttentionBackend 64 16 40 575.0955939
TreeAttentionBackend 64 16 40 577.0654082
TritonAttentionBackend 64 1024 1 589.0027285
TreeAttentionBackend 64 1024 1 588.8522267
TritonAttentionBackend 64 1024 4 591.3416147
TreeAttentionBackend 64 1024 4 591.4412737
TritonAttentionBackend 64 1024 13 596.9859958
TreeAttentionBackend 64 1024 13 597.3982215
TritonAttentionBackend 64 1024 40 662.3853445
TreeAttentionBackend 64 1024 40 666.0534143
TritonAttentionBackend 64 2048 1 619.3435192
TreeAttentionBackend 64 2048 1 619.8448539
TritonAttentionBackend 64 2048 4 621.242404
TreeAttentionBackend 64 2048 4 621.1918592
TritonAttentionBackend 64 2048 13 633.0024004
TreeAttentionBackend 64 2048 13 634.4519854
TritonAttentionBackend 64 2048 40 752.9329062
TreeAttentionBackend 64 2048 40 760.2900267

This demonstrates that the addition of the custom, tree attention bias does not significantly regress the overall performance. I expect that the increase in avg accepted token length from tree draft tokens will more than compensate for the minor increase in attention latency.

TODOs

The following actions still need to be taken to fully enable this backend:

  • Fix paged KV when a tree branch is selected
  • Add general support for setting draft model attention backend. It currently is forced to FlashAttentionBackend.

As of this diff, only chain drafts are supported by TreeAttentionBackend. This is because EagleProposer still only generates draft chains.

@github-actions
Copy link

github-actions bot commented Jul 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @TheEpicDolphin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the initial phase of a Tree Attention backend into v1 of the attention system, specifically to support EAGLE speculative decoding. The changes enable the efficient validation of draft tokens by implementing a tree-based attention mechanism that correctly applies necessary attention biases. This work involves significant additions to the attention backend infrastructure, updates to model architecture to utilize the new backend, and includes a correctness test to ensure functionality.

Highlights

  • New Tree Attention Backend: Introduced TreeAttentionBackend and TreeAttentionImpl to add support for tree attention, which is a key component for EAGLE speculative decoding in v1 of the attention system.
  • Attention Bias Implementation: The TreeAttentionImpl leverages xformers.ops.tree_attention and correctly applies both prefix and speculative (suffix) attention biases, essential for managing attention between draft tokens and their ancestors or prompt tokens.
  • Dynamic Backend Selection and Draft Model Support: The attention backend selection logic has been updated to include TREE_ATTN and now incorporates an is_draft flag, allowing the system to differentiate and select appropriate attention backends for draft models within the speculative decoding framework.
  • Optimized Batch Processing: A new TreeAttentionMetadataBuilder was added to reorder batches, prioritizing decode requests, and to efficiently construct attention metadata for both prefill (handled by FlashAttention) and speculative decode phases.
  • Correctness Validation: A new test, test_tree_attn_correctness, was implemented to verify the numerical correctness of the TreeAttentionBackend by comparing its output against FlashAttentionBackend across various configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added llama Related to Llama models speculative-decoding v1 labels Jul 2, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new TreeAttentionBackend for speculative decoding, which is a significant feature addition. The implementation is well-structured, reusing FlashAttentionImpl for prefill requests and using xformers for the tree attention part. The new test file provides good coverage for correctness verification.

I've identified a critical issue with duplicated fields in a dataclass and a few medium-severity issues related to code correctness, performance, and maintainability. Addressing these will improve the quality and robustness of the new backend. Overall, this is a great first step towards enabling tree attention.

@TheEpicDolphin TheEpicDolphin force-pushed the tree_attention_v1 branch 4 times, most recently from 5a37c78 to bfa883a Compare July 2, 2025 21:54
@TheEpicDolphin TheEpicDolphin marked this pull request as ready for review July 2, 2025 22:09
@TheEpicDolphin TheEpicDolphin force-pushed the tree_attention_v1 branch 2 times, most recently from 3ff7ebe to da6c40b Compare July 3, 2025 18:32
@mergify
Copy link

mergify bot commented Jul 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @TheEpicDolphin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 8, 2025
Copy link

@sgrigory sgrigory left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for integrating tree attention! Left a few comments. Regarding the performance, maybe look at the profiles to see what takes the most time - it could be the tree attention itself, but it could also be metadata processing (which we can then take out of decoding loop, at least partially)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This simulates a situation when pages are actually ordered contiguously in physical memory. Would the test also work in a more complex scenario? For example, you can swap two pages

https://github.com/facebookresearch/xformers/blob/80250b32516b019b72bb44be04ca9a8741b42faa/tests/test_mem_eff_attention.py#L2696-L2699

or even shuffle them all

https://github.com/Dao-AILab/flash-attention/blob/adf27d1db38223288981c4dc3509efafbddd3422/tests/test_flash_attn.py#L2151-L2155

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: is the comment above "No XFormers so far" still true if you are importing tree attention from xFormers?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just call it XFORMERS?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this at first, but received the following error:

File "/data/users/gdelfin/gitrepos/vllm/vllm/v1/attention/backends/tree_attn.py", line 515, in forward
output[:num_decode_tokens] = tree_attention(
^^^^^^^^^^^^^^^
File "/home/gdelfin/.conda/envs/py312conda/lib/python3.12/site-packages/xformers/ops/tree_attention.py", line 606, in tree_attention
prefix_op = select_prefix_op(
^^^^^^^^^^^^^^^^^
File "/home/gdelfin/.conda/envs/py312conda/lib/python3.12/site-packages/xformers/ops/tree_attention.py", line 491, in select_prefix_op
fa3_supported = isinstance(attn_bias, flash3.FwOp.SUPPORTED_ATTN_BIAS_TYPES) # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: isinstance() arg 2 must be a type, a tuple of types, or a union

It seems like the current xformers version (0.0.30) in the vllm/requirements/cuda.txt file has a type error that needs to be fixed to enable prefix_op heuristic. Either that, or we bump the required xformers version. I can look into doing this later after finishing up the xformers v1 backend!

@sgrigory
Copy link

cc @bottler

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) August 4, 2025 01:56
@TheEpicDolphin
Copy link
Collaborator Author

Failure is due to flakey test discussed in https://vllm-dev.slack.com/archives/C07R5PAL2L9/p1754127415660409. It is not caused by this PR. Will need help to force-merge this

@vllm-bot vllm-bot merged commit aa7012e into vllm-project:main Aug 4, 2025
43 of 46 checks passed
@DarkLight1337
Copy link
Member

@TheEpicDolphin
Copy link
Collaborator Author

TheEpicDolphin commented Aug 4, 2025

PTAL https://buildkite.com/vllm/ci/builds/25909/steps/canvas?jid=01987380-9186-492e-9655-3f1376603328

thanks for the heads up, i'll look into this test issue

Edit: I have a draft PR with the fix here #22207. Will publish shortly

@TheEpicDolphin
Copy link
Collaborator Author

@DarkLight1337 Fix is ready here: #22207

@wangjiahe0915
Copy link

I have a question. vLLM uses mixed scheduling for the prefill and decode stages, but your current operator completely separates prefill and decode. So, as I understand it, when validating in the main model (and there is a new prefill request at that time), it cannot be handled properly, right?

npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
@TheEpicDolphin
Copy link
Collaborator Author

I believe that this method: https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/tree_attn.py#L186
reorders the batches so that the decodes and prefills are contiguous, with the prefills at the end. During the forward pass, tree attention only applies masking to the decode batches.

jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
@tjtanaa tjtanaa mentioned this pull request Aug 11, 2025
3 tasks
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
@luccafong luccafong mentioned this pull request Sep 1, 2025
5 tasks
@TheEpicDolphin
Copy link
Collaborator Author

Part 2 which enables end-to-end support for tree spec decoding in V1 is here: #22752

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants