[FlexAttention] allow custom mask mod#37692
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new block_sparsity_hint parameter to the FlexAttentionMetadata class and modifies the attention mechanism to allow for custom mask modifications. The changes aim to provide more flexibility in defining attention patterns, including support for custom sparsity hints. The code has been reviewed and a critical issue has been identified.
| # (causal mask for decoder or bidirectional mask for encoder) | ||
| if self.causal: | ||
| has_custom_mask = self.logical_mask_mod is not causal_mask_mod | ||
| if self.causal or has_custom_mask: |
There was a problem hiding this comment.
The condition self.causal or has_custom_mask will always evaluate to True if has_custom_mask is True. This means that the code will always use self.get_causal_mask_mod() when a custom mask is present, regardless of the value of self.causal. This might not be the intended behavior, as the user might want to use a bidirectional mask with a custom modification. This could lead to unexpected or incorrect attention patterns.
To fix this, the logic should ensure that self.causal is only considered when a custom mask is not present. If a custom mask is present, it should override the causal mask behavior.
| if self.causal or has_custom_mask: | |
| if has_custom_mask: | |
| mask_mod = self.logical_mask_mod | |
| elif self.causal: | |
| mask_mod = self.get_causal_mask_mod() | |
| else: | |
| mask_mod = self.get_bidirectional_mask_mod() |
LucasWilkinson
left a comment
There was a problem hiding this comment.
@drisspg do you think you can help review this?
| causal_sliding_window = self.sliding_window and self.causal | ||
| custom_hint = self.block_sparsity_hint is not None | ||
|
|
||
| if causal_sliding_window or custom_hint: |
There was a problem hiding this comment.
nit: looking at this again do we even need causal to be true? I take it like always is but if we have a lookback window the same logical truncation applies :think:
| self.mask_mod = self.get_mask_mod() | ||
| self.transformed_score_mod = self.get_transformed_score_mod() | ||
|
|
||
| if self.direct_build and self.causal: |
There was a problem hiding this comment.
confirming; intentional right
There was a problem hiding this comment.
yeah inst of getting built in the post init i moved it to the forward to avoid needing to rebuild if its different per layer for custom mask mods
drisspg
left a comment
There was a problem hiding this comment.
A few things;
- Describe the sparisty hint in more detail (its shape attributes, etc) Maybe a make a named_tuple
- Add a small test showing how to it is used, it seems like both at the per layer and per model
Can you confrim where direct build gets set these days? Is it expected that it will always work for custom mask mods
c489993 to
907d60d
Compare
0ae34bf to
49cc770
Compare
tests/kernels/test_flex_attention.py
Outdated
| device = torch.device("cuda") | ||
|
|
||
| vllm_config = create_vllm_config( | ||
| model_name="meta-llama/Meta-Llama-3-8B", |
There was a problem hiding this comment.
probs a smaller one for ci, well ig uess its never ran so choose whateve rmakes a small config
There was a problem hiding this comment.
+1, it's not clear to me what this is for
There was a problem hiding this comment.
i need vllm_config for FlexAttentionMetadataBuilder but the size of the model doesn't affect the actual test since its never loaded, but i changed it to a smaller config
zou3519
left a comment
There was a problem hiding this comment.
LGTM minus the testing nit
Signed-off-by: Angel Li <liangel@meta.com>
Signed-off-by: Angel Li <liangel@meta.com>
Signed-off-by: Angel Li <liangel@meta.com>
Signed-off-by: Angel Li <liangel@meta.com> Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Signed-off-by: Angel Li <liangel@meta.com>
Signed-off-by: Angel Li <liangel@meta.com> Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Angel Li <liangel@meta.com> Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
Signed-off-by: Angel Li <liangel@meta.com>
Signed-off-by: Angel Li <liangel@meta.com> Signed-off-by: Vinay Damodaran <vrdn@hey.com>
updating FlexAttention impl to accept custom mask mod from users