Skip to content

Commit 7ba6a65

Browse files
TheEpicDolphinjinzhen-lin
authored andcommitted
Add tree attention backend for v1 (part 1) (vllm-project#20401)
Signed-off-by: Giancarlo Delfin <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
1 parent c362f3c commit 7ba6a65

File tree

12 files changed

+1098
-25
lines changed

12 files changed

+1098
-25
lines changed

tests/v1/attention/test_attention_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
BACKENDS_TO_TEST = [
1919
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
20-
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1
20+
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN
2121
]
2222

2323
# Remove flashinfer from the list if it's not available

tests/v1/attention/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ def create_common_attn_metadata(
109109

110110
def get_attention_backend(backend_name: _Backend):
111111
"""Set up attention backend classes for testing.
112-
112+
113113
Args:
114114
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
115115
vllm_config: VllmConfig instance
116-
116+
117117
Returns:
118118
Tuple of (backend_builder_class, backend_impl_class)
119119
"""
@@ -126,6 +126,8 @@ def get_attention_backend(backend_name: _Backend):
126126
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
127127
_Backend.TRITON_ATTN_VLLM_V1:
128128
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
129+
_Backend.TREE_ATTN:
130+
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
129131
}
130132

131133
if backend_name not in backend_map:

tests/v1/spec_decode/test_eagle.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@ class _TargetModelStub(LlamaForCausalLM):
202202

203203

204204
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
205-
def test_propose(num_speculative_tokens):
205+
@pytest.mark.parametrize("backend",
206+
[_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
207+
def test_propose(num_speculative_tokens, backend):
206208
# Use GPU device
207209
device = torch.device(current_platform.device_type)
208210

@@ -301,8 +303,7 @@ def create_deterministic_logits(token_ids):
301303
device=device)
302304
sampling_metadata = mock.MagicMock()
303305

304-
attn_metadata_builder_cls, _ = get_attention_backend(
305-
_Backend.FLASH_ATTN_VLLM_V1)
306+
attn_metadata_builder_cls, _ = get_attention_backend(backend)
306307
attn_metadata_builder = attn_metadata_builder_cls(
307308
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
308309
layer_names=proposer.attn_layer_names,
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import math
5+
from typing import Optional
6+
7+
import torch
8+
9+
from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec,
10+
create_vllm_config,
11+
get_attention_backend)
12+
from vllm.config import ParallelConfig, SpeculativeConfig
13+
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
14+
15+
16+
class MockAttentionLayer(torch.nn.Module):
17+
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
18+
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
19+
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
20+
21+
def __init__(self):
22+
super().__init__()
23+
24+
def forward(self, x):
25+
return x
26+
27+
28+
def forward_attention(
29+
q: torch.Tensor,
30+
k: torch.Tensor,
31+
v: torch.Tensor,
32+
kv_cache: torch.Tensor,
33+
block_table: torch.Tensor,
34+
slot_mapping: torch.Tensor,
35+
seqlen_k: int,
36+
backend: _Backend,
37+
spec_token_tree: Optional[str] = None,
38+
num_spec_tokens: int = 0,
39+
) -> torch.Tensor:
40+
batch_size, q_len, num_heads, dim_per_head = q.shape
41+
num_kv_heads = k.shape[-2]
42+
# Initialize the query and KV sequence lengths.
43+
query_start_loc = q_len * torch.arange(
44+
batch_size + 1, device=q.device, dtype=torch.int32)
45+
query_lens = torch.diff(query_start_loc)
46+
seq_lens = torch.full(
47+
(batch_size, ),
48+
seqlen_k,
49+
device=q.device,
50+
dtype=torch.int32,
51+
)
52+
context_lens = seq_lens - query_lens
53+
max_query_len = q_len
54+
num_actual_tokens = query_start_loc[-1]
55+
56+
softmax_scale = q.shape[-1]**(-0.5)
57+
layer = MockAttentionLayer()
58+
59+
# Build common metadata.
60+
model_name = "meta-llama/Meta-Llama-3-8B"
61+
builder_cls, impl_cls = get_attention_backend(backend)
62+
vllm_config = create_vllm_config(model_name=model_name,
63+
max_model_len=max(seq_lens))
64+
if spec_token_tree is not None:
65+
# Create speculative config if token tree is specified.
66+
vllm_config.speculative_config = SpeculativeConfig(
67+
target_model_config=vllm_config.model_config,
68+
target_parallel_config=ParallelConfig(),
69+
model=model_name,
70+
method="eagle",
71+
num_speculative_tokens=num_spec_tokens,
72+
speculative_token_tree=spec_token_tree)
73+
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
74+
builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
75+
common_attn_metadata = CommonAttentionMetadata(
76+
query_start_loc=query_start_loc,
77+
query_start_loc_cpu=query_start_loc.cpu(),
78+
seq_lens=seq_lens,
79+
seq_lens_cpu=seq_lens.cpu(),
80+
num_computed_tokens_cpu=context_lens.cpu(),
81+
num_reqs=batch_size,
82+
num_actual_tokens=num_actual_tokens,
83+
max_query_len=max_query_len,
84+
block_table_tensor=block_table,
85+
slot_mapping=slot_mapping,
86+
)
87+
88+
# Build attention metadata.
89+
attn_metadata = builder.build(
90+
common_prefix_len=0,
91+
common_attn_metadata=common_attn_metadata,
92+
)
93+
94+
# Initialize the backend implementation.
95+
instance = impl_cls(
96+
num_heads=num_heads,
97+
head_size=dim_per_head,
98+
scale=softmax_scale,
99+
num_kv_heads=num_kv_heads,
100+
alibi_slopes=None,
101+
sliding_window=None,
102+
kv_cache_dtype="auto",
103+
)
104+
105+
# Run forward pass and return output.
106+
query = q.view(-1, num_heads, dim_per_head)
107+
key = k.view(-1, num_kv_heads, dim_per_head)
108+
value = v.view(-1, num_kv_heads, dim_per_head)
109+
output = torch.empty_like(query)
110+
return instance.forward(
111+
layer=layer,
112+
query=query,
113+
key=key,
114+
value=value,
115+
kv_cache=kv_cache.clone(),
116+
attn_metadata=attn_metadata,
117+
output=output,
118+
)
119+
120+
121+
def test_tree_attn_correctness() -> None:
122+
torch.manual_seed(42)
123+
torch.cuda.manual_seed_all(42)
124+
125+
device = "cuda"
126+
tree_attn_masks = {
127+
# Chain.
128+
"[(0,), (0, 0), (0, 0, 0)]":
129+
torch.tensor(
130+
[
131+
[1, 0, 0, 0],
132+
[1, 1, 0, 0],
133+
[1, 1, 1, 0],
134+
[1, 1, 1, 1],
135+
],
136+
device=device,
137+
dtype=torch.int32,
138+
),
139+
# Tree.
140+
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]":
141+
torch.tensor(
142+
[
143+
[1, 0, 0, 0, 0, 0, 0],
144+
[1, 1, 0, 0, 0, 0, 0],
145+
[1, 0, 1, 0, 0, 0, 0],
146+
[1, 1, 0, 1, 0, 0, 0],
147+
[1, 1, 0, 0, 1, 0, 0],
148+
[1, 0, 1, 0, 0, 1, 0],
149+
[1, 0, 1, 0, 0, 0, 1],
150+
],
151+
device=device,
152+
dtype=torch.int32,
153+
),
154+
}
155+
156+
dim_per_head = 128
157+
num_kv_heads = 2
158+
block_size = 128
159+
max_sequence_length = 8192
160+
randomize_blocks = True
161+
for batch_size in [1, 16, 32]:
162+
for num_heads in [2, 4]:
163+
for sequence_position in [16, 1024, 2048]:
164+
for spec_token_tree, tree_attn_mask in tree_attn_masks.items():
165+
# Assert that the number of heads is divisible
166+
# by the number of KV heads.
167+
assert num_heads % num_kv_heads == 0
168+
169+
# Initialize q, k, and v.
170+
tree_size_q = tree_attn_mask.shape[0]
171+
seqlen_k = sequence_position + tree_size_q
172+
q = torch.randn(
173+
(batch_size, tree_size_q, num_heads, dim_per_head),
174+
device=device,
175+
dtype=torch.bfloat16,
176+
)
177+
k = torch.randn(
178+
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
179+
device=device,
180+
dtype=torch.bfloat16,
181+
)
182+
v = torch.randn(
183+
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
184+
device=device,
185+
dtype=torch.bfloat16,
186+
)
187+
188+
# Setup the block table and KV cache for paged KV.
189+
assert max_sequence_length % block_size == 0
190+
max_blocks_per_batch = max_sequence_length // block_size
191+
kv_cache = torch.randn(
192+
(
193+
2,
194+
batch_size * max_blocks_per_batch,
195+
block_size,
196+
num_kv_heads,
197+
dim_per_head,
198+
),
199+
device=q.device,
200+
dtype=torch.bfloat16,
201+
)
202+
num_alloc_blocks_per_batch = math.ceil(seqlen_k /
203+
block_size)
204+
block_table = torch.zeros(
205+
(batch_size, max_blocks_per_batch),
206+
device=q.device,
207+
dtype=torch.int32,
208+
)
209+
block_ids = torch.arange(
210+
0,
211+
batch_size * num_alloc_blocks_per_batch,
212+
device=q.device,
213+
dtype=torch.int32,
214+
)
215+
if randomize_blocks:
216+
# Randomize the block ids.
217+
block_ids = block_ids[torch.randperm(
218+
block_ids.numel())]
219+
block_table[:, :
220+
num_alloc_blocks_per_batch] = block_ids.view(
221+
-1, num_alloc_blocks_per_batch)
222+
223+
# Setup the slot mapping for the input KVs.
224+
tree_positions = sequence_position + torch.arange(
225+
0,
226+
tree_size_q,
227+
device=q.device,
228+
dtype=torch.int64,
229+
).repeat(batch_size, 1)
230+
tree_slot_mapping = _gen_slot_mapping(
231+
tree_positions, block_table, block_size)
232+
233+
# Compute attention for the tree.
234+
tree_attn_output = forward_attention(
235+
q=q,
236+
k=k,
237+
v=v,
238+
kv_cache=kv_cache,
239+
block_table=block_table,
240+
slot_mapping=tree_slot_mapping,
241+
seqlen_k=seqlen_k,
242+
backend=_Backend.TREE_ATTN,
243+
spec_token_tree=spec_token_tree,
244+
num_spec_tokens=tree_size_q - 1,
245+
).view(batch_size, -1, num_heads, dim_per_head)
246+
247+
# Verify that the chain attention output for each
248+
# branch of the tree (computed using FA3) matches
249+
# the tree attention output.
250+
for q_index in range(tree_size_q):
251+
# Get the q, k, and v for the branch.
252+
branch_mask = tree_attn_mask[q_index, :]
253+
branch_indices = torch.nonzero(branch_mask,
254+
as_tuple=True)[0]
255+
q_len = branch_indices.shape[0]
256+
q_branch = q[:, branch_indices]
257+
k_branch = k[:, branch_indices]
258+
v_branch = v[:, branch_indices]
259+
260+
# Setup slot mapping for the branch.
261+
branch_positions = sequence_position + torch.arange(
262+
0,
263+
q_len,
264+
device=q.device,
265+
dtype=torch.int64,
266+
).repeat(batch_size, 1)
267+
branch_slot_mapping = _gen_slot_mapping(
268+
branch_positions, block_table, block_size)
269+
270+
# Compute flash attention for the branch.
271+
flash_attn_output = forward_attention(
272+
q=q_branch,
273+
k=k_branch,
274+
v=v_branch,
275+
kv_cache=kv_cache,
276+
block_table=block_table,
277+
slot_mapping=branch_slot_mapping,
278+
seqlen_k=sequence_position + q_len,
279+
backend=_Backend.FLASH_ATTN_VLLM_V1,
280+
).view(batch_size, -1, num_heads, dim_per_head)
281+
282+
# Compare the outputs.
283+
assert torch.allclose(
284+
tree_attn_output[:, branch_indices],
285+
flash_attn_output,
286+
atol=7.81e-3,
287+
), (f"outputs are not close for "
288+
f"batch_size: {batch_size}, "
289+
f"num_heads: {num_heads}, "
290+
f"sequence_position: {sequence_position}, "
291+
f"tree_attn_mask: {tree_attn_mask}, "
292+
f"q_index: {q_index}.")
293+
294+
295+
def _gen_slot_mapping(positions: torch.Tensor, block_table: torch.Tensor,
296+
block_size: int):
297+
block_indices = positions // block_size
298+
blocks = block_table.gather(dim=1, index=block_indices)
299+
return (blocks * block_size + positions % block_size).view(-1)

0 commit comments

Comments
 (0)