Skip to content

[FlexAttention] allow custom mask mod#37692

Merged
zou3519 merged 1 commit intovllm-project:mainfrom
liangel-02:flex
Mar 24, 2026
Merged

[FlexAttention] allow custom mask mod#37692
zou3519 merged 1 commit intovllm-project:mainfrom
liangel-02:flex

Conversation

@liangel-02
Copy link
Copy Markdown
Contributor

@liangel-02 liangel-02 commented Mar 20, 2026

updating FlexAttention impl to accept custom mask mod from users

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new block_sparsity_hint parameter to the FlexAttentionMetadata class and modifies the attention mechanism to allow for custom mask modifications. The changes aim to provide more flexibility in defining attention patterns, including support for custom sparsity hints. The code has been reviewed and a critical issue has been identified.

# (causal mask for decoder or bidirectional mask for encoder)
if self.causal:
has_custom_mask = self.logical_mask_mod is not causal_mask_mod
if self.causal or has_custom_mask:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The condition self.causal or has_custom_mask will always evaluate to True if has_custom_mask is True. This means that the code will always use self.get_causal_mask_mod() when a custom mask is present, regardless of the value of self.causal. This might not be the intended behavior, as the user might want to use a bidirectional mask with a custom modification. This could lead to unexpected or incorrect attention patterns.

To fix this, the logic should ensure that self.causal is only considered when a custom mask is not present. If a custom mask is present, it should override the causal mask behavior.

Suggested change
if self.causal or has_custom_mask:
if has_custom_mask:
mask_mod = self.logical_mask_mod
elif self.causal:
mask_mod = self.get_causal_mask_mod()
else:
mask_mod = self.get_bidirectional_mask_mod()

Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg do you think you can help review this?

causal_sliding_window = self.sliding_window and self.causal
custom_hint = self.block_sparsity_hint is not None

if causal_sliding_window or custom_hint:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: looking at this again do we even need causal to be true? I take it like always is but if we have a lookback window the same logical truncation applies :think:

self.mask_mod = self.get_mask_mod()
self.transformed_score_mod = self.get_transformed_score_mod()

if self.direct_build and self.causal:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirming; intentional right

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah inst of getting built in the post init i moved it to the forward to avoid needing to rebuild if its different per layer for custom mask mods

Copy link
Copy Markdown
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things;

  1. Describe the sparisty hint in more detail (its shape attributes, etc) Maybe a make a named_tuple
  2. Add a small test showing how to it is used, it seems like both at the per layer and per model

Can you confrim where direct build gets set these days? Is it expected that it will always work for custom mask mods

@liangel-02 liangel-02 force-pushed the flex branch 2 times, most recently from c489993 to 907d60d Compare March 24, 2026 01:57
@liangel-02 liangel-02 force-pushed the flex branch 4 times, most recently from 0ae34bf to 49cc770 Compare March 24, 2026 02:10
device = torch.device("cuda")

vllm_config = create_vllm_config(
model_name="meta-llama/Meta-Llama-3-8B",
Copy link
Copy Markdown
Contributor

@drisspg drisspg Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probs a smaller one for ci, well ig uess its never ran so choose whateve rmakes a small config

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, it's not clear to me what this is for

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i need vllm_config for FlexAttentionMetadataBuilder but the size of the model doesn't affect the actual test since its never loaded, but i changed it to a smaller config

Copy link
Copy Markdown
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

Copy link
Copy Markdown
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM minus the testing nit

@zou3519 zou3519 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 24, 2026
Signed-off-by: Angel Li <liangel@meta.com>
@zou3519 zou3519 merged commit 8c47fdf into vllm-project:main Mar 24, 2026
57 checks passed
RhizoNymph pushed a commit to RhizoNymph/vllm that referenced this pull request Mar 26, 2026
Signed-off-by: Angel Li <liangel@meta.com>
HenryTangDev pushed a commit to HenryTangMain/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Angel Li <liangel@meta.com>
malaiwah pushed a commit to malaiwah/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Angel Li <liangel@meta.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Angel Li <liangel@meta.com>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Angel Li <liangel@meta.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Angel Li <liangel@meta.com>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
Signed-off-by: Angel Li <liangel@meta.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
Signed-off-by: Angel Li <liangel@meta.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants