-
-
Notifications
You must be signed in to change notification settings - Fork 15k
[FlexAttention] allow custom mask mod #37692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,9 +3,10 @@ | |
| """Attention layer with FlexAttention.""" | ||
|
|
||
| import math | ||
| from collections.abc import Callable | ||
| from dataclasses import dataclass | ||
| from functools import cached_property | ||
| from typing import ClassVar | ||
| from typing import ClassVar, NamedTuple | ||
|
|
||
| import torch | ||
| import torch._dynamo.decorators | ||
|
|
@@ -294,6 +295,27 @@ def causal_mask_mod( | |
| return q_idx >= kv_idx | ||
|
|
||
|
|
||
| # Type alias for the block sparsity hint callable signature. | ||
| _block_sparsity_hint_signature = Callable[ | ||
| [torch.Tensor, torch.Tensor, int], torch.Tensor | ||
| ] | ||
|
|
||
|
|
||
| class BlockSparsityHint(NamedTuple): | ||
| """This prunes KV blocks from the BlockMask before the flex_attention kernel | ||
| is invoked, so that blocks that are fully masked never get loaded. | ||
| Use this with custom mask_mods that are sparse to avoid | ||
| the kernel iterating over all KV blocks unnecessarily. | ||
|
|
||
| Attributes: | ||
| hint_fn: (q_block_idx [num_tokens, 1], kv_block_idx [1, num_kv_blocks], | ||
| block_size int) -> bool Tensor [num_tokens, num_kv_blocks]. | ||
| Returns True for block pairs that may contain non-masked elements. | ||
| """ | ||
|
|
||
| hint_fn: _block_sparsity_hint_signature | ||
|
|
||
|
|
||
| @dataclass | ||
| class FlexAttentionMetadata: | ||
| causal: bool | ||
|
|
@@ -335,6 +357,7 @@ class FlexAttentionMetadata: | |
| transformed_score_mod: _score_mod_signature | None = None | ||
| sliding_window: int | None = None | ||
| mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None | ||
| block_sparsity_hint: BlockSparsityHint | None = None | ||
|
|
||
| @cached_property | ||
| def logical_block_ids(self): | ||
|
|
@@ -378,7 +401,7 @@ def _convert_physical_to_logical( | |
|
|
||
| return is_valid, logical_q_idx, logical_kv_idx | ||
|
|
||
| def get_causal_mask_mod(self) -> _mask_mod_signature: | ||
| def get_paged_mask_mod(self) -> _mask_mod_signature: | ||
| """Creates the mask_mod function for FlexAttention. | ||
|
|
||
| This function creates the combined mask mod function that handles: | ||
|
|
@@ -504,8 +527,9 @@ def final_mask_mod( | |
| def get_mask_mod(self): | ||
| # Stage-1: initialize the base mask_mod | ||
| # (causal mask for decoder or bidirectional mask for encoder) | ||
| if self.causal: | ||
| mask_mod = self.get_causal_mask_mod() | ||
| has_custom_mask = self.logical_mask_mod is not causal_mask_mod | ||
| if self.causal or has_custom_mask: | ||
| mask_mod = self.get_paged_mask_mod() | ||
| else: | ||
| mask_mod = self.get_bidirectional_mask_mod() | ||
| # stage-2: add external mask_mod for special attention during | ||
|
|
@@ -591,7 +615,9 @@ def _build_block_mask_direct(self) -> BlockMask: | |
| self.doc_ids, : cdiv(self.max_seq_len, self.block_size) | ||
| ] | ||
|
|
||
| if self.sliding_window and self.causal: | ||
| custom_hint = self.block_sparsity_hint is not None | ||
|
|
||
| if self.sliding_window or custom_hint: | ||
| device = used_pages.device | ||
| assert self.doc_ids is not None | ||
| token_indices = torch.arange( | ||
|
|
@@ -602,10 +628,24 @@ def _build_block_mask_direct(self) -> BlockMask: | |
| - self.query_start_loc[self.doc_ids] | ||
| + self.decode_offset[self.doc_ids] | ||
| ) | ||
| min_kv_idx = torch.clamp(logical_q_idx - (self.sliding_window - 1), min=0) | ||
| min_block_idx = min_kv_idx // self.block_size | ||
| sliding_mask = self.logical_block_ids >= min_block_idx[:, None] | ||
| used_pages.masked_fill_(~sliding_mask, 0) | ||
|
|
||
| if self.sliding_window: | ||
| assert self.sliding_window is not None | ||
| min_kv_idx = torch.clamp( | ||
| logical_q_idx - (self.sliding_window - 1), min=0 | ||
| ) | ||
| min_block_idx = min_kv_idx // self.block_size | ||
| sliding_mask = self.logical_block_ids >= min_block_idx[:, None] | ||
| used_pages.masked_fill_(~sliding_mask, 0) | ||
| if custom_hint: | ||
| assert self.block_sparsity_hint is not None | ||
| q_block_idx = logical_q_idx // self.block_size | ||
| hint_mask = self.block_sparsity_hint.hint_fn( | ||
| q_block_idx[:, None], | ||
| self.logical_block_ids[None, :], | ||
| self.block_size, | ||
| ) | ||
| used_pages.masked_fill_(~hint_mask, 0) | ||
|
|
||
| used_pages_padded = pad_to_multiple( | ||
| used_pages, multiple=self.q_block_size, dim=0 | ||
|
|
@@ -660,11 +700,6 @@ def __post_init__(self): | |
| self.mask_mod = self.get_mask_mod() | ||
| self.transformed_score_mod = self.get_transformed_score_mod() | ||
|
|
||
| if self.direct_build and self.causal: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. confirming; intentional right
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah inst of getting built in the post init i moved it to the forward to avoid needing to rebuild if its different per layer for custom mask mods |
||
| self.block_mask = self._build_block_mask_direct() | ||
| else: | ||
| self.block_mask = self.build_block_mask() | ||
|
|
||
|
|
||
| class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]): | ||
| def __init__( | ||
|
|
@@ -770,6 +805,8 @@ class FlexAttentionImpl(AttentionImpl): | |
| alibi_slopes: torch.Tensor | None | ||
| logits_soft_cap: float | None | ||
| mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None | ||
| logical_mask_mod: _mask_mod_signature | None = None | ||
| block_sparsity_hint: BlockSparsityHint | None = None | ||
|
|
||
| def __init__( | ||
| self, | ||
|
|
@@ -907,8 +944,25 @@ def forward( | |
| attn_metadata.mask_mod = attn_metadata.get_mask_mod() | ||
| needs_rebuild_block_mask = True | ||
|
|
||
| if needs_rebuild_block_mask: | ||
| if attn_metadata.direct_build and attn_metadata.causal: | ||
| layer_mask_mod = getattr(layer, "logical_mask_mod", None) | ||
| if ( | ||
| layer_mask_mod is not None | ||
| and attn_metadata.logical_mask_mod is not layer_mask_mod | ||
| ): | ||
| attn_metadata.logical_mask_mod = layer_mask_mod | ||
| attn_metadata.mask_mod = attn_metadata.get_mask_mod() | ||
| needs_rebuild_block_mask = True | ||
|
|
||
| layer_hint = getattr(layer, "block_sparsity_hint", None) | ||
| if ( | ||
| layer_hint is not None | ||
| and attn_metadata.block_sparsity_hint is not layer_hint | ||
| ): | ||
| attn_metadata.block_sparsity_hint = layer_hint | ||
| needs_rebuild_block_mask = True | ||
|
|
||
| if needs_rebuild_block_mask or attn_metadata.block_mask is None: | ||
| if attn_metadata.direct_build: | ||
| attn_metadata.block_mask = attn_metadata._build_block_mask_direct() | ||
| else: | ||
| attn_metadata.block_mask = attn_metadata.build_block_mask() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
self.causal or has_custom_maskwill always evaluate toTrueifhas_custom_maskisTrue. This means that the code will always useself.get_causal_mask_mod()when a custom mask is present, regardless of the value ofself.causal. This might not be the intended behavior, as the user might want to use a bidirectional mask with a custom modification. This could lead to unexpected or incorrect attention patterns.To fix this, the logic should ensure that
self.causalis only considered when a custom mask is not present. If a custom mask is present, it should override the causal mask behavior.