Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ if(VLLM_FLASH_ATTN_SRC_DIR)
else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f
GIT_REPOSITORY https://github.com/samsung-cnct/flash-attention.git
GIT_TAG feaab457d8d58243f19bf234a42a498647de0e6f
Comment on lines +40 to +41
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The dependency for vllm-flash-attn has been changed to a personal fork (samsung-cnct/flash-attention) from the official vllm-project/flash-attention. This introduces a potential security risk and dependency issue. While this might be intentional for benchmarking purposes in a "DO_NOT_COMMIT" context, it's a critical change that should not be merged into a production branch. Using unverified third-party forks can expose the project to vulnerabilities.

GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
14 changes: 9 additions & 5 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer
import time

model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
Expand Down Expand Up @@ -237,7 +238,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform")

if (attn_backend == "TREE_ATTN"):
if (attn_backend in ("TREE_ATTN", "FLASH_ATTN_VLLM_V1")):
pytest.skip("TREE_ATTN is tested separately in test_propose_tree"
"because it requires special input mocking.")

Expand Down Expand Up @@ -394,7 +395,9 @@ def create_deterministic_logits(token_ids):
# Verify all tokens match our expectations
assert torch.equal(result, expected_tokens)


@pytest.mark.parametrize(
"backend", [_Backend.TREE_ATTN, _Backend.FLASH_ATTN_VLLM_V1]
)
@pytest.mark.parametrize(
"spec_token_tree",
[
Expand All @@ -404,7 +407,7 @@ def create_deterministic_logits(token_ids):
[(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0),
(2, 1)], # Tree
])
def test_propose_tree(spec_token_tree):
def test_propose_tree(spec_token_tree, backend):
# Get GPU device.
device = torch.device(current_platform.device_type)

Expand Down Expand Up @@ -477,7 +480,7 @@ def create_deterministic_logits(token_ids, k: int):
proposer.attn_layer_names = ["layer.0"]

# Get the tree attention metadata builder.
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder_cls, _ = get_attention_backend(backend)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
Expand Down Expand Up @@ -517,14 +520,15 @@ def create_deterministic_logits(token_ids, k: int):
sampling_metadata = mock.MagicMock()

# Propose draft tokens.
start = time.perf_counter()
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
assert result.shape == (batch_size, num_speculative_tokens)

print(f"backend {backend} took {time.perf_counter()-start}")
# The tokens are expected to be consecutive integers starting
# from the base token IDs.
expected_tokens = base_token_ids[:, None] + torch.arange(
Expand Down
159 changes: 123 additions & 36 deletions tests/v1/spec_decode/test_tree_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
get_attention_backend)
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backends.flash_attn import (FlashAttentionMetadataBuilder, TreeMetadata)
from vllm.vllm_flash_attn.utils.benchmark import benchmark_forward


class MockAttentionLayer(torch.nn.Module):
Expand All @@ -35,7 +37,9 @@ def forward_attention(
seqlen_k: int,
backend: _Backend,
spec_token_tree: Optional[str] = None,
tree_mask: Optional[torch.Tensor] = None,
num_spec_tokens: int = 0,
bench_desc: str = "",
) -> torch.Tensor:
batch_size, q_len, num_heads, dim_per_head = q.shape
num_kv_heads = k.shape[-2]
Expand Down Expand Up @@ -86,10 +90,18 @@ def forward_attention(
)

# Build attention metadata.
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
if isinstance(builder, FlashAttentionMetadataBuilder):
fa_tree_metadata = TreeMetadata(mask=tree_mask.repeat(batch_size), lens=torch.arange(0, (batch_size+1)*num_spec_tokens, num_spec_tokens, dtype=torch.int32, device=q.device)) if tree_mask is not None else None
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
tree_metadata=fa_tree_metadata,
)
else:
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)

# Initialize the backend implementation.
instance = impl_cls(
Expand All @@ -107,6 +119,23 @@ def forward_attention(
key = k.view(-1, num_kv_heads, dim_per_head)
value = v.view(-1, num_kv_heads, dim_per_head)
output = torch.empty_like(query)

measurement=None

if bench_desc != "":
_, measurement = benchmark_forward(
instance.forward,
layer=layer,
query=query,
key=key,
value=value,
kv_cache=kv_cache.clone(),
attn_metadata=attn_metadata,
output=output,
desc=bench_desc,
verbose=False,
)

return instance.forward(
layer=layer,
query=query,
Expand All @@ -115,7 +144,7 @@ def forward_attention(
kv_cache=kv_cache.clone(),
attn_metadata=attn_metadata,
output=output,
)
), measurement


def test_tree_attn_correctness() -> None:
Expand All @@ -125,31 +154,56 @@ def test_tree_attn_correctness() -> None:
device = "cuda"
tree_attn_masks = {
# Chain.
"[(0,), (0, 0), (0, 0, 0)]":
torch.tensor(
[
[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1],
],
device=device,
dtype=torch.int32,
("[(0,), (0, 0), (0, 0, 0)]", (-1, 0, 1)):
(
torch.tensor(
[
[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1],
],
device=device,
dtype=torch.int32,
),
torch.tensor(
[
0b100,
0b110,
0b111,
],
device=device,
dtype=torch.uint64,
),
),
# Tree.
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]":
torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 1, 0, 1, 0, 0, 0],
[1, 1, 0, 0, 1, 0, 0],
[1, 0, 1, 0, 0, 1, 0],
[1, 0, 1, 0, 0, 0, 1],
],
device=device,
dtype=torch.int32,
("[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]", (-1, -1, 0, 0, 1, 1)):
(
torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 1, 0, 1, 0, 0, 0],
[1, 1, 0, 0, 1, 0, 0],
[1, 0, 1, 0, 0, 1, 0],
[1, 0, 1, 0, 0, 0, 1],
],
device=device,
dtype=torch.int32,
),
torch.tensor(
[
0b100000,
0b010000,
0b101000,
0b100100,
0b010010,
0b010001,
],
device=device,
dtype=torch.uint64
),
),
}

Expand All @@ -167,7 +221,7 @@ def test_tree_attn_correctness() -> None:
assert num_heads % num_kv_heads == 0

# Initialize q, k, and v.
tree_size_q = tree_attn_mask.shape[0]
tree_size_q = tree_attn_mask[0].shape[0]
seqlen_k = sequence_position + tree_size_q
q = torch.randn(
(batch_size, tree_size_q, num_heads, dim_per_head),
Expand Down Expand Up @@ -231,7 +285,7 @@ def test_tree_attn_correctness() -> None:
tree_positions, block_table, block_size)

# Compute attention for the tree.
tree_attn_output = forward_attention(
forward_attn_output1, tree_attn_measurement = forward_attention(
q=q,
k=k,
v=v,
Expand All @@ -240,16 +294,49 @@ def test_tree_attn_correctness() -> None:
slot_mapping=tree_slot_mapping,
seqlen_k=seqlen_k,
backend=_Backend.TREE_ATTN,
spec_token_tree=spec_token_tree,
spec_token_tree=spec_token_tree[0],
num_spec_tokens=tree_size_q - 1,
).view(batch_size, -1, num_heads, dim_per_head)
bench_desc="tree_attention",
)
tree_attn_output = forward_attn_output1.view(batch_size, -1, num_heads, dim_per_head)

tree_time = tree_attn_measurement.mean
print(f"tree_attention average time: {tree_time:.6f} seconds")

forward_attn_output2, fa2_attn_measurement = forward_attention(
q=q,
k=k,
v=v,
kv_cache=kv_cache,
block_table=block_table,
slot_mapping=tree_slot_mapping,
seqlen_k=seqlen_k,
backend=_Backend.FLASH_ATTN_VLLM_V1,
tree_mask=tree_attn_mask[1],
num_spec_tokens=tree_size_q - 1,
bench_desc="fa2_tree_attention",
)
fa2_tree_attn_output = forward_attn_output2.view(batch_size, -1, num_heads, dim_per_head)

fa2_time = fa2_attn_measurement.mean
print(f"fa2_tree_attention average time: {fa2_time:.6f} seconds")

# Calculate speedup
speedup = tree_time / fa2_time
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation speedup = tree_time / fa2_time could lead to a division by zero error if fa2_time is zero or very close to it. This can happen in benchmarks, especially with very fast operations or measurement noise. To prevent this, a small epsilon should be added to the denominator.

Suggested change
speedup = tree_time / fa2_time
speedup = tree_time / (fa2_time + 1e-9)

print(f"Speedup (tree/fa2): {speedup:.2f}x")
if speedup > 1:
print(f"fa2_tree_attention is {speedup:.2f}x faster\n")
else:
print(f"tree_attention is {1/speedup:.2f}x faster\n")

assert torch.allclose(tree_attn_output, fa2_tree_attn_output, atol=7.81e-3)

# Verify that the chain attention output for each
# branch of the tree (computed using FA3) matches
# the tree attention output.
for q_index in range(tree_size_q):
# Get the q, k, and v for the branch.
branch_mask = tree_attn_mask[q_index, :]
branch_mask = tree_attn_mask[0][q_index, :]
branch_indices = torch.nonzero(branch_mask,
as_tuple=True)[0]
q_len = branch_indices.shape[0]
Expand All @@ -268,7 +355,7 @@ def test_tree_attn_correctness() -> None:
branch_positions, block_table, block_size)

# Compute flash attention for the branch.
flash_attn_output = forward_attention(
forward_attn_output3, _ = forward_attention(
q=q_branch,
k=k_branch,
v=v_branch,
Expand All @@ -277,8 +364,8 @@ def test_tree_attn_correctness() -> None:
slot_mapping=branch_slot_mapping,
seqlen_k=sequence_position + q_len,
backend=_Backend.FLASH_ATTN_VLLM_V1,
).view(batch_size, -1, num_heads, dim_per_head)

)
flash_attn_output = forward_attn_output3.view(batch_size, -1, num_heads, dim_per_head)
# Compare the outputs.
assert torch.allclose(
tree_attn_output[:, branch_indices],
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/utils/fa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
get_scheduler_metadata)
get_scheduler_metadata,
tree_attention)
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
Expand Down
Loading