Skip to content

Enable slicing for the BF16 FusedSDPA#1034

Merged
czhu15 merged 2 commits into
vllm-project:aicefrom
yangulei:slice_fsdpa
Mar 5, 2026
Merged

Enable slicing for the BF16 FusedSDPA#1034
czhu15 merged 2 commits into
vllm-project:aicefrom
yangulei:slice_fsdpa

Conversation

@yangulei
Copy link
Copy Markdown
Collaborator

@yangulei yangulei commented Feb 25, 2026

Note that this PR depends on:

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an optional “sliced” forward path for the HPU BF16 FusedSDPA kernel to improve long-context performance when using linear bucketing, and updates bucketing padding defaults/documentation to better align with attention-mask usage.

Changes:

  • Add a chunked/sliced FusedSDPA forward path (with optional graph breaks) for the causal + attention-mask case.
  • Adjust linear bucketing defaults for *_BUCKET_PAD_MAX (reduce default absolute padding).
  • Update environment variable documentation and remove redundant causal+attn_bias handling from the ops wrapper.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 7 comments.

File Description
vllm_gaudi/extension/utils.py Introduces sliced FusedSDPA execution + slicing setup/config via env vars.
vllm_gaudi/extension/ops.py Removes local workaround; relies on ModuleFusedSDPA behavior.
vllm_gaudi/extension/bucketing/linear.py Changes default pad_max values to reduce padding.
docs/configuration/env_variables.md Documents updated defaults and adds new FusedSDPA slicing tuning parameters.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread vllm_gaudi/extension/utils.py Outdated
Comment thread vllm_gaudi/extension/utils.py
Comment thread vllm_gaudi/extension/utils.py
Comment thread vllm_gaudi/extension/bucketing/linear.py
Comment thread vllm_gaudi/extension/bucketing/linear.py
Comment thread vllm_gaudi/extension/utils.py Outdated
Comment thread vllm_gaudi/extension/utils.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread vllm_gaudi/extension/utils.py
Comment thread vllm_gaudi/extension/utils.py Outdated
Comment thread vllm_gaudi/extension/utils.py Outdated
Comment thread vllm_gaudi/extension/utils.py Outdated
Copy link
Copy Markdown
Collaborator

@taotod taotod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please fix the typos that copilot found or suggested.

@yangulei
Copy link
Copy Markdown
Collaborator Author

@yupengzh-intel @testdig
Please help to review, thanks!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread vllm_gaudi/extension/utils.py
Comment thread vllm_gaudi/extension/bucketing/linear.py Outdated
Comment thread docs/configuration/env_variables.md Outdated
Comment thread vllm_gaudi/extension/utils.py Outdated
f'falling back to default {qkv_chunk_size_default}.')
qkv_chunk_size = qkv_chunk_size_default
if qkv_chunk_size % 1024 != 0:
qkv_chunk_size = (qkv_chunk_size + 1023) // 1024 * 1024
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is qkv_chunk_size = math.ceil(qkv_chunk_size / 1024) * 1024 easy to read?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to use qkv_chunk_size = int(math.ceil(qkv_chunk_size / 1024)) * 1024 and import math. I prefer the integer arithmetic here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rechecked this and found out that match.ceil() returns int, so I use it for all the ceiling in the code now.

Comment thread docs/configuration/env_variables.md Outdated
| Parameter name | Description | Default value |
| ---------------------------------------- | ----------------------------------------------------------------------------------------------- | ------------------------------------------ |
| `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD` | KV length threshold above which slicing is enabled. Set to `-1` to disable slicing. | `min(max_num_batched_tokens, 8192)` |
| `VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE` | Chunk size for `q_len` and `kv_len` in each chunk. Rounded up to the nearest multiple of 1024. | `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD // 2` |
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rounded up to the “next" looks better than "nearest"

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do, thank you.

Comment thread docs/configuration/env_variables.md Outdated

| Parameter name | Description | Default value |
| ---------------------------------------- | ----------------------------------------------------------------------------------------------- | ------------------------------------------ |
| `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD` | KV length threshold above which slicing is enabled. Set to `-1` to disable slicing. | `min(max_num_batched_tokens, 8192)` |
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it mean "Enable slicing when the KV length exceeds this threshold."?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The enabled/disable here is ambiguous. Setting VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD to a negative value totally disable the slicing. And setting VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD to a valid value will enable the slicing, and the slicing only applied for KV length exceeds this threshold.

Will change the enabled to applied.

Comment thread vllm_gaudi/extension/utils.py Outdated
return False

max_num_batched_tokens = bucketing_manager.max_num_batched_tokens
qkv_slice_thld_default = min(max_num_batched_tokens, 8192)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the case if max_num_batched_tokens is 8k or 16k, and I set VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD to 4k, then slicing cannot be enabled, is it by design?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qkv_slice_thld will be reset to min(max_num_batched_tokens, 8192) = 8192 in those cases. The optimization only get performance gain for kv_len >= (1+ num_padded_ctx_chunks) * max_num_batched_tokens. Setting the VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD less than max_num_batched_tokens will hurt the performance.

Comment thread vllm_gaudi/extension/utils.py Outdated
bs = query.shape[0]
q_len = query.shape[-2]
kv_len = key.shape[-2]
if (self.enable_slicing and bs == 1 and q_len != kv_len and kv_len >= self.qkv_slice_thld and is_causal
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if slicing works only if bs == 1, shall we document it?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bs == 1 is always satisfied for prefills with context. Actually I'm not sure whether the bs > 1 prefills with context works.

Comment thread docs/configuration/env_variables.md Outdated
| Prompt | query length step (`VLLM_PROMPT_QUERY_BUCKET_STEP`) | `block_size` |
| Prompt | query length max (`VLLM_PROMPT_QUERY_BUCKET_MAX`) | `max_num_batched_tokens` |
| Prompt | query length max abs padding (`VLLM_PROMPT_QUERY_BUCKET_PAD_MAX`) | `max_num_batched_tokens` |
| Prompt | query length max abs padding (`VLLM_PROMPT_QUERY_BUCKET_PAD_MAX`) | `max_num_batched_tokens // 4` |
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, it would be better to have a standalone PR to change the two default max padding value.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but the max_query_pad and max_ctx_pad share the same ENVs and default values in linear bucketing, so they have to be modified simultaneously anyway.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, the prompt_buckets are not generated yet when the __init__() of FusedSDPA is called, so we have to get the maximum possible padding from the configure instead of the actual buckets.

Copy link
Copy Markdown
Contributor

@testdig testdig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Collaborator

@czhu15 czhu15 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A very important PR to further improve the performance on Gaudi.
Some comments on old version PR. pls kindly check if they are valid in the new PR too.

Comment thread docs/configuration/env_variables.md Outdated
Comment thread docs/configuration/env_variables.md Outdated
| `VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS` | Places each chunk in a separate graph to reduce compilation time. | `true` |

!!! note
These parameters are effective only with the linear bucketing strategy, where the max absolute padding for `query` and `context` determines attention-mask usage.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the meaning of "where the max absolute padding for query and context determines attention-mask usage."? If it is pure internal implementation, suggest to not expose to external user (this document).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max abs padding in the query and context determines the chunks for which the FusedSDPA have to be called with attention mask to handle the padding in them.

Copy link
Copy Markdown
Collaborator

@czhu15 czhu15 Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a note message. what should user do if user enables slicing with linear bucketing strategy?
for me, "where the max absolute padding for query and context determines attention-mask usage." it is a pure internal implementation and I guess most users don't understand it :) I would suggest to remove it from the readme file. Only keep "These parameters are effective only with the linear bucketing strategy" is enough.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, will make it clearer, thanks.

Comment thread vllm_gaudi/extension/utils.py Outdated
bs = query.shape[0]
q_len = query.shape[-2]
kv_len = key.shape[-2]
if (self.enable_slicing and bs == 1 and q_len != kv_len and kv_len >= self.qkv_slice_thld and is_causal
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"(self.enable_slicing and bs == 1 and q_len != kv_len and kv_len >= self.qkv_slice_thld and is_causal..."
too many checks here. maybe better to wrapper them in an internal function for better readability.
BTW, why can't use sliced fsdpa when q_len == kv_len?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add comments here to make it clearer.
It's a normal causal FSDPA for q_len == kv_len, the default call is the most efficient way as it call FusedSDPA with is_causal=True and valid_seq_length without re-scaling overhead.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. thanks for the explaination.

Comment thread vllm_gaudi/extension/utils.py
last_m = new_m

if self.with_graph_breaks:
self.break_graph()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above.

Comment thread vllm_gaudi/extension/utils.py
Comment thread docs/configuration/env_variables.md Outdated
| Prompt | batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`) | `1` |
| Prompt | batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`) | `max_num_prefill_seqs` |
| Prompt | batch size max abs padding (`VLLM_PROMPT_BS_BUCKET_PAD_MAX`) | `16` |
| Prompt | batch size max abs padding (`VLLM_PROMPT_BS_BUCKET_PAD_MAX`) | `max_num_prefill_seqs / 4` |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why slice will change the prompt bs padding?

Copy link
Copy Markdown
Collaborator Author

@yangulei yangulei Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not for slicing. I aligned the *BUCKET_PAD_MAX = math.ceil(BUCKET_MAX / 4) for all the possible configurations.

Comment thread docs/configuration/env_variables.md Outdated
| Prompt | sequence ctx max (`VLLM_PROMPT_CTX_BUCKET_MAX`) | `(max_model_len - block_size) // block_size` |
| Prompt | sequence ctx max abs padding (`VLLM_PROMPT_CTX_BUCKET_PAD_MAX`) | `max_num_batched_tokens // block_size` |
| Prompt | sequence ctx step (`VLLM_PROMPT_CTX_BUCKET_STEP`) | `2` |
| Prompt | sequence ctx max (`VLLM_PROMPT_CTX_BUCKET_MAX`) | `(max_model_len - block_size) / block_size` |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// should be correct, but not /

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually math.ceil((max_model_len - block_size) / block_size). Do you think we should submit the changes to bucketing by another PR?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I prefer to submit separate PRs for bucketing optimization. Otherwise this PR contains too many changes not related to slice feature.

Comment thread docs/configuration/env_variables.md Outdated
| Prompt | sequence ctx max padding percent (`VLLM_PROMPT_CTX_BUCKET_PAD_PERCENT`) | `25` |
| Decode | batch size min (`VLLM_DECODE_BS_BUCKET_MIN`) | `1` |
| Decode | batch size step (`VLLM_DECODE_BS_BUCKET_STEP`) | `32` |
| Decode | batch size step (`VLLM_DECODE_BS_BUCKET_STEP`) | `2` |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change VLLM_DECODE_BS_BUCKET_STEP to a so small value?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous step of 32 introduced to many padding in some cases with max-concurrency < 32. The value in the code is modified to 2 in the PR for linear bucketing with limits, while the doc here unintentionally left unchanged.

Comment thread docs/configuration/env_variables.md Outdated
| Decode | block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`) | `max_model_len * max_num_seqs // block_size` <br>by default or `max_blocks` <br>if `VLLM_CONTIGUOUS_PA = True`|
| Decode | block size max abs padding (`VLLM_DECODE_BLOCK_BUCKET_PAD_MAX`) | `max_num_batched_tokens * max_num_seqs // block_size` |
| Decode | block size max (`VLLM_DECODE_BLOCK_BCUKET_MAX`) | `max_model_len * max_num_seqs / block_size` <br>by default or `max_blocks` <br>if `VLLM_CONTIGUOUS_PA = True` |
| Decode | block size max abs padding (`VLLM_DECODE_BLOCK_BUCKET_PAD_MAX`) | `VLLM_DECODE_BLOCK_BCUKET_MAX / 4` |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be // but no /

pad_max=math.ceil(max_decode_blocks / 4),
pad_percent=25)
if decode_block_bucket_cfg[2] > max_blocks:
if contiguous_pa and decode_block_bucket_cfg[2] > max_blocks:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need check contiguous_pa? Does slice feature depend on contiguous pa?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contiguous_pa and decode_block_bucket_cfg[2] is VLLM_DECODE_BLOCK_BCUKET_MAX, and VLLM_DECODE_BLOCK_BCUKET_MAX <= max_blocks is satisfied for contiguous PA only. The original code may cause not warmed-up for non contiguous PA cases.

Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Copy link
Copy Markdown
Contributor

@yupengzh-intel yupengzh-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yangulei
Copy link
Copy Markdown
Collaborator Author

yangulei commented Mar 4, 2026

@czhu15 The changes for bucketing is submitted in #1086.

Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Copy link
Copy Markdown
Collaborator

@czhu15 czhu15 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@czhu15 czhu15 merged commit b9db1af into vllm-project:aice Mar 5, 2026
1 check passed
@yangulei yangulei deleted the slice_fsdpa branch March 5, 2026 02:58
tvoas pushed a commit to tvoas/vllm-gaudi that referenced this pull request Mar 11, 2026
Note that this PR depends on:
- the **Boolean** attention mask introduced by
vllm-project#1032 to get valid `m`
and `linv` from the FusedSDPA kernel,
- the default query/ctx bucketing config modified in
vllm-project#1086

---------

Signed-off-by: Youlei Yang <youlei.yang@intel.com>
afierka-intel added a commit to afierka-intel/vllm-gaudi that referenced this pull request Mar 12, 2026
Port of PR vllm-project#1034 from aice branch to main.

Splits FusedSDPA kernel into smaller chunks for long sequences to:
- Fit chunks into SRAM for better performance
- Improve TPC/MME pipelining
- Reduce attention-mask usage for padded regions

New env vars:
- VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD: KV length threshold for slicing
- VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE: chunk size (rounded to 1024)
- VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS: graph break control

Only active with linear bucketing strategy and boolean attention masks.

Depends on: Boolean attention mask (port of vllm-project#1032)
Ref: GAUDISW-245533

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
afierka-intel added a commit to afierka-intel/vllm-gaudi that referenced this pull request Mar 17, 2026
Port of PR vllm-project#1034 from aice branch to main.

Splits FusedSDPA kernel into smaller chunks for long sequences to:
- Fit chunks into SRAM for better performance
- Improve TPC/MME pipelining
- Reduce attention-mask usage for padded regions

New env vars:
- VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD: KV length threshold for slicing
- VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE: chunk size (rounded to 1024)
- VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS: graph break control

Only active with linear bucketing strategy and boolean attention masks.

Depends on: Boolean attention mask (port of vllm-project#1032)
Ref: GAUDISW-245533

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Artur Fierka <artur.fierka@intel.com>
Co-authored-by: yangulei <24203353+yangulei@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants