Use finite numbers for the attention mask#1290
Conversation
This reverts commit 9271c08.
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
There was a problem hiding this comment.
Pull request overview
This PR aims to address fp8 FusedSDPA accuracy issues by reverting the boolean attention-mask approach (from #1032) and consistently using additive attention biases with large finite negative values.
Changes:
- Introduces a dtype→finite-mask-value mapping (
MASK_VALUES) and uses it to build additive attention biases instead of boolean masks. - Updates multiple attention-bias construction paths (prefill/paged, sliding-window, chunked, block-mapping) to emit numeric biases.
- Simplifies
_naive_prompt_attentionto always cast-and-addattn_bias(removing the boolean-mask special case).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
vllm_gaudi/v1/worker/hpu_model_runner.py |
Switches attention bias/mask generation to finite numeric additive biases across several metadata paths; adds global mask constants. |
vllm_gaudi/extension/ops.py |
Removes boolean-mask handling in naive attention and always treats attn_bias as additive. |
| MASK_VALUES = {torch.float32: -3E38, torch.bfloat16: -3E38, torch.float16: -6E4} | ||
| DEFAULT_MASK_VALUE = -3E38 |
There was a problem hiding this comment.
MASK_VALUES doesn’t cover fp8 dtypes (e.g., torch.float8_e4m3fn). For fp8 models, the fallback DEFAULT_MASK_VALUE (-3e38) will overflow when cast to fp8 (likely becoming -inf), defeating the PR goal of using finite mask values for fp8 FusedSDPA. Consider deriving the mask value from torch.finfo(dtype).min (or adding explicit fp8 entries) so the value is finite and representable for every dtype you pass into the attention bias builders.
| attn_mask = attn_mask.to(dtype).masked_fill_(attn_mask, MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE)) | ||
|
|
||
| return attn_mask.unflatten(0, (1, -1)) |
There was a problem hiding this comment.
_make_attn_bias materializes the bias tensor on CPU using dtype = self.dtype via attn_mask.to(dtype). If self.dtype can be an fp8 dtype, this conversion will either be unsupported on CPU or will quantize/overflow the mask values before the tensor is moved to HPU. A safer approach is to keep the CPU mask boolean and create/convert the numeric bias after transferring to HPU (or build it in a CPU-supported dtype like bf16/float32 and cast on-device).
| attn_mask = attn_mask.to(dtype).masked_fill_(attn_mask, MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE)) | |
| return attn_mask.unflatten(0, (1, -1)) | |
| # Avoid materializing the bias tensor on CPU in an fp8 dtype. Build the | |
| # numeric bias in a CPU-supported floating dtype and rely on downstream | |
| # device-side casting when needed. | |
| fp8_dtypes = {getattr(torch, name) for name in ("float8_e4m3fn", "float8_e5m2") if hasattr(torch, name)} | |
| cpu_bias_dtype = torch.bfloat16 if dtype in fp8_dtypes else dtype | |
| attn_bias = attn_mask.to(cpu_bias_dtype) | |
| attn_bias.masked_fill_(attn_mask, MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE)) | |
| return attn_bias.unflatten(0, (1, -1)) |
| tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) | ||
| mask = torch.tril(tensor, diagonal=shift) | ||
| attn_bias = torch.triu(mask, diagonal=shift - window_size + 1) | ||
| mask = torch.triu(mask, diagonal=shift - window_size + 1) | ||
| attn_bias = torch.where( | ||
| mask, torch.tensor(0.0, dtype=dtype, device=device), | ||
| torch.tensor(MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE), dtype=dtype, device=device)) |
There was a problem hiding this comment.
In this branch, tensor/mask are created with dtype=dtype (floating), but torch.where requires a boolean condition tensor. With typical dtype values (fp16/bf16/fp32), this will raise at runtime. Keep the mask boolean (e.g., build tensor with dtype=torch.bool and use logical ops), then produce the numeric bias via masked_fill/where using that boolean mask.
| tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) | ||
| mask = torch.tril(tensor, diagonal=shift) | ||
| idx = torch.arange(seq_len, device=device) | ||
| chunk_id = idx // chunk_size | ||
| same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1) | ||
| same_chunk = same_chunk.unsqueeze(0).unsqueeze(0) | ||
| attn_bias = same_chunk & mask | ||
| mask = torch.where(same_chunk, mask, torch.tensor(0.0, dtype=dtype, device=device)) | ||
| attn_bias = torch.where( | ||
| mask, torch.tensor(0.0, dtype=dtype, device=device), | ||
| torch.tensor(MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE), dtype=dtype, device=device)) |
There was a problem hiding this comment.
Same issue as sliding-window else-branch: mask is a floating tensor (created from torch.full(..., dtype=dtype) and further modified with torch.where), but it’s used as the condition for torch.where, which must be boolean. This will fail at runtime for float dtypes. Build/keep mask as boolean (e.g., mask = torch.tril(torch.ones(..., dtype=torch.bool)) then mask &= same_chunk) and then generate the numeric bias from that boolean mask.
|
Close as included in #1285 |
Revert "Use Boolean attention mask (#1032)" and use finite numbers for attention mask to solve the accuracy issue for fp8 FusedSDPA.