Skip to content

Use Boolean attention mask and enable FusedSDPA slicing for long sequences#1149

Closed
afierka-intel wants to merge 5 commits into
vllm-project:mainfrom
afierka-intel:port-boolean-attn-mask-to-main
Closed

Use Boolean attention mask and enable FusedSDPA slicing for long sequences#1149
afierka-intel wants to merge 5 commits into
vllm-project:mainfrom
afierka-intel:port-boolean-attn-mask-to-main

Conversation

@afierka-intel
Copy link
Copy Markdown
Collaborator

@afierka-intel afierka-intel commented Mar 12, 2026

Motivation

This PR enables 256K context length support for models like Qwen3-30B-A3B-Thinking-2507 on Gaudi 2 with TP=1.
It ports two key patches from the aice branch to main:

  1. Boolean attention mask (port of aice PR Use Boolean attention mask #1032): Converts attention masks from float (bf16 with -inf) to boolean format, reducing memory usage and enabling proper interaction with the FusedSDPA kernel for long sequences.

  2. FusedSDPA slicing (port of aice PR Enable slicing for the BF16 FusedSDPA #1034): Splits the FusedSDPA kernel into smaller chunks for long sequences to fit into SRAM, improve TPC/MME pipelining, and reduce attention-mask usage for padded regions.

Changes

Boolean attention mask

  • _naive_prompt_attention: Handle bool attn_bias via masked_fill
  • _fsdpa_prompt_attention: Remove causal+attn_bias workaround (now supported with bool masks)
  • _make_attn_bias: Output ~attn_mask (bool) instead of float masked with -inf
  • _set_attn_bias: Output ~mask (bool)
  • _set_attn_bias_for_sliding_window: Use bool masks throughout
  • _set_attn_bias_for_chunked_attention: Use bool masks throughout
  • _set_block_mapping: Use mask < block_usage (bool) instead of float

FusedSDPA slicing

  • Add _sliced_fsdpa_fwd method to ModuleFusedSDPA for chunked attention computation
  • Add _setup_slicing for configuring chunk parameters from env vars
  • New env vars: VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD, VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE, VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS
  • Only active with linear bucketing strategy

Dependencies

Testing

  • Tested functionally on Gaudi 2 with Qwen3-30B-A3B-Thinking-2507 at 256K context

Ref: GAUDISW-245533
Related: aice PRs #1032, #1034
Co-author: @yangulei

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR targets long-context (up to 256K) execution on Gaudi 2 by switching attention masks to boolean format and adding optional slicing for the FusedSDPA kernel under linear bucketing.

Changes:

  • Convert multiple attention-bias/masking paths from float -inf masks to boolean masks.
  • Introduce FusedSDPA slicing (chunked execution) controlled by new environment variables and gated to linear bucketing.
  • Document and expose new env flags for tuning slicing behavior.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
vllm_gaudi/v1/worker/hpu_model_runner.py Switches several attention bias/mask constructions to boolean masks.
vllm_gaudi/extension/utils.py Adds chunked (“sliced”) FusedSDPA forward path and env-based configuration.
vllm_gaudi/extension/ops.py Updates naive attention to handle boolean masks and removes a causal+mask workaround in one path.
vllm_gaudi/extension/features.py Registers new env vars for FusedSDPA slicing.
docs/configuration/env_variables.md Documents new slicing-related env vars and behavior.

You can also share your feedback on Copilot code review. Take the survey.

if not enable_slicing and slice_thld > 0:
logger().warning('Invalid slice sequence length threshold, the threshold should be '
f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.')
slice_thld = slice_thld_default
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user sets VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD to a positive value below the default, the code logs that it is “falling back to default” but slicing remains disabled because enable_slicing is never re-enabled after resetting slice_thld. After assigning slice_thld = slice_thld_default, either set enable_slicing back to True (since prerequisites were met) or restructure the validation to compute enable_slicing only after sanitizing slice_thld.

Suggested change
slice_thld = slice_thld_default
slice_thld = slice_thld_default
enable_slicing = True

Copilot uses AI. Check for mistakes.
Comment on lines +186 to +189
if is_causal and attn_mask is not None:
# TODO: causal + attn_bias is not yet supported
is_causal = False
valid_sequence_lengths = None
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description states the causal+attn_bias workaround is removed now that boolean masks are supported, but this ModuleFusedSDPA.forward path still forcibly disables is_causal when attn_mask is present. If boolean masking is intended to enable is_causal + mask in the fused kernel, this block should be removed or updated to only apply to unsupported mask dtypes/shapes; otherwise the description should be updated to reflect that the workaround still exists here.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, if the threshold is met then we do not enter here so it's not that misleading

Comment on lines +360 to +365
if attn_bias.dtype == torch.bool:
attn_weights = attn_weights.masked_fill(~attn_bias, float("-inf"))
else:
if attn_weights.dtype != attn_bias.dtype:
attn_bias = attn_bias.to(dtype=attn_weights.dtype)
attn_weights.add_(attn_bias)
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the boolean-mask path, masked_fill creates a new tensor and can significantly increase peak memory for large attention matrices. Prefer the in-place variant (attn_weights.masked_fill_(...)) when it’s safe in this function’s flow to avoid an extra allocation.

Copilot uses AI. Check for mistakes.
Comment on lines +106 to +110
| Parameter name | Description | Default value |
| ---------------------------------------- | -------------------------------------------------------------------------------------------- | ------------------------------------------ |
| `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD` | KV length threshold above which slicing is applied. Set to `-1` to disable slicing. | `min(max_num_batched_tokens, 8192)` |
| `VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE` | Chunk size for `q_len` and `kv_len` in each chunk. Rounded up to the next multiple of 1024. | `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD // 2` |
| `VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS` | Places each chunk in a separate graph to reduce compilation time. | `true` for lazy mode and `false` otherwise |
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The markdown table uses || at the start of each row, which typically renders as an extra empty column (or may not render as intended depending on the markdown renderer). Use single | delimiters for standard GitHub-flavored markdown tables.

Copilot uses AI. Check for mistakes.
Comment on lines +250 to +264
chunk_res = torch.ops.hpu.sdpa_recomp_fwd(
q_chunk,
k_chunk,
v_chunk,
mask_chunk,
dropout_p,
scale,
is_causal_chunk,
True, # requires_backward
softmax_mode,
None, # valid_seq_len
padding_side,
)
chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32)
for x in (chunk_res[:3]))
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sliced path hard-codes requires_backward=True and unconditionally casts chunk outputs (including chunk_out) to float32. For long contexts this can materially increase memory and runtime overhead in inference. If the kernel/API allows it, consider disabling backward requirements and limiting float32 to the accumulator numerics (e.g., keep m/linv and the running mix in float32, but avoid extra float32 copies of each chunk_out unless needed).

Copilot uses AI. Check for mistakes.
@czhu15
Copy link
Copy Markdown
Collaborator

czhu15 commented Mar 13, 2026

add @yangulei for review too.

afierka-intel and others added 3 commits March 17, 2026 10:07
Port of PR vllm-project#1032 from aice branch to main.

Converts attention masks from float (bf16 with -inf values) to boolean
format. This reduces memory usage (bool vs bf16) and is required for
the FusedSDPA slicing feature to get valid m and linv outputs from
the kernel.

Key changes:
- _naive_prompt_attention: handle bool attn_bias via masked_fill
- _fsdpa_prompt_attention: remove causal+attn_bias workaround (now supported)
- _make_attn_bias: output ~attn_mask (bool) instead of float masked with -inf
- _set_attn_bias: output ~mask (bool)
- _set_attn_bias_for_sliding_window: use bool masks throughout
- _set_attn_bias_for_chunked_attention: use bool masks throughout
- _set_block_mapping: use mask < block_usage (bool) instead of float

Ref: GAUDISW-245533

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Artur Fierka <artur.fierka@intel.com>
Co-authored-by: yangulei <24203353+yangulei@users.noreply.github.com>
Port of PR vllm-project#1034 from aice branch to main.

Splits FusedSDPA kernel into smaller chunks for long sequences to:
- Fit chunks into SRAM for better performance
- Improve TPC/MME pipelining
- Reduce attention-mask usage for padded regions

New env vars:
- VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD: KV length threshold for slicing
- VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE: chunk size (rounded to 1024)
- VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS: graph break control

Only active with linear bucketing strategy and boolean attention masks.

Depends on: Boolean attention mask (port of vllm-project#1032)
Ref: GAUDISW-245533

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Artur Fierka <artur.fierka@intel.com>
Co-authored-by: yangulei <24203353+yangulei@users.noreply.github.com>
The boolean attention mask change (port of vllm-project#1032) creates bool block_bias
in _set_block_mapping (True=valid, False=masked), but pipelined_pa()
blindly cast it to float (True→1.0, False→0.0) and added it to attention
scores. This broke masking: valid positions got +1.0 noise and masked
positions got no penalty (should be -inf).

Fix: detect bool dtype and convert to proper additive bias (0.0/-inf)
before use in both the block_softmax kernel path and the manual fallback.

Co-authored-by: yangulei <24203353+yangulei@users.noreply.github.com>

Signed-off-by: Artur Fierka <artur.fierka@intel.com>
@afierka-intel afierka-intel force-pushed the port-boolean-attn-mask-to-main branch from 960c194 to 2e875e3 Compare March 17, 2026 09:28
Upstream vLLM renamed get_eagle3_aux_hidden_state_layers() to
get_eagle3_default_aux_hidden_state_layers(). Update the call in
hpu_model_runner.py to match.

Co-authored-by: yangulei <24203353+yangulei@users.noreply.github.com>

Signed-off-by: Artur Fierka <artur.fierka@intel.com>
@afierka-intel afierka-intel force-pushed the port-boolean-attn-mask-to-main branch from 2e875e3 to 0cf717e Compare March 18, 2026 07:39
@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

@afierka-intel
Copy link
Copy Markdown
Collaborator Author

The PR depends on #1154. Let's merge the dependency first. I'll rebase this PR then.

if is_causal and attn_mask is not None:
# TODO: causal + attn_bias is not yet supported
is_causal = False
valid_sequence_lengths = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole if is_causal ... could be removed because we should be able to support all of that. Slawomir Laba can help if this doesn't work. If there are issues then it's not needed though so do not block the work here for the enablement of it.

q_start = max(q_start, 0)
q_end = q_len - q_chunk_idx * self.chunk_size
q_chunk_size = q_end - q_start
q_chunk = q[..., q_start:q_end, :].clone() if self.with_graph_breaks else q[..., q_start:q_end, :]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've never seen ... operator, TIL. I trust it's needed here

return output.to(q.dtype)

def _setup_slicing(self) -> bool:
from vllm_gaudi.extension.bucketing.common import get_bucketing_manager
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The imports inside are needed because they are not seen otherwise? If not we could place them outside.

Copy link
Copy Markdown
Collaborator

@kamil-kaczor kamil-kaczor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added few nitpicks but if this passes CI + customer benchmark tests like we've discussed on the call, this can go in imo

logger().warning('Bucketing manager is not initialized, slicing in FSDPA will be disabled.')
return False

from vllm_gaudi.extension.bucketing.linear import LinearBucketingStrategy
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The imports inside are needed because they are not seen otherwise? If not we could place them outside.

max_ctx_pad = int(os.getenv("VLLM_PROMPT_CTX_BUCKET_PAD_MAX", str(max_ctx_pad_default)))
self.num_padded_ctx_chunks = math.ceil(max_ctx_pad * block_size / self.chunk_size)

import habana_frameworks.torch as ht
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The imports inside are needed because they are not seen otherwise? If not we could place them outside.

causal_mask = torch.triu(causal_mask, diagonal=1).unsqueeze(0)
attn_mask[:, :, context_len:].logical_or_(causal_mask)
attn_mask = attn_mask.to(dtype).masked_fill_(attn_mask, -math.inf)
attn_mask = ~attn_mask
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice but any arithmetic on the mask will fail later on so make sure it's well tested

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants