-
-
Notifications
You must be signed in to change notification settings - Fork 14.9k
DO_NOT_COMMIT: benchmark fa2 vs tree attention #23143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -11,6 +11,8 @@ | |||||
| get_attention_backend) | ||||||
| from vllm.config import ParallelConfig, SpeculativeConfig | ||||||
| from vllm.v1.attention.backends.utils import CommonAttentionMetadata | ||||||
| from vllm.v1.attention.backends.flash_attn import (FlashAttentionMetadataBuilder, TreeMetadata) | ||||||
| from vllm.vllm_flash_attn.utils.benchmark import benchmark_forward | ||||||
|
|
||||||
|
|
||||||
| class MockAttentionLayer(torch.nn.Module): | ||||||
|
|
@@ -35,7 +37,9 @@ def forward_attention( | |||||
| seqlen_k: int, | ||||||
| backend: _Backend, | ||||||
| spec_token_tree: Optional[str] = None, | ||||||
| tree_mask: Optional[torch.Tensor] = None, | ||||||
| num_spec_tokens: int = 0, | ||||||
| bench_desc: str = "", | ||||||
| ) -> torch.Tensor: | ||||||
| batch_size, q_len, num_heads, dim_per_head = q.shape | ||||||
| num_kv_heads = k.shape[-2] | ||||||
|
|
@@ -86,10 +90,18 @@ def forward_attention( | |||||
| ) | ||||||
|
|
||||||
| # Build attention metadata. | ||||||
| attn_metadata = builder.build( | ||||||
| common_prefix_len=0, | ||||||
| common_attn_metadata=common_attn_metadata, | ||||||
| ) | ||||||
| if isinstance(builder, FlashAttentionMetadataBuilder): | ||||||
| fa_tree_metadata = TreeMetadata(mask=tree_mask.repeat(batch_size), lens=torch.arange(0, (batch_size+1)*num_spec_tokens, num_spec_tokens, dtype=torch.int32, device=q.device)) if tree_mask is not None else None | ||||||
| attn_metadata = builder.build( | ||||||
| common_prefix_len=0, | ||||||
| common_attn_metadata=common_attn_metadata, | ||||||
| tree_metadata=fa_tree_metadata, | ||||||
| ) | ||||||
| else: | ||||||
| attn_metadata = builder.build( | ||||||
| common_prefix_len=0, | ||||||
| common_attn_metadata=common_attn_metadata, | ||||||
| ) | ||||||
|
|
||||||
| # Initialize the backend implementation. | ||||||
| instance = impl_cls( | ||||||
|
|
@@ -107,6 +119,23 @@ def forward_attention( | |||||
| key = k.view(-1, num_kv_heads, dim_per_head) | ||||||
| value = v.view(-1, num_kv_heads, dim_per_head) | ||||||
| output = torch.empty_like(query) | ||||||
|
|
||||||
| measurement=None | ||||||
|
|
||||||
| if bench_desc != "": | ||||||
| _, measurement = benchmark_forward( | ||||||
| instance.forward, | ||||||
| layer=layer, | ||||||
| query=query, | ||||||
| key=key, | ||||||
| value=value, | ||||||
| kv_cache=kv_cache.clone(), | ||||||
| attn_metadata=attn_metadata, | ||||||
| output=output, | ||||||
| desc=bench_desc, | ||||||
| verbose=False, | ||||||
| ) | ||||||
|
|
||||||
| return instance.forward( | ||||||
| layer=layer, | ||||||
| query=query, | ||||||
|
|
@@ -115,7 +144,7 @@ def forward_attention( | |||||
| kv_cache=kv_cache.clone(), | ||||||
| attn_metadata=attn_metadata, | ||||||
| output=output, | ||||||
| ) | ||||||
| ), measurement | ||||||
|
|
||||||
|
|
||||||
| def test_tree_attn_correctness() -> None: | ||||||
|
|
@@ -125,31 +154,56 @@ def test_tree_attn_correctness() -> None: | |||||
| device = "cuda" | ||||||
| tree_attn_masks = { | ||||||
| # Chain. | ||||||
| "[(0,), (0, 0), (0, 0, 0)]": | ||||||
| torch.tensor( | ||||||
| [ | ||||||
| [1, 0, 0, 0], | ||||||
| [1, 1, 0, 0], | ||||||
| [1, 1, 1, 0], | ||||||
| [1, 1, 1, 1], | ||||||
| ], | ||||||
| device=device, | ||||||
| dtype=torch.int32, | ||||||
| ("[(0,), (0, 0), (0, 0, 0)]", (-1, 0, 1)): | ||||||
| ( | ||||||
| torch.tensor( | ||||||
| [ | ||||||
| [1, 0, 0, 0], | ||||||
| [1, 1, 0, 0], | ||||||
| [1, 1, 1, 0], | ||||||
| [1, 1, 1, 1], | ||||||
| ], | ||||||
| device=device, | ||||||
| dtype=torch.int32, | ||||||
| ), | ||||||
| torch.tensor( | ||||||
| [ | ||||||
| 0b100, | ||||||
| 0b110, | ||||||
| 0b111, | ||||||
| ], | ||||||
| device=device, | ||||||
| dtype=torch.uint64, | ||||||
| ), | ||||||
| ), | ||||||
| # Tree. | ||||||
| "[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": | ||||||
| torch.tensor( | ||||||
| [ | ||||||
| [1, 0, 0, 0, 0, 0, 0], | ||||||
| [1, 1, 0, 0, 0, 0, 0], | ||||||
| [1, 0, 1, 0, 0, 0, 0], | ||||||
| [1, 1, 0, 1, 0, 0, 0], | ||||||
| [1, 1, 0, 0, 1, 0, 0], | ||||||
| [1, 0, 1, 0, 0, 1, 0], | ||||||
| [1, 0, 1, 0, 0, 0, 1], | ||||||
| ], | ||||||
| device=device, | ||||||
| dtype=torch.int32, | ||||||
| ("[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]", (-1, -1, 0, 0, 1, 1)): | ||||||
| ( | ||||||
| torch.tensor( | ||||||
| [ | ||||||
| [1, 0, 0, 0, 0, 0, 0], | ||||||
| [1, 1, 0, 0, 0, 0, 0], | ||||||
| [1, 0, 1, 0, 0, 0, 0], | ||||||
| [1, 1, 0, 1, 0, 0, 0], | ||||||
| [1, 1, 0, 0, 1, 0, 0], | ||||||
| [1, 0, 1, 0, 0, 1, 0], | ||||||
| [1, 0, 1, 0, 0, 0, 1], | ||||||
| ], | ||||||
| device=device, | ||||||
| dtype=torch.int32, | ||||||
| ), | ||||||
| torch.tensor( | ||||||
| [ | ||||||
| 0b100000, | ||||||
| 0b010000, | ||||||
| 0b101000, | ||||||
| 0b100100, | ||||||
| 0b010010, | ||||||
| 0b010001, | ||||||
| ], | ||||||
| device=device, | ||||||
| dtype=torch.uint64 | ||||||
| ), | ||||||
| ), | ||||||
| } | ||||||
|
|
||||||
|
|
@@ -167,7 +221,7 @@ def test_tree_attn_correctness() -> None: | |||||
| assert num_heads % num_kv_heads == 0 | ||||||
|
|
||||||
| # Initialize q, k, and v. | ||||||
| tree_size_q = tree_attn_mask.shape[0] | ||||||
| tree_size_q = tree_attn_mask[0].shape[0] | ||||||
| seqlen_k = sequence_position + tree_size_q | ||||||
| q = torch.randn( | ||||||
| (batch_size, tree_size_q, num_heads, dim_per_head), | ||||||
|
|
@@ -231,7 +285,7 @@ def test_tree_attn_correctness() -> None: | |||||
| tree_positions, block_table, block_size) | ||||||
|
|
||||||
| # Compute attention for the tree. | ||||||
| tree_attn_output = forward_attention( | ||||||
| forward_attn_output1, tree_attn_measurement = forward_attention( | ||||||
| q=q, | ||||||
| k=k, | ||||||
| v=v, | ||||||
|
|
@@ -240,16 +294,49 @@ def test_tree_attn_correctness() -> None: | |||||
| slot_mapping=tree_slot_mapping, | ||||||
| seqlen_k=seqlen_k, | ||||||
| backend=_Backend.TREE_ATTN, | ||||||
| spec_token_tree=spec_token_tree, | ||||||
| spec_token_tree=spec_token_tree[0], | ||||||
| num_spec_tokens=tree_size_q - 1, | ||||||
| ).view(batch_size, -1, num_heads, dim_per_head) | ||||||
| bench_desc="tree_attention", | ||||||
| ) | ||||||
| tree_attn_output = forward_attn_output1.view(batch_size, -1, num_heads, dim_per_head) | ||||||
|
|
||||||
| tree_time = tree_attn_measurement.mean | ||||||
| print(f"tree_attention average time: {tree_time:.6f} seconds") | ||||||
|
|
||||||
| forward_attn_output2, fa2_attn_measurement = forward_attention( | ||||||
| q=q, | ||||||
| k=k, | ||||||
| v=v, | ||||||
| kv_cache=kv_cache, | ||||||
| block_table=block_table, | ||||||
| slot_mapping=tree_slot_mapping, | ||||||
| seqlen_k=seqlen_k, | ||||||
| backend=_Backend.FLASH_ATTN_VLLM_V1, | ||||||
| tree_mask=tree_attn_mask[1], | ||||||
| num_spec_tokens=tree_size_q - 1, | ||||||
| bench_desc="fa2_tree_attention", | ||||||
| ) | ||||||
| fa2_tree_attn_output = forward_attn_output2.view(batch_size, -1, num_heads, dim_per_head) | ||||||
|
|
||||||
| fa2_time = fa2_attn_measurement.mean | ||||||
| print(f"fa2_tree_attention average time: {fa2_time:.6f} seconds") | ||||||
|
|
||||||
| # Calculate speedup | ||||||
| speedup = tree_time / fa2_time | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The calculation
Suggested change
|
||||||
| print(f"Speedup (tree/fa2): {speedup:.2f}x") | ||||||
| if speedup > 1: | ||||||
| print(f"fa2_tree_attention is {speedup:.2f}x faster\n") | ||||||
| else: | ||||||
| print(f"tree_attention is {1/speedup:.2f}x faster\n") | ||||||
|
|
||||||
| assert torch.allclose(tree_attn_output, fa2_tree_attn_output, atol=7.81e-3) | ||||||
|
|
||||||
| # Verify that the chain attention output for each | ||||||
| # branch of the tree (computed using FA3) matches | ||||||
| # the tree attention output. | ||||||
| for q_index in range(tree_size_q): | ||||||
| # Get the q, k, and v for the branch. | ||||||
| branch_mask = tree_attn_mask[q_index, :] | ||||||
| branch_mask = tree_attn_mask[0][q_index, :] | ||||||
| branch_indices = torch.nonzero(branch_mask, | ||||||
| as_tuple=True)[0] | ||||||
| q_len = branch_indices.shape[0] | ||||||
|
|
@@ -268,7 +355,7 @@ def test_tree_attn_correctness() -> None: | |||||
| branch_positions, block_table, block_size) | ||||||
|
|
||||||
| # Compute flash attention for the branch. | ||||||
| flash_attn_output = forward_attention( | ||||||
| forward_attn_output3, _ = forward_attention( | ||||||
| q=q_branch, | ||||||
| k=k_branch, | ||||||
| v=v_branch, | ||||||
|
|
@@ -277,8 +364,8 @@ def test_tree_attn_correctness() -> None: | |||||
| slot_mapping=branch_slot_mapping, | ||||||
| seqlen_k=sequence_position + q_len, | ||||||
| backend=_Backend.FLASH_ATTN_VLLM_V1, | ||||||
| ).view(batch_size, -1, num_heads, dim_per_head) | ||||||
|
|
||||||
| ) | ||||||
| flash_attn_output = forward_attn_output3.view(batch_size, -1, num_heads, dim_per_head) | ||||||
| # Compare the outputs. | ||||||
| assert torch.allclose( | ||||||
| tree_attn_output[:, branch_indices], | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dependency for
vllm-flash-attnhas been changed to a personal fork (samsung-cnct/flash-attention) from the officialvllm-project/flash-attention. This introduces a potential security risk and dependency issue. While this might be intentional for benchmarking purposes in a "DO_NOT_COMMIT" context, it's a critical change that should not be merged into a production branch. Using unverified third-party forks can expose the project to vulnerabilities.