Enable slicing for fp8 FusedSDPA#1285
Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends the existing FusedSDPA slicing mechanism to the FP8 FusedSDPA path on HPU, so long-context “chunked prefill” can use the sliced dispatch when FP8 attention is enabled.
Changes:
- Refactors common slicing setup into a shared
ModuleFusedSDPABase. - Adds a sliced forward path for
ModuleFP8FusedSDPA(FP8 FusedSDPA) and routes into it under the same gating conditions as BF16. - Updates the HPU attention backend to prefer the FP8 FusedSDPA kernel when FP8 attention is enabled and the kernel is available.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
vllm_gaudi/extension/utils.py |
Adds shared slicing setup base class and implements FP8 sliced FusedSDPA forward path. |
vllm_gaudi/attention/backends/hpu_attn.py |
Switches fsdpa_op to FP8 FusedSDPA when FP8 attention is enabled and import succeeds. |
| False, # is_amax_s | ||
| False, # is_amax_o | ||
| None, # valid_seq_len | ||
| "right", # seq_padding_type |
There was a problem hiding this comment.
fp8_fsdpa_fwd hardcodes seq_padding_type to "right" and ignores the padding_side argument flowing into forward/_sliced_fsdpa_fwd. This can silently break correctness for left-padded inputs and is inconsistent with the BF16 slicing path (which passes padding_side). Thread padding_side through to fp8_fsdpa_fwd and use it instead of the constant.
| mask_chunk = mask_chunk.clone() if mask_chunk is not None else None | ||
| self.break_graph() | ||
|
|
||
| chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, None, dropout_p, scale, False, softmax_mode) |
There was a problem hiding this comment.
In the FP8 sliced path (context part), mask_chunk is computed (and even cloned under with_graph_breaks) but fp8_fsdpa_fwd is called with attn_mask=None.
This drops the padding mask for context chunks and can produce incorrect attention results when padding is present; pass mask_chunk through instead of None.
| chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, None, dropout_p, scale, False, softmax_mode) | |
| chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p, scale, False, softmax_mode) |
| @@ -413,6 +419,7 @@ def __init__(self, fusedSDPA): | |||
| self.d_scale_q = torch.tensor(1.0) | |||
| self.d_scale_k = torch.tensor(1.0) | |||
| self.d_scale_v = torch.tensor(1.0) | |||
| self.d_scale_output = torch.tensor(1.0) | |||
There was a problem hiding this comment.
d_scale_output is introduced and used to dequantize sliced FP8 outputs, but it is not configured anywhere outside this module (unlike scale_q/k/v and d_scale_q/k/v, which are set during weight/scale processing). As-is, slicing will always dequant with the default 1.0 scale. Consider wiring d_scale_output into the same scale-setup path (e.g., wherever fused_scaled_dot_product_attention.d_scale_q/k/v is set) or deriving it from the kernel’s output scale contract.
| enable_slicing = enable_slicing and slice_thld >= slice_thld_default | ||
| if not enable_slicing and slice_thld > 0: | ||
| logger().warning('Invalid slice sequence length threshold, the threshold should be ' | ||
| f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.') | ||
| slice_thld = slice_thld_default |
There was a problem hiding this comment.
When VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD is set to an invalid value (< default), the code logs that it is “falling back to default”, but enable_slicing stays False and the fallback threshold is never used. Either re-enable slicing after resetting slice_thld, or update the log/message to reflect that slicing remains disabled for invalid thresholds.
| enable_slicing = enable_slicing and slice_thld >= slice_thld_default | |
| if not enable_slicing and slice_thld > 0: | |
| logger().warning('Invalid slice sequence length threshold, the threshold should be ' | |
| f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.') | |
| slice_thld = slice_thld_default | |
| if 0 < slice_thld < slice_thld_default: | |
| logger().warning('Invalid slice sequence length threshold, the threshold should be ' | |
| f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.') | |
| slice_thld = slice_thld_default | |
| elif slice_thld <= 0: | |
| enable_slicing = False |
This reverts commit 9271c08.
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
czhu15
left a comment
There was a problem hiding this comment.
LGTM and example test code passed.
Pls check the comments from copilot.
The boolean mask handling for attn_bias was accidentally removed in commit f337029 (Enable slicing for fp8 FusedSDPA vllm-project#1285). When attn_bias is a boolean tensor, the code should use masked_fill to set invalid positions to -inf, but instead it was using add_ which only adds 0/1 to the attention weights. This causes incorrect attention scores and accuracy degradation, especially for long prompts where proper masking of padded positions is critical. Signed-off-by: copilot <copilot@github.com> Signed-off-by: GitHub <noreply@github.com> Co-authored-by: JyhWind <40982453+JyhWind@users.noreply.github.com>
No description provided.