Skip to content

Enable slicing for fp8 FusedSDPA#1285

Merged
czhu15 merged 5 commits into
vllm-project:aicefrom
yangulei:fp8_slice
Apr 8, 2026
Merged

Enable slicing for fp8 FusedSDPA#1285
czhu15 merged 5 commits into
vllm-project:aicefrom
yangulei:fp8_slice

Conversation

@yangulei
Copy link
Copy Markdown
Collaborator

@yangulei yangulei commented Apr 1, 2026

No description provided.

Copilot AI review requested due to automatic review settings April 1, 2026 06:51
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the existing FusedSDPA slicing mechanism to the FP8 FusedSDPA path on HPU, so long-context “chunked prefill” can use the sliced dispatch when FP8 attention is enabled.

Changes:

  • Refactors common slicing setup into a shared ModuleFusedSDPABase.
  • Adds a sliced forward path for ModuleFP8FusedSDPA (FP8 FusedSDPA) and routes into it under the same gating conditions as BF16.
  • Updates the HPU attention backend to prefer the FP8 FusedSDPA kernel when FP8 attention is enabled and the kernel is available.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
vllm_gaudi/extension/utils.py Adds shared slicing setup base class and implements FP8 sliced FusedSDPA forward path.
vllm_gaudi/attention/backends/hpu_attn.py Switches fsdpa_op to FP8 FusedSDPA when FP8 attention is enabled and import succeeds.

False, # is_amax_s
False, # is_amax_o
None, # valid_seq_len
"right", # seq_padding_type
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8_fsdpa_fwd hardcodes seq_padding_type to "right" and ignores the padding_side argument flowing into forward/_sliced_fsdpa_fwd. This can silently break correctness for left-padded inputs and is inconsistent with the BF16 slicing path (which passes padding_side). Thread padding_side through to fp8_fsdpa_fwd and use it instead of the constant.

Copilot uses AI. Check for mistakes.
Comment thread vllm_gaudi/extension/utils.py Outdated
mask_chunk = mask_chunk.clone() if mask_chunk is not None else None
self.break_graph()

chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, None, dropout_p, scale, False, softmax_mode)
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the FP8 sliced path (context part), mask_chunk is computed (and even cloned under with_graph_breaks) but fp8_fsdpa_fwd is called with attn_mask=None.
This drops the padding mask for context chunks and can produce incorrect attention results when padding is present; pass mask_chunk through instead of None.

Suggested change
chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, None, dropout_p, scale, False, softmax_mode)
chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p, scale, False, softmax_mode)

Copilot uses AI. Check for mistakes.
Comment thread vllm_gaudi/extension/utils.py Outdated
Comment on lines +418 to +422
@@ -413,6 +419,7 @@ def __init__(self, fusedSDPA):
self.d_scale_q = torch.tensor(1.0)
self.d_scale_k = torch.tensor(1.0)
self.d_scale_v = torch.tensor(1.0)
self.d_scale_output = torch.tensor(1.0)
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d_scale_output is introduced and used to dequantize sliced FP8 outputs, but it is not configured anywhere outside this module (unlike scale_q/k/v and d_scale_q/k/v, which are set during weight/scale processing). As-is, slicing will always dequant with the default 1.0 scale. Consider wiring d_scale_output into the same scale-setup path (e.g., wherever fused_scaled_dot_product_attention.d_scale_q/k/v is set) or deriving it from the kernel’s output scale contract.

Copilot uses AI. Check for mistakes.
Comment on lines +183 to +187
enable_slicing = enable_slicing and slice_thld >= slice_thld_default
if not enable_slicing and slice_thld > 0:
logger().warning('Invalid slice sequence length threshold, the threshold should be '
f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.')
slice_thld = slice_thld_default
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD is set to an invalid value (< default), the code logs that it is “falling back to default”, but enable_slicing stays False and the fallback threshold is never used. Either re-enable slicing after resetting slice_thld, or update the log/message to reflect that slicing remains disabled for invalid thresholds.

Suggested change
enable_slicing = enable_slicing and slice_thld >= slice_thld_default
if not enable_slicing and slice_thld > 0:
logger().warning('Invalid slice sequence length threshold, the threshold should be '
f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.')
slice_thld = slice_thld_default
if 0 < slice_thld < slice_thld_default:
logger().warning('Invalid slice sequence length threshold, the threshold should be '
f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.')
slice_thld = slice_thld_default
elif slice_thld <= 0:
enable_slicing = False

Copilot uses AI. Check for mistakes.
yangulei added 4 commits April 2, 2026 07:45
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Copy link
Copy Markdown
Collaborator

@czhu15 czhu15 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and example test code passed.
Pls check the comments from copilot.

@czhu15 czhu15 merged commit f337029 into vllm-project:aice Apr 8, 2026
1 check passed
@yangulei yangulei deleted the fp8_slice branch April 8, 2026 05:11
Copilot AI added a commit to JyhWind/vllm-gaudi that referenced this pull request May 19, 2026
The boolean mask handling for attn_bias was accidentally removed in commit
f337029 (Enable slicing for fp8 FusedSDPA vllm-project#1285). When attn_bias is a boolean
tensor, the code should use masked_fill to set invalid positions to -inf,
but instead it was using add_ which only adds 0/1 to the attention weights.
This causes incorrect attention scores and accuracy degradation, especially
for long prompts where proper masking of padded positions is critical.

Signed-off-by: copilot <copilot@github.com>
Signed-off-by: GitHub <noreply@github.com>

Co-authored-by: JyhWind <40982453+JyhWind@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants