Use Boolean attention mask and enable FusedSDPA slicing for long sequences#1149
Use Boolean attention mask and enable FusedSDPA slicing for long sequences#1149afierka-intel wants to merge 5 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR targets long-context (up to 256K) execution on Gaudi 2 by switching attention masks to boolean format and adding optional slicing for the FusedSDPA kernel under linear bucketing.
Changes:
- Convert multiple attention-bias/masking paths from float
-infmasks to boolean masks. - Introduce FusedSDPA slicing (chunked execution) controlled by new environment variables and gated to linear bucketing.
- Document and expose new env flags for tuning slicing behavior.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| vllm_gaudi/v1/worker/hpu_model_runner.py | Switches several attention bias/mask constructions to boolean masks. |
| vllm_gaudi/extension/utils.py | Adds chunked (“sliced”) FusedSDPA forward path and env-based configuration. |
| vllm_gaudi/extension/ops.py | Updates naive attention to handle boolean masks and removes a causal+mask workaround in one path. |
| vllm_gaudi/extension/features.py | Registers new env vars for FusedSDPA slicing. |
| docs/configuration/env_variables.md | Documents new slicing-related env vars and behavior. |
You can also share your feedback on Copilot code review. Take the survey.
| if not enable_slicing and slice_thld > 0: | ||
| logger().warning('Invalid slice sequence length threshold, the threshold should be ' | ||
| f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.') | ||
| slice_thld = slice_thld_default |
There was a problem hiding this comment.
If the user sets VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD to a positive value below the default, the code logs that it is “falling back to default” but slicing remains disabled because enable_slicing is never re-enabled after resetting slice_thld. After assigning slice_thld = slice_thld_default, either set enable_slicing back to True (since prerequisites were met) or restructure the validation to compute enable_slicing only after sanitizing slice_thld.
| slice_thld = slice_thld_default | |
| slice_thld = slice_thld_default | |
| enable_slicing = True |
| if is_causal and attn_mask is not None: | ||
| # TODO: causal + attn_bias is not yet supported | ||
| is_causal = False | ||
| valid_sequence_lengths = None |
There was a problem hiding this comment.
The PR description states the causal+attn_bias workaround is removed now that boolean masks are supported, but this ModuleFusedSDPA.forward path still forcibly disables is_causal when attn_mask is present. If boolean masking is intended to enable is_causal + mask in the fused kernel, this block should be removed or updated to only apply to unsupported mask dtypes/shapes; otherwise the description should be updated to reflect that the workaround still exists here.
There was a problem hiding this comment.
well, if the threshold is met then we do not enter here so it's not that misleading
| if attn_bias.dtype == torch.bool: | ||
| attn_weights = attn_weights.masked_fill(~attn_bias, float("-inf")) | ||
| else: | ||
| if attn_weights.dtype != attn_bias.dtype: | ||
| attn_bias = attn_bias.to(dtype=attn_weights.dtype) | ||
| attn_weights.add_(attn_bias) |
There was a problem hiding this comment.
For the boolean-mask path, masked_fill creates a new tensor and can significantly increase peak memory for large attention matrices. Prefer the in-place variant (attn_weights.masked_fill_(...)) when it’s safe in this function’s flow to avoid an extra allocation.
| | Parameter name | Description | Default value | | ||
| | ---------------------------------------- | -------------------------------------------------------------------------------------------- | ------------------------------------------ | | ||
| | `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD` | KV length threshold above which slicing is applied. Set to `-1` to disable slicing. | `min(max_num_batched_tokens, 8192)` | | ||
| | `VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE` | Chunk size for `q_len` and `kv_len` in each chunk. Rounded up to the next multiple of 1024. | `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD // 2` | | ||
| | `VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS` | Places each chunk in a separate graph to reduce compilation time. | `true` for lazy mode and `false` otherwise | |
There was a problem hiding this comment.
The markdown table uses || at the start of each row, which typically renders as an extra empty column (or may not render as intended depending on the markdown renderer). Use single | delimiters for standard GitHub-flavored markdown tables.
| chunk_res = torch.ops.hpu.sdpa_recomp_fwd( | ||
| q_chunk, | ||
| k_chunk, | ||
| v_chunk, | ||
| mask_chunk, | ||
| dropout_p, | ||
| scale, | ||
| is_causal_chunk, | ||
| True, # requires_backward | ||
| softmax_mode, | ||
| None, # valid_seq_len | ||
| padding_side, | ||
| ) | ||
| chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32) | ||
| for x in (chunk_res[:3])) |
There was a problem hiding this comment.
This sliced path hard-codes requires_backward=True and unconditionally casts chunk outputs (including chunk_out) to float32. For long contexts this can materially increase memory and runtime overhead in inference. If the kernel/API allows it, consider disabling backward requirements and limiting float32 to the accumulator numerics (e.g., keep m/linv and the running mix in float32, but avoid extra float32 copies of each chunk_out unless needed).
|
add @yangulei for review too. |
Port of PR vllm-project#1032 from aice branch to main. Converts attention masks from float (bf16 with -inf values) to boolean format. This reduces memory usage (bool vs bf16) and is required for the FusedSDPA slicing feature to get valid m and linv outputs from the kernel. Key changes: - _naive_prompt_attention: handle bool attn_bias via masked_fill - _fsdpa_prompt_attention: remove causal+attn_bias workaround (now supported) - _make_attn_bias: output ~attn_mask (bool) instead of float masked with -inf - _set_attn_bias: output ~mask (bool) - _set_attn_bias_for_sliding_window: use bool masks throughout - _set_attn_bias_for_chunked_attention: use bool masks throughout - _set_block_mapping: use mask < block_usage (bool) instead of float Ref: GAUDISW-245533 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Artur Fierka <artur.fierka@intel.com> Co-authored-by: yangulei <24203353+yangulei@users.noreply.github.com>
Port of PR vllm-project#1034 from aice branch to main. Splits FusedSDPA kernel into smaller chunks for long sequences to: - Fit chunks into SRAM for better performance - Improve TPC/MME pipelining - Reduce attention-mask usage for padded regions New env vars: - VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD: KV length threshold for slicing - VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE: chunk size (rounded to 1024) - VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS: graph break control Only active with linear bucketing strategy and boolean attention masks. Depends on: Boolean attention mask (port of vllm-project#1032) Ref: GAUDISW-245533 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Artur Fierka <artur.fierka@intel.com> Co-authored-by: yangulei <24203353+yangulei@users.noreply.github.com>
The boolean attention mask change (port of vllm-project#1032) creates bool block_bias in _set_block_mapping (True=valid, False=masked), but pipelined_pa() blindly cast it to float (True→1.0, False→0.0) and added it to attention scores. This broke masking: valid positions got +1.0 noise and masked positions got no penalty (should be -inf). Fix: detect bool dtype and convert to proper additive bias (0.0/-inf) before use in both the block_softmax kernel path and the manual fallback. Co-authored-by: yangulei <24203353+yangulei@users.noreply.github.com> Signed-off-by: Artur Fierka <artur.fierka@intel.com>
960c194 to
2e875e3
Compare
Upstream vLLM renamed get_eagle3_aux_hidden_state_layers() to get_eagle3_default_aux_hidden_state_layers(). Update the call in hpu_model_runner.py to match. Co-authored-by: yangulei <24203353+yangulei@users.noreply.github.com> Signed-off-by: Artur Fierka <artur.fierka@intel.com>
2e875e3 to
0cf717e
Compare
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
|
The PR depends on #1154. Let's merge the dependency first. I'll rebase this PR then. |
| if is_causal and attn_mask is not None: | ||
| # TODO: causal + attn_bias is not yet supported | ||
| is_causal = False | ||
| valid_sequence_lengths = None |
There was a problem hiding this comment.
The whole if is_causal ... could be removed because we should be able to support all of that. Slawomir Laba can help if this doesn't work. If there are issues then it's not needed though so do not block the work here for the enablement of it.
| q_start = max(q_start, 0) | ||
| q_end = q_len - q_chunk_idx * self.chunk_size | ||
| q_chunk_size = q_end - q_start | ||
| q_chunk = q[..., q_start:q_end, :].clone() if self.with_graph_breaks else q[..., q_start:q_end, :] |
There was a problem hiding this comment.
I've never seen ... operator, TIL. I trust it's needed here
| return output.to(q.dtype) | ||
|
|
||
| def _setup_slicing(self) -> bool: | ||
| from vllm_gaudi.extension.bucketing.common import get_bucketing_manager |
There was a problem hiding this comment.
The imports inside are needed because they are not seen otherwise? If not we could place them outside.
kamil-kaczor
left a comment
There was a problem hiding this comment.
Added few nitpicks but if this passes CI + customer benchmark tests like we've discussed on the call, this can go in imo
| logger().warning('Bucketing manager is not initialized, slicing in FSDPA will be disabled.') | ||
| return False | ||
|
|
||
| from vllm_gaudi.extension.bucketing.linear import LinearBucketingStrategy |
There was a problem hiding this comment.
The imports inside are needed because they are not seen otherwise? If not we could place them outside.
| max_ctx_pad = int(os.getenv("VLLM_PROMPT_CTX_BUCKET_PAD_MAX", str(max_ctx_pad_default))) | ||
| self.num_padded_ctx_chunks = math.ceil(max_ctx_pad * block_size / self.chunk_size) | ||
|
|
||
| import habana_frameworks.torch as ht |
There was a problem hiding this comment.
The imports inside are needed because they are not seen otherwise? If not we could place them outside.
| causal_mask = torch.triu(causal_mask, diagonal=1).unsqueeze(0) | ||
| attn_mask[:, :, context_len:].logical_or_(causal_mask) | ||
| attn_mask = attn_mask.to(dtype).masked_fill_(attn_mask, -math.inf) | ||
| attn_mask = ~attn_mask |
There was a problem hiding this comment.
Nice but any arithmetic on the mask will fail later on so make sure it's well tested
Motivation
This PR enables 256K context length support for models like Qwen3-30B-A3B-Thinking-2507 on Gaudi 2 with TP=1.
It ports two key patches from the
aicebranch to main:Boolean attention mask (port of aice PR Use Boolean attention mask #1032): Converts attention masks from float (bf16 with -inf) to boolean format, reducing memory usage and enabling proper interaction with the FusedSDPA kernel for long sequences.
FusedSDPA slicing (port of aice PR Enable slicing for the BF16 FusedSDPA #1034): Splits the FusedSDPA kernel into smaller chunks for long sequences to fit into SRAM, improve TPC/MME pipelining, and reduce attention-mask usage for padded regions.
Changes
Boolean attention mask
_naive_prompt_attention: Handle boolattn_biasviamasked_fill_fsdpa_prompt_attention: Remove causal+attn_bias workaround (now supported with bool masks)_make_attn_bias: Output~attn_mask(bool) instead of float masked with -inf_set_attn_bias: Output~mask(bool)_set_attn_bias_for_sliding_window: Use bool masks throughout_set_attn_bias_for_chunked_attention: Use bool masks throughout_set_block_mapping: Usemask < block_usage(bool) instead of floatFusedSDPA slicing
_sliced_fsdpa_fwdmethod toModuleFusedSDPAfor chunked attention computation_setup_slicingfor configuring chunk parameters from env varsVLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD,VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE,VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKSDependencies
Testing
Ref: GAUDISW-245533
Related: aice PRs #1032, #1034
Co-author: @yangulei