Enable slicing for the BF16 FusedSDPA#1034
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds an optional “sliced” forward path for the HPU BF16 FusedSDPA kernel to improve long-context performance when using linear bucketing, and updates bucketing padding defaults/documentation to better align with attention-mask usage.
Changes:
- Add a chunked/sliced FusedSDPA forward path (with optional graph breaks) for the causal + attention-mask case.
- Adjust linear bucketing defaults for
*_BUCKET_PAD_MAX(reduce default absolute padding). - Update environment variable documentation and remove redundant causal+attn_bias handling from the ops wrapper.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
vllm_gaudi/extension/utils.py |
Introduces sliced FusedSDPA execution + slicing setup/config via env vars. |
vllm_gaudi/extension/ops.py |
Removes local workaround; relies on ModuleFusedSDPA behavior. |
vllm_gaudi/extension/bucketing/linear.py |
Changes default pad_max values to reduce padding. |
docs/configuration/env_variables.md |
Documents updated defaults and adds new FusedSDPA slicing tuning parameters. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
taotod
left a comment
There was a problem hiding this comment.
LGTM, please fix the typos that copilot found or suggested.
|
@yupengzh-intel @testdig |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| f'falling back to default {qkv_chunk_size_default}.') | ||
| qkv_chunk_size = qkv_chunk_size_default | ||
| if qkv_chunk_size % 1024 != 0: | ||
| qkv_chunk_size = (qkv_chunk_size + 1023) // 1024 * 1024 |
There was a problem hiding this comment.
is qkv_chunk_size = math.ceil(qkv_chunk_size / 1024) * 1024 easy to read?
There was a problem hiding this comment.
Have to use qkv_chunk_size = int(math.ceil(qkv_chunk_size / 1024)) * 1024 and import math. I prefer the integer arithmetic here.
There was a problem hiding this comment.
I rechecked this and found out that match.ceil() returns int, so I use it for all the ceiling in the code now.
| | Parameter name | Description | Default value | | ||
| | ---------------------------------------- | ----------------------------------------------------------------------------------------------- | ------------------------------------------ | | ||
| | `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD` | KV length threshold above which slicing is enabled. Set to `-1` to disable slicing. | `min(max_num_batched_tokens, 8192)` | | ||
| | `VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE` | Chunk size for `q_len` and `kv_len` in each chunk. Rounded up to the nearest multiple of 1024. | `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD // 2` | |
There was a problem hiding this comment.
Rounded up to the “next" looks better than "nearest"
There was a problem hiding this comment.
Will do, thank you.
|
|
||
| | Parameter name | Description | Default value | | ||
| | ---------------------------------------- | ----------------------------------------------------------------------------------------------- | ------------------------------------------ | | ||
| | `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD` | KV length threshold above which slicing is enabled. Set to `-1` to disable slicing. | `min(max_num_batched_tokens, 8192)` | |
There was a problem hiding this comment.
does it mean "Enable slicing when the KV length exceeds this threshold."?
There was a problem hiding this comment.
The enabled/disable here is ambiguous. Setting VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD to a negative value totally disable the slicing. And setting VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD to a valid value will enable the slicing, and the slicing only applied for KV length exceeds this threshold.
Will change the enabled to applied.
| return False | ||
|
|
||
| max_num_batched_tokens = bucketing_manager.max_num_batched_tokens | ||
| qkv_slice_thld_default = min(max_num_batched_tokens, 8192) |
There was a problem hiding this comment.
in the case if max_num_batched_tokens is 8k or 16k, and I set VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD to 4k, then slicing cannot be enabled, is it by design?
There was a problem hiding this comment.
qkv_slice_thld will be reset to min(max_num_batched_tokens, 8192) = 8192 in those cases. The optimization only get performance gain for kv_len >= (1+ num_padded_ctx_chunks) * max_num_batched_tokens. Setting the VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD less than max_num_batched_tokens will hurt the performance.
| bs = query.shape[0] | ||
| q_len = query.shape[-2] | ||
| kv_len = key.shape[-2] | ||
| if (self.enable_slicing and bs == 1 and q_len != kv_len and kv_len >= self.qkv_slice_thld and is_causal |
There was a problem hiding this comment.
if slicing works only if bs == 1, shall we document it?
There was a problem hiding this comment.
bs == 1 is always satisfied for prefills with context. Actually I'm not sure whether the bs > 1 prefills with context works.
| | Prompt | query length step (`VLLM_PROMPT_QUERY_BUCKET_STEP`) | `block_size` | | ||
| | Prompt | query length max (`VLLM_PROMPT_QUERY_BUCKET_MAX`) | `max_num_batched_tokens` | | ||
| | Prompt | query length max abs padding (`VLLM_PROMPT_QUERY_BUCKET_PAD_MAX`) | `max_num_batched_tokens` | | ||
| | Prompt | query length max abs padding (`VLLM_PROMPT_QUERY_BUCKET_PAD_MAX`) | `max_num_batched_tokens // 4` | |
There was a problem hiding this comment.
IMHO, it would be better to have a standalone PR to change the two default max padding value.
There was a problem hiding this comment.
Yes, but the max_query_pad and max_ctx_pad share the same ENVs and default values in linear bucketing, so they have to be modified simultaneously anyway.
There was a problem hiding this comment.
In fact, the prompt_buckets are not generated yet when the __init__() of FusedSDPA is called, so we have to get the maximum possible padding from the configure instead of the actual buckets.
2ad51a6 to
f76b8d6
Compare
f76b8d6 to
480fc85
Compare
czhu15
left a comment
There was a problem hiding this comment.
A very important PR to further improve the performance on Gaudi.
Some comments on old version PR. pls kindly check if they are valid in the new PR too.
| | `VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS` | Places each chunk in a separate graph to reduce compilation time. | `true` | | ||
|
|
||
| !!! note | ||
| These parameters are effective only with the linear bucketing strategy, where the max absolute padding for `query` and `context` determines attention-mask usage. |
There was a problem hiding this comment.
what's the meaning of "where the max absolute padding for query and context determines attention-mask usage."? If it is pure internal implementation, suggest to not expose to external user (this document).
There was a problem hiding this comment.
The max abs padding in the query and context determines the chunks for which the FusedSDPA have to be called with attention mask to handle the padding in them.
There was a problem hiding this comment.
it is a note message. what should user do if user enables slicing with linear bucketing strategy?
for me, "where the max absolute padding for query and context determines attention-mask usage." it is a pure internal implementation and I guess most users don't understand it :) I would suggest to remove it from the readme file. Only keep "These parameters are effective only with the linear bucketing strategy" is enough.
There was a problem hiding this comment.
I see, will make it clearer, thanks.
| bs = query.shape[0] | ||
| q_len = query.shape[-2] | ||
| kv_len = key.shape[-2] | ||
| if (self.enable_slicing and bs == 1 and q_len != kv_len and kv_len >= self.qkv_slice_thld and is_causal |
There was a problem hiding this comment.
"(self.enable_slicing and bs == 1 and q_len != kv_len and kv_len >= self.qkv_slice_thld and is_causal..."
too many checks here. maybe better to wrapper them in an internal function for better readability.
BTW, why can't use sliced fsdpa when q_len == kv_len?
There was a problem hiding this comment.
I will add comments here to make it clearer.
It's a normal causal FSDPA for q_len == kv_len, the default call is the most efficient way as it call FusedSDPA with is_causal=True and valid_seq_length without re-scaling overhead.
There was a problem hiding this comment.
got it. thanks for the explaination.
| last_m = new_m | ||
|
|
||
| if self.with_graph_breaks: | ||
| self.break_graph() |
| | Prompt | batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`) | `1` | | ||
| | Prompt | batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`) | `max_num_prefill_seqs` | | ||
| | Prompt | batch size max abs padding (`VLLM_PROMPT_BS_BUCKET_PAD_MAX`) | `16` | | ||
| | Prompt | batch size max abs padding (`VLLM_PROMPT_BS_BUCKET_PAD_MAX`) | `max_num_prefill_seqs / 4` | |
There was a problem hiding this comment.
why slice will change the prompt bs padding?
There was a problem hiding this comment.
It's not for slicing. I aligned the *BUCKET_PAD_MAX = math.ceil(BUCKET_MAX / 4) for all the possible configurations.
| | Prompt | sequence ctx max (`VLLM_PROMPT_CTX_BUCKET_MAX`) | `(max_model_len - block_size) // block_size` | | ||
| | Prompt | sequence ctx max abs padding (`VLLM_PROMPT_CTX_BUCKET_PAD_MAX`) | `max_num_batched_tokens // block_size` | | ||
| | Prompt | sequence ctx step (`VLLM_PROMPT_CTX_BUCKET_STEP`) | `2` | | ||
| | Prompt | sequence ctx max (`VLLM_PROMPT_CTX_BUCKET_MAX`) | `(max_model_len - block_size) / block_size` | |
There was a problem hiding this comment.
// should be correct, but not /
There was a problem hiding this comment.
It's actually math.ceil((max_model_len - block_size) / block_size). Do you think we should submit the changes to bucketing by another PR?
There was a problem hiding this comment.
Yes. I prefer to submit separate PRs for bucketing optimization. Otherwise this PR contains too many changes not related to slice feature.
| | Prompt | sequence ctx max padding percent (`VLLM_PROMPT_CTX_BUCKET_PAD_PERCENT`) | `25` | | ||
| | Decode | batch size min (`VLLM_DECODE_BS_BUCKET_MIN`) | `1` | | ||
| | Decode | batch size step (`VLLM_DECODE_BS_BUCKET_STEP`) | `32` | | ||
| | Decode | batch size step (`VLLM_DECODE_BS_BUCKET_STEP`) | `2` | |
There was a problem hiding this comment.
why change VLLM_DECODE_BS_BUCKET_STEP to a so small value?
There was a problem hiding this comment.
The previous step of 32 introduced to many padding in some cases with max-concurrency < 32. The value in the code is modified to 2 in the PR for linear bucketing with limits, while the doc here unintentionally left unchanged.
| | Decode | block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`) | `max_model_len * max_num_seqs // block_size` <br>by default or `max_blocks` <br>if `VLLM_CONTIGUOUS_PA = True`| | ||
| | Decode | block size max abs padding (`VLLM_DECODE_BLOCK_BUCKET_PAD_MAX`) | `max_num_batched_tokens * max_num_seqs // block_size` | | ||
| | Decode | block size max (`VLLM_DECODE_BLOCK_BCUKET_MAX`) | `max_model_len * max_num_seqs / block_size` <br>by default or `max_blocks` <br>if `VLLM_CONTIGUOUS_PA = True` | | ||
| | Decode | block size max abs padding (`VLLM_DECODE_BLOCK_BUCKET_PAD_MAX`) | `VLLM_DECODE_BLOCK_BCUKET_MAX / 4` | |
| pad_max=math.ceil(max_decode_blocks / 4), | ||
| pad_percent=25) | ||
| if decode_block_bucket_cfg[2] > max_blocks: | ||
| if contiguous_pa and decode_block_bucket_cfg[2] > max_blocks: |
There was a problem hiding this comment.
why need check contiguous_pa? Does slice feature depend on contiguous pa?
There was a problem hiding this comment.
contiguous_pa and decode_block_bucket_cfg[2] is VLLM_DECODE_BLOCK_BCUKET_MAX, and VLLM_DECODE_BLOCK_BCUKET_MAX <= max_blocks is satisfied for contiguous PA only. The original code may cause not warmed-up for non contiguous PA cases.
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Note that this PR depends on: - the **Boolean** attention mask introduced by vllm-project#1032 to get valid `m` and `linv` from the FusedSDPA kernel, - the default query/ctx bucketing config modified in vllm-project#1086 --------- Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Port of PR vllm-project#1034 from aice branch to main. Splits FusedSDPA kernel into smaller chunks for long sequences to: - Fit chunks into SRAM for better performance - Improve TPC/MME pipelining - Reduce attention-mask usage for padded regions New env vars: - VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD: KV length threshold for slicing - VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE: chunk size (rounded to 1024) - VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS: graph break control Only active with linear bucketing strategy and boolean attention masks. Depends on: Boolean attention mask (port of vllm-project#1032) Ref: GAUDISW-245533 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Port of PR vllm-project#1034 from aice branch to main. Splits FusedSDPA kernel into smaller chunks for long sequences to: - Fit chunks into SRAM for better performance - Improve TPC/MME pipelining - Reduce attention-mask usage for padded regions New env vars: - VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD: KV length threshold for slicing - VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE: chunk size (rounded to 1024) - VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS: graph break control Only active with linear bucketing strategy and boolean attention masks. Depends on: Boolean attention mask (port of vllm-project#1032) Ref: GAUDISW-245533 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Artur Fierka <artur.fierka@intel.com> Co-authored-by: yangulei <24203353+yangulei@users.noreply.github.com>
Note that this PR depends on:
mandlinvfrom the FusedSDPA kernel,