diff --git a/tests/kernels/test_flex_attention.py b/tests/kernels/test_flex_attention.py index b04f5c62c79b..69113b57c74e 100644 --- a/tests/kernels/test_flex_attention.py +++ b/tests/kernels/test_flex_attention.py @@ -14,6 +14,7 @@ create_vllm_config, ) from vllm.v1.attention.backends.flex_attention import ( + BlockSparsityHint, FlexAttentionMetadataBuilder, physical_to_logical_mapping, ) @@ -223,5 +224,55 @@ def test_physical_to_logical_mapping_handles_reused_blocks(): assert out2[0, 2].item() == 1 +@pytest.mark.skipif( + not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION, + reason="CUDA not available or PyTorch version < 2.9", +) +def test_block_sparsity_hint_prunes_blocks(): + """Test that BlockSparsityHint prunes KV blocks from the direct build path. + + Uses a hint that only keeps the diagonal (q_block == kv_block) to verify + that off-diagonal blocks are excluded from the resulting BlockMask. + """ + device = torch.device("cuda") + + vllm_config = create_vllm_config( + model_name="facebook/opt-125m", + block_size=16, + max_model_len=1024, + ) + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + batch_spec = BatchSpec( + seq_lens=[256], + query_lens=[256], + name="test_sparsity_hint", + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device + ) + + builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device) + + metadata_no_hint = builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + metadata_no_hint.block_mask = metadata_no_hint._build_block_mask_direct() + assert metadata_no_hint.block_mask.kv_num_blocks.max().item() > 1 + + def diagonal_hint(q_block_idx, kv_block_idx, block_size): + return q_block_idx == kv_block_idx + + metadata_with_hint = builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + metadata_with_hint.block_sparsity_hint = BlockSparsityHint( + hint_fn=diagonal_hint, + ) + metadata_with_hint.block_mask = metadata_with_hint._build_block_mask_direct() + assert metadata_with_hint.block_mask.kv_num_blocks.max().item() <= 1 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 25bb31ffcbde..23fb7d9e9111 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -3,9 +3,10 @@ """Attention layer with FlexAttention.""" import math +from collections.abc import Callable from dataclasses import dataclass from functools import cached_property -from typing import ClassVar +from typing import ClassVar, NamedTuple import torch import torch._dynamo.decorators @@ -294,6 +295,27 @@ def causal_mask_mod( return q_idx >= kv_idx +# Type alias for the block sparsity hint callable signature. +_block_sparsity_hint_signature = Callable[ + [torch.Tensor, torch.Tensor, int], torch.Tensor +] + + +class BlockSparsityHint(NamedTuple): + """This prunes KV blocks from the BlockMask before the flex_attention kernel + is invoked, so that blocks that are fully masked never get loaded. + Use this with custom mask_mods that are sparse to avoid + the kernel iterating over all KV blocks unnecessarily. + + Attributes: + hint_fn: (q_block_idx [num_tokens, 1], kv_block_idx [1, num_kv_blocks], + block_size int) -> bool Tensor [num_tokens, num_kv_blocks]. + Returns True for block pairs that may contain non-masked elements. + """ + + hint_fn: _block_sparsity_hint_signature + + @dataclass class FlexAttentionMetadata: causal: bool @@ -335,6 +357,7 @@ class FlexAttentionMetadata: transformed_score_mod: _score_mod_signature | None = None sliding_window: int | None = None mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None + block_sparsity_hint: BlockSparsityHint | None = None @cached_property def logical_block_ids(self): @@ -378,7 +401,7 @@ def _convert_physical_to_logical( return is_valid, logical_q_idx, logical_kv_idx - def get_causal_mask_mod(self) -> _mask_mod_signature: + def get_paged_mask_mod(self) -> _mask_mod_signature: """Creates the mask_mod function for FlexAttention. This function creates the combined mask mod function that handles: @@ -504,8 +527,9 @@ def final_mask_mod( def get_mask_mod(self): # Stage-1: initialize the base mask_mod # (causal mask for decoder or bidirectional mask for encoder) - if self.causal: - mask_mod = self.get_causal_mask_mod() + has_custom_mask = self.logical_mask_mod is not causal_mask_mod + if self.causal or has_custom_mask: + mask_mod = self.get_paged_mask_mod() else: mask_mod = self.get_bidirectional_mask_mod() # stage-2: add external mask_mod for special attention during @@ -591,7 +615,9 @@ def _build_block_mask_direct(self) -> BlockMask: self.doc_ids, : cdiv(self.max_seq_len, self.block_size) ] - if self.sliding_window and self.causal: + custom_hint = self.block_sparsity_hint is not None + + if self.sliding_window or custom_hint: device = used_pages.device assert self.doc_ids is not None token_indices = torch.arange( @@ -602,10 +628,24 @@ def _build_block_mask_direct(self) -> BlockMask: - self.query_start_loc[self.doc_ids] + self.decode_offset[self.doc_ids] ) - min_kv_idx = torch.clamp(logical_q_idx - (self.sliding_window - 1), min=0) - min_block_idx = min_kv_idx // self.block_size - sliding_mask = self.logical_block_ids >= min_block_idx[:, None] - used_pages.masked_fill_(~sliding_mask, 0) + + if self.sliding_window: + assert self.sliding_window is not None + min_kv_idx = torch.clamp( + logical_q_idx - (self.sliding_window - 1), min=0 + ) + min_block_idx = min_kv_idx // self.block_size + sliding_mask = self.logical_block_ids >= min_block_idx[:, None] + used_pages.masked_fill_(~sliding_mask, 0) + if custom_hint: + assert self.block_sparsity_hint is not None + q_block_idx = logical_q_idx // self.block_size + hint_mask = self.block_sparsity_hint.hint_fn( + q_block_idx[:, None], + self.logical_block_ids[None, :], + self.block_size, + ) + used_pages.masked_fill_(~hint_mask, 0) used_pages_padded = pad_to_multiple( used_pages, multiple=self.q_block_size, dim=0 @@ -660,11 +700,6 @@ def __post_init__(self): self.mask_mod = self.get_mask_mod() self.transformed_score_mod = self.get_transformed_score_mod() - if self.direct_build and self.causal: - self.block_mask = self._build_block_mask_direct() - else: - self.block_mask = self.build_block_mask() - class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]): def __init__( @@ -770,6 +805,8 @@ class FlexAttentionImpl(AttentionImpl): alibi_slopes: torch.Tensor | None logits_soft_cap: float | None mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None + logical_mask_mod: _mask_mod_signature | None = None + block_sparsity_hint: BlockSparsityHint | None = None def __init__( self, @@ -907,8 +944,25 @@ def forward( attn_metadata.mask_mod = attn_metadata.get_mask_mod() needs_rebuild_block_mask = True - if needs_rebuild_block_mask: - if attn_metadata.direct_build and attn_metadata.causal: + layer_mask_mod = getattr(layer, "logical_mask_mod", None) + if ( + layer_mask_mod is not None + and attn_metadata.logical_mask_mod is not layer_mask_mod + ): + attn_metadata.logical_mask_mod = layer_mask_mod + attn_metadata.mask_mod = attn_metadata.get_mask_mod() + needs_rebuild_block_mask = True + + layer_hint = getattr(layer, "block_sparsity_hint", None) + if ( + layer_hint is not None + and attn_metadata.block_sparsity_hint is not layer_hint + ): + attn_metadata.block_sparsity_hint = layer_hint + needs_rebuild_block_mask = True + + if needs_rebuild_block_mask or attn_metadata.block_mask is None: + if attn_metadata.direct_build: attn_metadata.block_mask = attn_metadata._build_block_mask_direct() else: attn_metadata.block_mask = attn_metadata.build_block_mask()