From 8b1524ee2a642bbab0c51e775690b44e61016c38 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 19 May 2026 05:43:06 +0000 Subject: [PATCH] Fix: restore boolean attention mask handling in _naive_prompt_attention The boolean mask handling for attn_bias was accidentally removed in commit f337029 (Enable slicing for fp8 FusedSDPA #1285). When attn_bias is a boolean tensor, the code should use masked_fill to set invalid positions to -inf, but instead it was using add_ which only adds 0/1 to the attention weights. This causes incorrect attention scores and accuracy degradation, especially for long prompts where proper masking of padded positions is critical. Signed-off-by: copilot Signed-off-by: GitHub Co-authored-by: JyhWind <40982453+JyhWind@users.noreply.github.com> --- vllm_gaudi/extension/ops.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 4222044e5c..17799dc3e8 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -349,9 +349,12 @@ def _naive_prompt_attention(query: torch.Tensor, htcore.mark_step() attn_weights.add_(position_bias) if attn_bias is not None: - if attn_weights.dtype != attn_bias.dtype: - attn_bias = attn_bias.to(dtype=attn_weights.dtype) - attn_weights.add_(attn_bias) + if attn_bias.dtype == torch.bool: + attn_weights = attn_weights.masked_fill(~attn_bias, float("-inf")) + else: + if attn_weights.dtype != attn_bias.dtype: + attn_bias = attn_bias.to(dtype=attn_weights.dtype) + attn_weights.add_(attn_bias) if sinks is not None: sink = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) if query_heads != kv_heads: