Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions tests/kernels/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
create_vllm_config,
)
from vllm.v1.attention.backends.flex_attention import (
BlockSparsityHint,
FlexAttentionMetadataBuilder,
physical_to_logical_mapping,
)
Expand Down Expand Up @@ -223,5 +224,55 @@ def test_physical_to_logical_mapping_handles_reused_blocks():
assert out2[0, 2].item() == 1


@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION,
reason="CUDA not available or PyTorch version < 2.9",
)
def test_block_sparsity_hint_prunes_blocks():
"""Test that BlockSparsityHint prunes KV blocks from the direct build path.

Uses a hint that only keeps the diagonal (q_block == kv_block) to verify
that off-diagonal blocks are excluded from the resulting BlockMask.
"""
device = torch.device("cuda")

vllm_config = create_vllm_config(
model_name="facebook/opt-125m",
block_size=16,
max_model_len=1024,
)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)

batch_spec = BatchSpec(
seq_lens=[256],
query_lens=[256],
name="test_sparsity_hint",
)

common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device
)

builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device)

metadata_no_hint = builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
metadata_no_hint.block_mask = metadata_no_hint._build_block_mask_direct()
assert metadata_no_hint.block_mask.kv_num_blocks.max().item() > 1

def diagonal_hint(q_block_idx, kv_block_idx, block_size):
return q_block_idx == kv_block_idx

metadata_with_hint = builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
metadata_with_hint.block_sparsity_hint = BlockSparsityHint(
hint_fn=diagonal_hint,
)
metadata_with_hint.block_mask = metadata_with_hint._build_block_mask_direct()
assert metadata_with_hint.block_mask.kv_num_blocks.max().item() <= 1


if __name__ == "__main__":
pytest.main([__file__])
86 changes: 70 additions & 16 deletions vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
"""Attention layer with FlexAttention."""

import math
from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property
from typing import ClassVar
from typing import ClassVar, NamedTuple

import torch
import torch._dynamo.decorators
Expand Down Expand Up @@ -294,6 +295,27 @@ def causal_mask_mod(
return q_idx >= kv_idx


# Type alias for the block sparsity hint callable signature.
_block_sparsity_hint_signature = Callable[
[torch.Tensor, torch.Tensor, int], torch.Tensor
]


class BlockSparsityHint(NamedTuple):
"""This prunes KV blocks from the BlockMask before the flex_attention kernel
is invoked, so that blocks that are fully masked never get loaded.
Use this with custom mask_mods that are sparse to avoid
the kernel iterating over all KV blocks unnecessarily.

Attributes:
hint_fn: (q_block_idx [num_tokens, 1], kv_block_idx [1, num_kv_blocks],
block_size int) -> bool Tensor [num_tokens, num_kv_blocks].
Returns True for block pairs that may contain non-masked elements.
"""

hint_fn: _block_sparsity_hint_signature


@dataclass
class FlexAttentionMetadata:
causal: bool
Expand Down Expand Up @@ -335,6 +357,7 @@ class FlexAttentionMetadata:
transformed_score_mod: _score_mod_signature | None = None
sliding_window: int | None = None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
block_sparsity_hint: BlockSparsityHint | None = None

@cached_property
def logical_block_ids(self):
Expand Down Expand Up @@ -378,7 +401,7 @@ def _convert_physical_to_logical(

return is_valid, logical_q_idx, logical_kv_idx

def get_causal_mask_mod(self) -> _mask_mod_signature:
def get_paged_mask_mod(self) -> _mask_mod_signature:
"""Creates the mask_mod function for FlexAttention.

This function creates the combined mask mod function that handles:
Expand Down Expand Up @@ -504,8 +527,9 @@ def final_mask_mod(
def get_mask_mod(self):
# Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder)
if self.causal:
mask_mod = self.get_causal_mask_mod()
has_custom_mask = self.logical_mask_mod is not causal_mask_mod
if self.causal or has_custom_mask:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The condition self.causal or has_custom_mask will always evaluate to True if has_custom_mask is True. This means that the code will always use self.get_causal_mask_mod() when a custom mask is present, regardless of the value of self.causal. This might not be the intended behavior, as the user might want to use a bidirectional mask with a custom modification. This could lead to unexpected or incorrect attention patterns.

To fix this, the logic should ensure that self.causal is only considered when a custom mask is not present. If a custom mask is present, it should override the causal mask behavior.

Suggested change
if self.causal or has_custom_mask:
if has_custom_mask:
mask_mod = self.logical_mask_mod
elif self.causal:
mask_mod = self.get_causal_mask_mod()
else:
mask_mod = self.get_bidirectional_mask_mod()

mask_mod = self.get_paged_mask_mod()
else:
mask_mod = self.get_bidirectional_mask_mod()
# stage-2: add external mask_mod for special attention during
Expand Down Expand Up @@ -591,7 +615,9 @@ def _build_block_mask_direct(self) -> BlockMask:
self.doc_ids, : cdiv(self.max_seq_len, self.block_size)
]

if self.sliding_window and self.causal:
custom_hint = self.block_sparsity_hint is not None

if self.sliding_window or custom_hint:
device = used_pages.device
assert self.doc_ids is not None
token_indices = torch.arange(
Expand All @@ -602,10 +628,24 @@ def _build_block_mask_direct(self) -> BlockMask:
- self.query_start_loc[self.doc_ids]
+ self.decode_offset[self.doc_ids]
)
min_kv_idx = torch.clamp(logical_q_idx - (self.sliding_window - 1), min=0)
min_block_idx = min_kv_idx // self.block_size
sliding_mask = self.logical_block_ids >= min_block_idx[:, None]
used_pages.masked_fill_(~sliding_mask, 0)

if self.sliding_window:
assert self.sliding_window is not None
min_kv_idx = torch.clamp(
logical_q_idx - (self.sliding_window - 1), min=0
)
min_block_idx = min_kv_idx // self.block_size
sliding_mask = self.logical_block_ids >= min_block_idx[:, None]
used_pages.masked_fill_(~sliding_mask, 0)
if custom_hint:
assert self.block_sparsity_hint is not None
q_block_idx = logical_q_idx // self.block_size
hint_mask = self.block_sparsity_hint.hint_fn(
q_block_idx[:, None],
self.logical_block_ids[None, :],
self.block_size,
)
used_pages.masked_fill_(~hint_mask, 0)

used_pages_padded = pad_to_multiple(
used_pages, multiple=self.q_block_size, dim=0
Expand Down Expand Up @@ -660,11 +700,6 @@ def __post_init__(self):
self.mask_mod = self.get_mask_mod()
self.transformed_score_mod = self.get_transformed_score_mod()

if self.direct_build and self.causal:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirming; intentional right

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah inst of getting built in the post init i moved it to the forward to avoid needing to rebuild if its different per layer for custom mask mods

self.block_mask = self._build_block_mask_direct()
else:
self.block_mask = self.build_block_mask()


class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]):
def __init__(
Expand Down Expand Up @@ -770,6 +805,8 @@ class FlexAttentionImpl(AttentionImpl):
alibi_slopes: torch.Tensor | None
logits_soft_cap: float | None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
logical_mask_mod: _mask_mod_signature | None = None
block_sparsity_hint: BlockSparsityHint | None = None

def __init__(
self,
Expand Down Expand Up @@ -907,8 +944,25 @@ def forward(
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
needs_rebuild_block_mask = True

if needs_rebuild_block_mask:
if attn_metadata.direct_build and attn_metadata.causal:
layer_mask_mod = getattr(layer, "logical_mask_mod", None)
if (
layer_mask_mod is not None
and attn_metadata.logical_mask_mod is not layer_mask_mod
):
attn_metadata.logical_mask_mod = layer_mask_mod
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
needs_rebuild_block_mask = True

layer_hint = getattr(layer, "block_sparsity_hint", None)
if (
layer_hint is not None
and attn_metadata.block_sparsity_hint is not layer_hint
):
attn_metadata.block_sparsity_hint = layer_hint
needs_rebuild_block_mask = True

if needs_rebuild_block_mask or attn_metadata.block_mask is None:
if attn_metadata.direct_build:
attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
else:
attn_metadata.block_mask = attn_metadata.build_block_mask()
Expand Down
Loading