-
Notifications
You must be signed in to change notification settings - Fork 133
Use Boolean attention mask and enable FusedSDPA slicing for long sequences #1149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b5a5fc1
b534667
5aaf0f5
0cf717e
e216a6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,8 +80,12 @@ def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, bat | |
| batch2block_matmul_op, block2batch_matmul_op): | ||
| # When fp32_softmax is enabled attn is left in fp32 after Q@K | ||
| # We can return to native dtype after we renormalize and calculate the adjustments | ||
| if block_bias is not None and attn.dtype != block_bias.dtype: | ||
| block_bias = block_bias.to(dtype=attn.dtype) | ||
| if block_bias is not None: | ||
| if block_bias.dtype == torch.bool: | ||
| # Convert boolean mask (True=valid, False=masked) to additive bias (0.0/-inf) | ||
| block_bias = torch.zeros_like(block_bias, dtype=attn.dtype).masked_fill_(~block_bias, float('-inf')) | ||
| elif attn.dtype != block_bias.dtype: | ||
| block_bias = block_bias.to(dtype=attn.dtype) | ||
| # TODO: w/a with 5D req as the block_softmax kernel does not support 4D attn tensor, which is used in e.g. Granite-3B | ||
| if get_config().fused_block_softmax and get_config().fused_block_softmax_adjustment and attn.dim() == 5: | ||
| attn, block_max, block_sums = torch.ops.hpu.block_softmax(attn, block_bias, block_groups) | ||
|
|
@@ -357,9 +361,12 @@ def _naive_prompt_attention(query: torch.Tensor, | |
| htcore.mark_step() | ||
| attn_weights.add_(position_bias) | ||
| if attn_bias is not None: | ||
| if attn_weights.dtype != attn_bias.dtype: | ||
| attn_bias = attn_bias.to(dtype=attn_weights.dtype) | ||
| attn_weights.add_(attn_bias) | ||
| if attn_bias.dtype == torch.bool: | ||
| attn_weights = attn_weights.masked_fill(~attn_bias, float("-inf")) | ||
| else: | ||
| if attn_weights.dtype != attn_bias.dtype: | ||
| attn_bias = attn_bias.to(dtype=attn_weights.dtype) | ||
| attn_weights.add_(attn_bias) | ||
|
Comment on lines
+364
to
+369
|
||
| if sinks is not None: | ||
| sink = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) | ||
| if query_heads != kv_heads: | ||
|
|
@@ -404,10 +411,6 @@ def _fsdpa_prompt_attention(query: torch.Tensor, | |
| recompute_mode = True | ||
| assert attn_bias is not None or valid_seq_lengths is not None, \ | ||
| 'Either attn_bias or valid_seq_lengths must be != None' | ||
| if is_causal and attn_bias is not None: | ||
| # TODO: causal + attn_bias is not yet supported | ||
| is_causal = False | ||
| valid_seq_lengths = None | ||
|
|
||
| args = [ | ||
| query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,12 +6,14 @@ | |||||||
| ############################################################################### | ||||||||
|
|
||||||||
| import os | ||||||||
| import math | ||||||||
| from functools import lru_cache, wraps | ||||||||
| from typing import Optional, Any | ||||||||
|
|
||||||||
| import habana_frameworks.torch as htorch | ||||||||
| import torch | ||||||||
| import itertools | ||||||||
| from vllm_gaudi.extension.logger import logger | ||||||||
|
|
||||||||
| from vllm_gaudi.extension.runtime import get_config | ||||||||
|
|
||||||||
|
|
@@ -155,6 +157,7 @@ def __init__(self, fusedSDPA): | |||||||
| super().__init__() | ||||||||
| assert fusedSDPA is not None, f'fusedSDPA kernel is None' | ||||||||
| self._hpu_kernel_fsdpa = fusedSDPA | ||||||||
| self.enable_slicing = self._setup_slicing() | ||||||||
|
|
||||||||
| def forward( | ||||||||
| self, | ||||||||
|
|
@@ -172,6 +175,20 @@ def forward( | |||||||
| window_size=None, | ||||||||
| sinks=None, | ||||||||
| ): | ||||||||
| bs = query.shape[0] | ||||||||
| q_len = query.shape[-2] | ||||||||
| kv_len = key.shape[-2] | ||||||||
| if (self.enable_slicing and kv_len >= self.slice_thld \ | ||||||||
| and bs == 1 # bs should be 1 for chunked prefill | ||||||||
| and q_len != kv_len # normal causal prefill route to the default dispatch for better performance | ||||||||
| and is_causal and attn_mask is not None # only supports causal attention with mask | ||||||||
| ): | ||||||||
| return self._sliced_fsdpa_fwd(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, | ||||||||
| recompute_mode, valid_sequence_lengths, padding_side) | ||||||||
| if is_causal and attn_mask is not None: | ||||||||
| # TODO: causal + attn_bias is not yet supported | ||||||||
| is_causal = False | ||||||||
| valid_sequence_lengths = None | ||||||||
|
Comment on lines
+188
to
+191
|
||||||||
| if window_size is not None: | ||||||||
| return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, | ||||||||
| recompute_mode, valid_sequence_lengths, padding_side, False, False, | ||||||||
|
|
@@ -181,6 +198,202 @@ def forward( | |||||||
| recompute_mode, valid_sequence_lengths, padding_side, False, False, | ||||||||
| (-1, -1), sinks) | ||||||||
|
|
||||||||
| def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, | ||||||||
| valid_sequence_lengths, padding_side): | ||||||||
| assert is_causal and attn_mask is not None | ||||||||
|
|
||||||||
| from habana_frameworks.torch.hpex.kernels.FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape | ||||||||
| gqa = is_gqa(query, key) | ||||||||
| if gqa: | ||||||||
| q, k, v, attn_mask = gqa_input_reshape_fwd(query, key, value, attn_mask) | ||||||||
| else: | ||||||||
| q, k, v, attn_mask = (query, key, value, attn_mask) | ||||||||
| q_len = q.shape[-2] | ||||||||
| kv_len = k.shape[-2] | ||||||||
| prefix_len = kv_len - q_len | ||||||||
|
|
||||||||
| chunk_outputs = [] | ||||||||
| num_q_chunks = math.ceil(q_len / self.chunk_size) | ||||||||
| num_prefix_chunks = math.ceil(prefix_len / self.chunk_size) | ||||||||
| for q_chunk_idx in range(num_q_chunks): | ||||||||
| q_start = q_len - (q_chunk_idx + 1) * self.chunk_size | ||||||||
| q_start = max(q_start, 0) | ||||||||
| q_end = q_len - q_chunk_idx * self.chunk_size | ||||||||
| q_chunk_size = q_end - q_start | ||||||||
| q_chunk = q[..., q_start:q_end, :].clone() if self.with_graph_breaks else q[..., q_start:q_end, :] | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've never seen ... operator, TIL. I trust it's needed here |
||||||||
|
|
||||||||
| last_out = None | ||||||||
| last_m = None | ||||||||
| last_linv = None | ||||||||
|
|
||||||||
| # the causal part | ||||||||
| for kv_chunk_idx in range(0, num_q_chunks - q_chunk_idx): | ||||||||
| kv_start = prefix_len + q_end - (kv_chunk_idx + 1) * self.chunk_size | ||||||||
| kv_start = max(kv_start, prefix_len) | ||||||||
| kv_end = prefix_len + q_end - kv_chunk_idx * self.chunk_size | ||||||||
| kv_chunk_size = kv_end - kv_start | ||||||||
| k_chunk = k[..., kv_start:kv_end, :] | ||||||||
| v_chunk = v[..., kv_start:kv_end, :] | ||||||||
|
|
||||||||
| is_causal_chunk = kv_chunk_idx == 0 and q_chunk_idx != 0 | ||||||||
| # chunk sizes must be multiples of 1024 to get valid m and linv | ||||||||
| is_causal_chunk = is_causal_chunk and q_chunk_size % 1024 == 0 and kv_chunk_size % 1024 == 0 | ||||||||
| # use mask only for the causal chunks that may have padding | ||||||||
| mask_chunk = attn_mask[ | ||||||||
| ..., q_start:q_end, | ||||||||
| kv_start:kv_end] if kv_chunk_idx < self.num_padded_query_chunks and not is_causal_chunk else None | ||||||||
|
|
||||||||
| if self.with_graph_breaks: | ||||||||
| k_chunk = k_chunk.clone() | ||||||||
| v_chunk = v_chunk.clone() | ||||||||
| mask_chunk = mask_chunk.clone() if mask_chunk is not None else None | ||||||||
| self.break_graph() | ||||||||
|
|
||||||||
| chunk_res = torch.ops.hpu.sdpa_recomp_fwd( | ||||||||
| q_chunk, | ||||||||
| k_chunk, | ||||||||
| v_chunk, | ||||||||
| mask_chunk, | ||||||||
| dropout_p, | ||||||||
| scale, | ||||||||
| is_causal_chunk, | ||||||||
| True, # requires_backward | ||||||||
| softmax_mode, | ||||||||
| None, # valid_seq_len | ||||||||
| padding_side, | ||||||||
| ) | ||||||||
| chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32) | ||||||||
| for x in (chunk_res[:3])) | ||||||||
|
Comment on lines
+252
to
+266
|
||||||||
|
|
||||||||
| if last_out is None or last_m is None or last_linv is None: | ||||||||
| last_out = chunk_out | ||||||||
| last_m = chunk_m | ||||||||
| last_linv = chunk_linv | ||||||||
| else: | ||||||||
| new_m = torch.maximum(last_m, chunk_m) | ||||||||
| last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) | ||||||||
| chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) | ||||||||
| last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) | ||||||||
| last_out = (last_linv_rescaled * last_linv) * last_out + (chunk_linv_rescaled * | ||||||||
| last_linv) * chunk_out | ||||||||
| last_m = new_m | ||||||||
|
|
||||||||
| if self.with_graph_breaks: | ||||||||
| self.break_graph() | ||||||||
|
|
||||||||
| # the context part | ||||||||
| for kv_chunk_idx in range(num_prefix_chunks): | ||||||||
| kv_start = prefix_len - (kv_chunk_idx + 1) * self.chunk_size | ||||||||
| kv_start = max(kv_start, 0) | ||||||||
| kv_end = prefix_len - kv_chunk_idx * self.chunk_size | ||||||||
| k_chunk = k[..., kv_start:kv_end, :] | ||||||||
| v_chunk = v[..., kv_start:kv_end, :] | ||||||||
| # use mask only for the chunks that may have padding | ||||||||
| mask_chunk = attn_mask[..., q_start:q_end, | ||||||||
| kv_start:kv_end] if kv_chunk_idx < self.num_padded_ctx_chunks else None | ||||||||
|
|
||||||||
| if self.with_graph_breaks: | ||||||||
| k_chunk = k_chunk.clone() | ||||||||
| v_chunk = v_chunk.clone() | ||||||||
| mask_chunk = mask_chunk.clone() if mask_chunk is not None else None | ||||||||
| self.break_graph() | ||||||||
|
|
||||||||
| chunk_res = torch.ops.hpu.sdpa_recomp_fwd( | ||||||||
| q_chunk, | ||||||||
| k_chunk, | ||||||||
| v_chunk, | ||||||||
| mask_chunk, | ||||||||
| dropout_p, | ||||||||
| scale, | ||||||||
| False, # is_causal | ||||||||
| True, # requires_backward | ||||||||
| softmax_mode, | ||||||||
| None, # valid_seq_len | ||||||||
| padding_side, | ||||||||
| ) | ||||||||
| chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32) | ||||||||
| for x in chunk_res[:3]) | ||||||||
|
|
||||||||
| assert not (last_out is None or last_m is None or last_linv is None) | ||||||||
| new_m = torch.maximum(last_m, chunk_m) | ||||||||
| last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) | ||||||||
| chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) | ||||||||
| last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) | ||||||||
| last_out = (last_linv_rescaled * last_linv) * last_out + (chunk_linv_rescaled * last_linv) * chunk_out | ||||||||
| last_m = new_m | ||||||||
|
|
||||||||
| if self.with_graph_breaks: | ||||||||
| self.break_graph() | ||||||||
| chunk_outputs.append(last_out) | ||||||||
| chunk_outputs = list(reversed(chunk_outputs)) | ||||||||
| output = torch.cat(chunk_outputs, dim=-2) | ||||||||
| return output.to(q.dtype) | ||||||||
|
|
||||||||
| def _setup_slicing(self) -> bool: | ||||||||
| from vllm_gaudi.extension.bucketing.common import get_bucketing_manager | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The imports inside are needed because they are not seen otherwise? If not we could place them outside. |
||||||||
| bucketing_manager = get_bucketing_manager() | ||||||||
| enable_slicing = bucketing_manager is not None | ||||||||
| if not enable_slicing: | ||||||||
| logger().warning('Bucketing manager is not instantiated, slicing in FSDPA will be disabled.') | ||||||||
| return False | ||||||||
| assert bucketing_manager is not None | ||||||||
| enable_slicing = enable_slicing and bucketing_manager.initialized | ||||||||
| if not enable_slicing: | ||||||||
| logger().warning('Bucketing manager is not initialized, slicing in FSDPA will be disabled.') | ||||||||
| return False | ||||||||
|
|
||||||||
| from vllm_gaudi.extension.bucketing.linear import LinearBucketingStrategy | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The imports inside are needed because they are not seen otherwise? If not we could place them outside. |
||||||||
| strategy = bucketing_manager.get_bucketing_strategy() | ||||||||
| enable_slicing = isinstance(strategy, LinearBucketingStrategy) | ||||||||
| if not enable_slicing: | ||||||||
| logger().debug('Not using Linear Bucketing Strategy, slicing in FSDPA will be disabled.') | ||||||||
| return False | ||||||||
|
|
||||||||
| max_num_batched_tokens = bucketing_manager.max_num_batched_tokens | ||||||||
| slice_thld_default = min(max_num_batched_tokens, 8192) | ||||||||
| slice_thld = int(os.getenv("VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD", str(slice_thld_default))) | ||||||||
| enable_slicing = enable_slicing and slice_thld >= slice_thld_default | ||||||||
| if not enable_slicing and slice_thld > 0: | ||||||||
| logger().warning('Invalid slice sequence length threshold, the threshold should be ' | ||||||||
| f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.') | ||||||||
| slice_thld = slice_thld_default | ||||||||
|
||||||||
| slice_thld = slice_thld_default | |
| slice_thld = slice_thld_default | |
| enable_slicing = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The imports inside are needed because they are not seen otherwise? If not we could place them outside.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The markdown table uses
||at the start of each row, which typically renders as an extra empty column (or may not render as intended depending on the markdown renderer). Use single|delimiters for standard GitHub-flavored markdown tables.