-
Notifications
You must be signed in to change notification settings - Fork 129
Enable slicing for the BF16 FusedSDPA #1034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,12 +6,14 @@ | |
| ############################################################################### | ||
|
|
||
| import os | ||
| import math | ||
| from functools import lru_cache, wraps | ||
| from typing import Optional, Any | ||
|
|
||
| import habana_frameworks.torch as htorch | ||
| import torch | ||
| import itertools | ||
| from vllm_gaudi.extension.logger import logger | ||
|
|
||
| from vllm_gaudi.extension.runtime import get_config | ||
|
|
||
|
|
@@ -155,6 +157,7 @@ def __init__(self, fusedSDPA): | |
| super().__init__() | ||
| assert fusedSDPA is not None, f'fusedSDPA kernel is None' | ||
| self._hpu_kernel_fsdpa = fusedSDPA | ||
| self.enable_slicing = self._setup_slicing() | ||
|
|
||
| def forward( | ||
| self, | ||
|
|
@@ -172,6 +175,20 @@ def forward( | |
| window_size=None, | ||
| sinks=None, | ||
| ): | ||
| bs = query.shape[0] | ||
| q_len = query.shape[-2] | ||
| kv_len = key.shape[-2] | ||
| if (self.enable_slicing and kv_len >= self.slice_thld \ | ||
| and bs == 1 # bs should be 1 for chunked prefill | ||
| and q_len != kv_len # normal causal prefill route to the default dispatch for better performance | ||
| and is_causal and attn_mask is not None # only supports causal attention with mask | ||
| ): | ||
| return self._sliced_fsdpa_fwd(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, | ||
| recompute_mode, valid_sequence_lengths, padding_side) | ||
| if is_causal and attn_mask is not None: | ||
| # TODO: causal + attn_bias is not yet supported | ||
| is_causal = False | ||
| valid_sequence_lengths = None | ||
| if window_size is not None: | ||
| return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, | ||
| recompute_mode, valid_sequence_lengths, padding_side, False, False, | ||
|
|
@@ -181,6 +198,204 @@ def forward( | |
| recompute_mode, valid_sequence_lengths, padding_side, False, False, | ||
| (-1, -1), sinks) | ||
|
|
||
| def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, | ||
| valid_sequence_lengths, padding_side): | ||
| assert is_causal and attn_mask is not None | ||
|
|
||
| from habana_frameworks.torch.hpex.kernels.FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape | ||
| gqa = is_gqa(query, key) | ||
| if gqa: | ||
| q, k, v, attn_mask = gqa_input_reshape_fwd(query, key, value, attn_mask) | ||
| else: | ||
| q, k, v, attn_mask = (query, key, value, attn_mask) | ||
| q_len = q.shape[-2] | ||
| kv_len = k.shape[-2] | ||
| prefix_len = kv_len - q_len | ||
|
|
||
| chunk_outputs = [] | ||
| num_q_chunks = math.ceil(q_len / self.chunk_size) | ||
| num_prefix_chunks = math.ceil(prefix_len / self.chunk_size) | ||
| for q_chunk_idx in range(num_q_chunks): | ||
| q_start = q_len - (q_chunk_idx + 1) * self.chunk_size | ||
| q_start = max(q_start, 0) | ||
| q_end = q_len - q_chunk_idx * self.chunk_size | ||
| q_chunk_size = q_end - q_start | ||
| q_chunk = q[..., q_start:q_end, :].clone() if self.with_graph_breaks else q[..., q_start:q_end, :] | ||
|
|
||
| last_out = None | ||
| last_m = None | ||
| last_linv = None | ||
|
|
||
| # the causal part | ||
| for kv_chunk_idx in range(0, num_q_chunks - q_chunk_idx): | ||
| kv_start = prefix_len + q_end - (kv_chunk_idx + 1) * self.chunk_size | ||
| kv_start = max(kv_start, prefix_len) | ||
| kv_end = prefix_len + q_end - kv_chunk_idx * self.chunk_size | ||
| kv_chunk_size = kv_end - kv_start | ||
| k_chunk = k[..., kv_start:kv_end, :] | ||
| v_chunk = v[..., kv_start:kv_end, :] | ||
|
|
||
| is_causal_chunk = kv_chunk_idx == 0 and q_chunk_idx != 0 | ||
| # chunk sizes must be multiples of 1024 to get valid m and linv | ||
| is_causal_chunk = is_causal_chunk and q_chunk_size % 1024 == 0 and kv_chunk_size % 1024 == 0 | ||
| # use mask only for the causal chunks that may have padding | ||
| mask_chunk = attn_mask[ | ||
| ..., q_start:q_end, | ||
| kv_start:kv_end] if kv_chunk_idx < self.num_padded_query_chunks and not is_causal_chunk else None | ||
|
|
||
| if self.with_graph_breaks: | ||
| # break_graph() cannot break the tensor slicing, use clone to isolate the graph | ||
| k_chunk = k_chunk.clone() | ||
| v_chunk = v_chunk.clone() | ||
| mask_chunk = mask_chunk.clone() if mask_chunk is not None else None | ||
| self.break_graph() | ||
|
|
||
| chunk_res = torch.ops.hpu.sdpa_recomp_fwd( | ||
| q_chunk, | ||
| k_chunk, | ||
| v_chunk, | ||
| mask_chunk, | ||
| dropout_p, | ||
| scale, | ||
| is_causal_chunk, | ||
| True, # requires_backward | ||
| softmax_mode, | ||
| None, # valid_seq_len | ||
| padding_side, | ||
| ) | ||
| chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32) | ||
| for x in (chunk_res[:3])) | ||
|
|
||
| if last_out is None or last_m is None or last_linv is None: | ||
| last_out = chunk_out | ||
| last_m = chunk_m | ||
| last_linv = chunk_linv | ||
| else: | ||
| new_m = torch.maximum(last_m, chunk_m) | ||
| last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) | ||
| chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) | ||
| last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) | ||
| last_out = (last_linv_rescaled * last_linv) * last_out + (chunk_linv_rescaled * | ||
| last_linv) * chunk_out | ||
| last_m = new_m | ||
|
|
||
| if self.with_graph_breaks: | ||
| self.break_graph() | ||
|
|
||
| # the context part | ||
| for kv_chunk_idx in range(num_prefix_chunks): | ||
| kv_start = prefix_len - (kv_chunk_idx + 1) * self.chunk_size | ||
| kv_start = max(kv_start, 0) | ||
| kv_end = prefix_len - kv_chunk_idx * self.chunk_size | ||
| k_chunk = k[..., kv_start:kv_end, :] | ||
| v_chunk = v[..., kv_start:kv_end, :] | ||
| # use mask only for the chunks that may have padding | ||
| mask_chunk = attn_mask[..., q_start:q_end, | ||
| kv_start:kv_end] if kv_chunk_idx < self.num_padded_ctx_chunks else None | ||
|
|
||
| if self.with_graph_breaks: | ||
| # break_graph() cannot break the tensor slicing, use clone to isolate the graph | ||
| k_chunk = k_chunk.clone() | ||
| v_chunk = v_chunk.clone() | ||
| mask_chunk = mask_chunk.clone() if mask_chunk is not None else None | ||
| self.break_graph() | ||
|
|
||
| chunk_res = torch.ops.hpu.sdpa_recomp_fwd( | ||
| q_chunk, | ||
| k_chunk, | ||
| v_chunk, | ||
| mask_chunk, | ||
| dropout_p, | ||
| scale, | ||
| False, # is_causal | ||
| True, # requires_backward | ||
| softmax_mode, | ||
| None, # valid_seq_len | ||
| padding_side, | ||
| ) | ||
| chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32) | ||
| for x in chunk_res[:3]) | ||
|
|
||
| assert not (last_out is None or last_m is None or last_linv is None) | ||
| new_m = torch.maximum(last_m, chunk_m) | ||
| last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) | ||
| chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) | ||
| last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) | ||
| last_out = (last_linv_rescaled * last_linv) * last_out + (chunk_linv_rescaled * last_linv) * chunk_out | ||
| last_m = new_m | ||
|
|
||
| if self.with_graph_breaks: | ||
| self.break_graph() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as above. |
||
| chunk_outputs.append(last_out) | ||
| chunk_outputs = list(reversed(chunk_outputs)) | ||
| output = torch.cat(chunk_outputs, dim=-2) | ||
| return output.to(q.dtype) | ||
|
|
||
| def _setup_slicing(self) -> bool: | ||
| from vllm_gaudi.extension.bucketing.common import get_bucketing_manager | ||
| bucketing_manager = get_bucketing_manager() | ||
| enable_slicing = bucketing_manager is not None | ||
| if not enable_slicing: | ||
| logger().warning('Bucketing manager is not instantiated, slicing in FSDPA will be disabled.') | ||
| return False | ||
| assert bucketing_manager is not None | ||
| enable_slicing = enable_slicing and bucketing_manager.initialized | ||
| if not enable_slicing: | ||
| logger().warning('Bucketing manager is not initialized, slicing in FSDPA will be disabled.') | ||
| return False | ||
|
|
||
| from vllm_gaudi.extension.bucketing.linear import LinearBucketingStrategy | ||
| strategy = bucketing_manager.get_bucketing_strategy() | ||
| enable_slicing = isinstance(strategy, LinearBucketingStrategy) | ||
| if not enable_slicing: | ||
| logger().debug('Not using Linear Bucketing Strategy, slicing in FSDPA will be disabled.') | ||
|
czhu15 marked this conversation as resolved.
|
||
| return False | ||
|
|
||
| max_num_batched_tokens = bucketing_manager.max_num_batched_tokens | ||
| slice_thld_default = min(max_num_batched_tokens, 8192) | ||
| slice_thld = int(os.getenv("VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD", str(slice_thld_default))) | ||
| enable_slicing = enable_slicing and slice_thld >= slice_thld_default | ||
| if not enable_slicing and slice_thld > 0: | ||
| logger().warning('Invalid slice sequence length threshold, the threshold should be ' | ||
| f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.') | ||
| slice_thld = slice_thld_default | ||
|
|
||
| if enable_slicing: | ||
| # default to half of the threshold and round up by 1024 | ||
| chunk_size_default = math.ceil(slice_thld // 2 / 1024) * 1024 | ||
| chunk_size = int(os.getenv("VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE", str(chunk_size_default))) | ||
| block_size = bucketing_manager.block_size | ||
| if chunk_size < block_size or chunk_size > slice_thld: | ||
| logger().warning(f'Invalid chunk size for FusedSDPA slicing, the chunk size should be between ' | ||
| f'{block_size} and {slice_thld}, falling back to default {chunk_size_default}.') | ||
| chunk_size = chunk_size_default | ||
| if chunk_size % 1024 != 0: | ||
| chunk_size = math.ceil(chunk_size / 1024) * 1024 | ||
| logger().warning('Rounded up the chunk size for FusedSDPA slicing to the next multiple of 1024.') | ||
| self.slice_thld = slice_thld | ||
| self.chunk_size = chunk_size | ||
| max_query_pad_default = math.ceil(max_num_batched_tokens / 4) | ||
| max_query_pad = int(os.getenv("VLLM_PROMPT_QUERY_BUCKET_PAD_MAX", str(max_query_pad_default))) | ||
| self.num_padded_query_chunks = math.ceil(max_query_pad / self.chunk_size) | ||
| max_ctx_pad_default = math.ceil(max_num_batched_tokens / block_size) | ||
| max_ctx_pad = int(os.getenv("VLLM_PROMPT_CTX_BUCKET_PAD_MAX", str(max_ctx_pad_default))) | ||
| self.num_padded_ctx_chunks = math.ceil(max_ctx_pad * block_size / self.chunk_size) | ||
|
|
||
| import habana_frameworks.torch as ht | ||
| is_lazy = ht.utils.internal.is_lazy() | ||
| self.with_graph_breaks = os.getenv("VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS", | ||
| str(is_lazy)).strip().lower() in ("1", "true") | ||
| if self.with_graph_breaks: | ||
| if is_lazy: | ||
| self.break_graph = ht.core.mark_step | ||
| else: | ||
| self.break_graph = torch._dynamo.graph_break | ||
| msg = (f"FusedSDPA slicing is enabled with sequence length threshold {slice_thld}, " | ||
| f"chunk size {self.chunk_size}, num padded query chunks {self.num_padded_query_chunks}, " | ||
| f"num padded ctx chunks {self.num_padded_ctx_chunks}, with graph breaks {self.with_graph_breaks}.") | ||
| logger().debug(msg) | ||
| return enable_slicing | ||
|
|
||
|
|
||
| class ModuleFP8FusedSDPA(torch.nn.Module): | ||
|
|
||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.