Skip to content

Use finite numbers for the attention mask#1290

Closed
yangulei wants to merge 2 commits into
vllm-project:aicefrom
yangulei:float_mask
Closed

Use finite numbers for the attention mask#1290
yangulei wants to merge 2 commits into
vllm-project:aicefrom
yangulei:float_mask

Conversation

@yangulei
Copy link
Copy Markdown
Collaborator

@yangulei yangulei commented Apr 2, 2026

Revert "Use Boolean attention mask (#1032)" and use finite numbers for attention mask to solve the accuracy issue for fp8 FusedSDPA.

yangulei added 2 commits April 2, 2026 07:45
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Copilot AI review requested due to automatic review settings April 2, 2026 08:42
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to address fp8 FusedSDPA accuracy issues by reverting the boolean attention-mask approach (from #1032) and consistently using additive attention biases with large finite negative values.

Changes:

  • Introduces a dtype→finite-mask-value mapping (MASK_VALUES) and uses it to build additive attention biases instead of boolean masks.
  • Updates multiple attention-bias construction paths (prefill/paged, sliding-window, chunked, block-mapping) to emit numeric biases.
  • Simplifies _naive_prompt_attention to always cast-and-add attn_bias (removing the boolean-mask special case).

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
vllm_gaudi/v1/worker/hpu_model_runner.py Switches attention bias/mask generation to finite numeric additive biases across several metadata paths; adds global mask constants.
vllm_gaudi/extension/ops.py Removes boolean-mask handling in naive attention and always treats attn_bias as additive.

Comment on lines +136 to +137
MASK_VALUES = {torch.float32: -3E38, torch.bfloat16: -3E38, torch.float16: -6E4}
DEFAULT_MASK_VALUE = -3E38
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MASK_VALUES doesn’t cover fp8 dtypes (e.g., torch.float8_e4m3fn). For fp8 models, the fallback DEFAULT_MASK_VALUE (-3e38) will overflow when cast to fp8 (likely becoming -inf), defeating the PR goal of using finite mask values for fp8 FusedSDPA. Consider deriving the mask value from torch.finfo(dtype).min (or adding explicit fp8 entries) so the value is finite and representable for every dtype you pass into the attention bias builders.

Copilot uses AI. Check for mistakes.
Comment on lines +2073 to 2075
attn_mask = attn_mask.to(dtype).masked_fill_(attn_mask, MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE))

return attn_mask.unflatten(0, (1, -1))
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_make_attn_bias materializes the bias tensor on CPU using dtype = self.dtype via attn_mask.to(dtype). If self.dtype can be an fp8 dtype, this conversion will either be unsupported on CPU or will quantize/overflow the mask values before the tensor is moved to HPU. A safer approach is to keep the CPU mask boolean and create/convert the numeric bias after transferring to HPU (or build it in a CPU-supported dtype like bf16/float32 and cast on-device).

Suggested change
attn_mask = attn_mask.to(dtype).masked_fill_(attn_mask, MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE))
return attn_mask.unflatten(0, (1, -1))
# Avoid materializing the bias tensor on CPU in an fp8 dtype. Build the
# numeric bias in a CPU-supported floating dtype and rely on downstream
# device-side casting when needed.
fp8_dtypes = {getattr(torch, name) for name in ("float8_e4m3fn", "float8_e5m2") if hasattr(torch, name)}
cpu_bias_dtype = torch.bfloat16 if dtype in fp8_dtypes else dtype
attn_bias = attn_mask.to(cpu_bias_dtype)
attn_bias.masked_fill_(attn_mask, MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE))
return attn_bias.unflatten(0, (1, -1))

Copilot uses AI. Check for mistakes.
Comment on lines +6615 to +6620
tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1)
mask = torch.tril(tensor, diagonal=shift)
attn_bias = torch.triu(mask, diagonal=shift - window_size + 1)
mask = torch.triu(mask, diagonal=shift - window_size + 1)
attn_bias = torch.where(
mask, torch.tensor(0.0, dtype=dtype, device=device),
torch.tensor(MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE), dtype=dtype, device=device))
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this branch, tensor/mask are created with dtype=dtype (floating), but torch.where requires a boolean condition tensor. With typical dtype values (fp16/bf16/fp32), this will raise at runtime. Keep the mask boolean (e.g., build tensor with dtype=torch.bool and use logical ops), then produce the numeric bias via masked_fill/where using that boolean mask.

Copilot uses AI. Check for mistakes.
Comment on lines +6678 to +6687
tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1)
mask = torch.tril(tensor, diagonal=shift)
idx = torch.arange(seq_len, device=device)
chunk_id = idx // chunk_size
same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1)
same_chunk = same_chunk.unsqueeze(0).unsqueeze(0)
attn_bias = same_chunk & mask
mask = torch.where(same_chunk, mask, torch.tensor(0.0, dtype=dtype, device=device))
attn_bias = torch.where(
mask, torch.tensor(0.0, dtype=dtype, device=device),
torch.tensor(MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE), dtype=dtype, device=device))
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as sliding-window else-branch: mask is a floating tensor (created from torch.full(..., dtype=dtype) and further modified with torch.where), but it’s used as the condition for torch.where, which must be boolean. This will fail at runtime for float dtypes. Build/keep mask as boolean (e.g., mask = torch.tril(torch.ones(..., dtype=torch.bool)) then mask &= same_chunk) and then generate the numeric bias from that boolean mask.

Copilot uses AI. Check for mistakes.
@yangulei
Copy link
Copy Markdown
Collaborator Author

Close as included in #1285

@yangulei yangulei closed this Apr 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants