diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 4222044e5c..17799dc3e8 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -349,9 +349,12 @@ def _naive_prompt_attention(query: torch.Tensor, htcore.mark_step() attn_weights.add_(position_bias) if attn_bias is not None: - if attn_weights.dtype != attn_bias.dtype: - attn_bias = attn_bias.to(dtype=attn_weights.dtype) - attn_weights.add_(attn_bias) + if attn_bias.dtype == torch.bool: + attn_weights = attn_weights.masked_fill(~attn_bias, float("-inf")) + else: + if attn_weights.dtype != attn_bias.dtype: + attn_bias = attn_bias.to(dtype=attn_weights.dtype) + attn_weights.add_(attn_bias) if sinks is not None: sink = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) if query_heads != kv_heads: