-
Notifications
You must be signed in to change notification settings - Fork 139
Use Boolean attention mask #1032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1925,7 +1925,6 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes, num_schedul | |||||||||||||||
| return all_batch_contents, num_pad_across_dp | ||||||||||||||||
|
|
||||||||||||||||
| def _make_attn_bias(self, context_groups, token_groups): | ||||||||||||||||
| dtype = self.dtype | ||||||||||||||||
| is_causal = True # TODO: add support for non-causal tasks | ||||||||||||||||
| context_groups = torch.tensor(context_groups, device='cpu', dtype=torch.int16) | ||||||||||||||||
| context_groups = context_groups.repeat_interleave(self.block_size, dim=-1) | ||||||||||||||||
|
|
@@ -1938,7 +1937,7 @@ def _make_attn_bias(self, context_groups, token_groups): | |||||||||||||||
| causal_mask = torch.ones(num_queries, num_queries, device='cpu', dtype=torch.bool) | ||||||||||||||||
| causal_mask = torch.triu(causal_mask, diagonal=1).unsqueeze(0) | ||||||||||||||||
| attn_mask[:, :, context_len:].logical_or_(causal_mask) | ||||||||||||||||
| attn_mask = attn_mask.to(dtype).masked_fill_(attn_mask, -math.inf) | ||||||||||||||||
| attn_mask = ~attn_mask | ||||||||||||||||
|
|
||||||||||||||||
| return attn_mask.unflatten(0, (1, -1)) | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -6114,7 +6113,7 @@ def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, | |||||||||||||||
| diagonal=1) | ||||||||||||||||
| mask = causal_mask.logical_or(len_mask) | ||||||||||||||||
| mask = torch.concat((past_mask, mask), dim=-1) | ||||||||||||||||
| attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) | ||||||||||||||||
| attn_bias = ~mask | ||||||||||||||||
|
yangulei marked this conversation as resolved.
|
||||||||||||||||
| attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", attn_bias=attn_bias) | ||||||||||||||||
| return attn_metadata | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -6172,16 +6171,13 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV | |||||||||||||||
| # seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len)) | ||||||||||||||||
| # causal_mask = causal_mask.logical_and(len_mask) | ||||||||||||||||
|
|
||||||||||||||||
| mask = torch.concat((past_mask, causal_mask), dim=-1) | ||||||||||||||||
| attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), | ||||||||||||||||
| torch.tensor(float('-inf'), dtype=dtype, device=device)) | ||||||||||||||||
| attn_bias = torch.concat((past_mask, causal_mask), dim=-1) | ||||||||||||||||
|
||||||||||||||||
| else: | ||||||||||||||||
| # CAUSAL MASK without removing padding (CAUSAL+sliding window) | ||||||||||||||||
| # removing padding cause accuracy issue for images input | ||||||||||||||||
| tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) | ||||||||||||||||
| tensor = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) | ||||||||||||||||
| mask = torch.tril(tensor, diagonal=shift) | ||||||||||||||||
| mask = torch.triu(mask, diagonal=shift - window_size + 1) | ||||||||||||||||
| attn_bias = torch.log(mask) | ||||||||||||||||
| attn_bias = torch.triu(mask, diagonal=shift - window_size + 1) | ||||||||||||||||
|
|
||||||||||||||||
|
||||||||||||||||
| # Convert boolean mask to numeric attention bias: 0.0 for allowed positions, -inf for masked. | |
| if attn_bias.dtype == torch.bool: | |
| zero = torch.zeros(1, dtype=dtype, device=device) | |
| neg_inf = torch.full((1,), float("-inf"), dtype=dtype, device=device) | |
| attn_bias = torch.where(attn_bias, zero, neg_inf) |
Copilot
AI
Feb 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The boolean mask for chunked attention will be incorrectly converted when added to attention scores. The concatenated past_mask and causal_mask are boolean tensors that will be converted to float (True→1.0, False→0.0) when used in operations that expect numeric bias values (0.0 for allowed, -inf for masked).
Copilot
AI
Feb 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The boolean mask result from the AND operation (same_chunk & mask) will be incorrectly used as attention bias. When this boolean tensor is used in attention operations that add the bias to attention scores, True will become 1.0 and False will become 0.0, which is incorrect. The old code correctly used torch.where with explicit 0.0 and -inf values, then applied torch.log to get the proper bias values.
Copilot
AI
Feb 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The boolean attention bias cannot be correctly added to attention scores in decode operations. In pipelined_pa (ops.py line 82-83), block_bias is added to attention weights. Converting a boolean tensor to float produces 1.0 for True and 0.0 for False, which is incorrect for attention masking that requires 0.0 for allowed positions and -inf for masked positions.
| attn_bias = mask < block_usage.unsqueeze(-1) | |
| valid_positions = mask < block_usage.unsqueeze(-1) | |
| # Convert boolean mask to additive attention bias: 0.0 for allowed, -inf for masked | |
| attn_bias = torch.zeros_like(valid_positions, dtype=dtype) | |
| attn_bias.masked_fill_(~valid_positions, float("-inf")) |
Uh oh!
There was an error while loading. Please reload this page.