diff --git a/docs/configuration/env_variables.md b/docs/configuration/env_variables.md index 0b8114fc57..7300aa0c90 100644 --- a/docs/configuration/env_variables.md +++ b/docs/configuration/env_variables.md @@ -106,3 +106,20 @@ and warm-up. Recommended settings for this case are: !!! note If the model config specifies a high `max_model_len`, set it to the sum of `input_tokens` and `output_tokens`, rounded up to a multiple of `block_size` according to actual requirements. + +## Additional Performance Tuning Parameters for the FusedSDPA Kernel with Linear Bucketing + +FusedSDPA can be split into smaller chunks to improve performance by: + +- fitting smaller chunks into SRAM, +- improving TPC/MME pipelining, +- reducing attention-mask usage. + +| Parameter name | Description | Default value | +| ---------------------------------------- | -------------------------------------------------------------------------------------------- | ------------------------------------------ | +| `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD` | KV length threshold above which slicing is applied. Set to `-1` to disable slicing. | `min(max_num_batched_tokens, 8192)` | +| `VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE` | Chunk size for `q_len` and `kv_len` in each chunk. Rounded up to the next multiple of 1024. | `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD // 2` | +| `VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS` | Places each chunk in a separate graph to reduce compilation time. | `true` for lazy mode and `false` otherwise | + +!!! note + These parameters are effective only with the linear bucketing strategy. diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index d630bf1a38..bb87415e94 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -55,6 +55,11 @@ def get_user_flags(): Env('PT_HPU_SDPA_QKV_SLICE_MODE_FWD', boolean), Env('PT_HPU_SDPA_BC_FACTOR', int), Env('VLLM_FUSEDSDPA_SLIDE_THLD', int), + + # FusedSDPA slicing flags + Env('VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD', int), + Env('VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE', int), + Env('VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS', boolean), ] return to_dict(flags) diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index c2cc3a368d..39ebf0192c 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -399,10 +399,6 @@ def _fsdpa_prompt_attention(query: torch.Tensor, recompute_mode = True assert attn_bias is not None or valid_seq_lengths is not None, \ 'Either attn_bias or valid_seq_lengths must be != None' - if is_causal and attn_bias is not None: - # TODO: causal + attn_bias is not yet supported - is_causal = False - valid_seq_lengths = None args = [ query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths, diff --git a/vllm_gaudi/extension/utils.py b/vllm_gaudi/extension/utils.py index 47fe5deb49..df148466d8 100644 --- a/vllm_gaudi/extension/utils.py +++ b/vllm_gaudi/extension/utils.py @@ -6,12 +6,14 @@ ############################################################################### import os +import math from functools import lru_cache, wraps from typing import Optional, Any import habana_frameworks.torch as htorch import torch import itertools +from vllm_gaudi.extension.logger import logger from vllm_gaudi.extension.runtime import get_config @@ -155,6 +157,7 @@ def __init__(self, fusedSDPA): super().__init__() assert fusedSDPA is not None, f'fusedSDPA kernel is None' self._hpu_kernel_fsdpa = fusedSDPA + self.enable_slicing = self._setup_slicing() def forward( self, @@ -172,6 +175,20 @@ def forward( window_size=None, sinks=None, ): + bs = query.shape[0] + q_len = query.shape[-2] + kv_len = key.shape[-2] + if (self.enable_slicing and kv_len >= self.slice_thld \ + and bs == 1 # bs should be 1 for chunked prefill + and q_len != kv_len # normal causal prefill route to the default dispatch for better performance + and is_causal and attn_mask is not None # only supports causal attention with mask + ): + return self._sliced_fsdpa_fwd(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, + recompute_mode, valid_sequence_lengths, padding_side) + if is_causal and attn_mask is not None: + # TODO: causal + attn_bias is not yet supported + is_causal = False + valid_sequence_lengths = None if window_size is not None: return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side, False, False, @@ -181,6 +198,204 @@ def forward( recompute_mode, valid_sequence_lengths, padding_side, False, False, (-1, -1), sinks) + def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, + valid_sequence_lengths, padding_side): + assert is_causal and attn_mask is not None + + from habana_frameworks.torch.hpex.kernels.FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape + gqa = is_gqa(query, key) + if gqa: + q, k, v, attn_mask = gqa_input_reshape_fwd(query, key, value, attn_mask) + else: + q, k, v, attn_mask = (query, key, value, attn_mask) + q_len = q.shape[-2] + kv_len = k.shape[-2] + prefix_len = kv_len - q_len + + chunk_outputs = [] + num_q_chunks = math.ceil(q_len / self.chunk_size) + num_prefix_chunks = math.ceil(prefix_len / self.chunk_size) + for q_chunk_idx in range(num_q_chunks): + q_start = q_len - (q_chunk_idx + 1) * self.chunk_size + q_start = max(q_start, 0) + q_end = q_len - q_chunk_idx * self.chunk_size + q_chunk_size = q_end - q_start + q_chunk = q[..., q_start:q_end, :].clone() if self.with_graph_breaks else q[..., q_start:q_end, :] + + last_out = None + last_m = None + last_linv = None + + # the causal part + for kv_chunk_idx in range(0, num_q_chunks - q_chunk_idx): + kv_start = prefix_len + q_end - (kv_chunk_idx + 1) * self.chunk_size + kv_start = max(kv_start, prefix_len) + kv_end = prefix_len + q_end - kv_chunk_idx * self.chunk_size + kv_chunk_size = kv_end - kv_start + k_chunk = k[..., kv_start:kv_end, :] + v_chunk = v[..., kv_start:kv_end, :] + + is_causal_chunk = kv_chunk_idx == 0 and q_chunk_idx != 0 + # chunk sizes must be multiples of 1024 to get valid m and linv + is_causal_chunk = is_causal_chunk and q_chunk_size % 1024 == 0 and kv_chunk_size % 1024 == 0 + # use mask only for the causal chunks that may have padding + mask_chunk = attn_mask[ + ..., q_start:q_end, + kv_start:kv_end] if kv_chunk_idx < self.num_padded_query_chunks and not is_causal_chunk else None + + if self.with_graph_breaks: + # break_graph() cannot break the tensor slicing, use clone to isolate the graph + k_chunk = k_chunk.clone() + v_chunk = v_chunk.clone() + mask_chunk = mask_chunk.clone() if mask_chunk is not None else None + self.break_graph() + + chunk_res = torch.ops.hpu.sdpa_recomp_fwd( + q_chunk, + k_chunk, + v_chunk, + mask_chunk, + dropout_p, + scale, + is_causal_chunk, + True, # requires_backward + softmax_mode, + None, # valid_seq_len + padding_side, + ) + chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32) + for x in (chunk_res[:3])) + + if last_out is None or last_m is None or last_linv is None: + last_out = chunk_out + last_m = chunk_m + last_linv = chunk_linv + else: + new_m = torch.maximum(last_m, chunk_m) + last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) + chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) + last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) + last_out = (last_linv_rescaled * last_linv) * last_out + (chunk_linv_rescaled * + last_linv) * chunk_out + last_m = new_m + + if self.with_graph_breaks: + self.break_graph() + + # the context part + for kv_chunk_idx in range(num_prefix_chunks): + kv_start = prefix_len - (kv_chunk_idx + 1) * self.chunk_size + kv_start = max(kv_start, 0) + kv_end = prefix_len - kv_chunk_idx * self.chunk_size + k_chunk = k[..., kv_start:kv_end, :] + v_chunk = v[..., kv_start:kv_end, :] + # use mask only for the chunks that may have padding + mask_chunk = attn_mask[..., q_start:q_end, + kv_start:kv_end] if kv_chunk_idx < self.num_padded_ctx_chunks else None + + if self.with_graph_breaks: + # break_graph() cannot break the tensor slicing, use clone to isolate the graph + k_chunk = k_chunk.clone() + v_chunk = v_chunk.clone() + mask_chunk = mask_chunk.clone() if mask_chunk is not None else None + self.break_graph() + + chunk_res = torch.ops.hpu.sdpa_recomp_fwd( + q_chunk, + k_chunk, + v_chunk, + mask_chunk, + dropout_p, + scale, + False, # is_causal + True, # requires_backward + softmax_mode, + None, # valid_seq_len + padding_side, + ) + chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32) + for x in chunk_res[:3]) + + assert not (last_out is None or last_m is None or last_linv is None) + new_m = torch.maximum(last_m, chunk_m) + last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) + chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) + last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) + last_out = (last_linv_rescaled * last_linv) * last_out + (chunk_linv_rescaled * last_linv) * chunk_out + last_m = new_m + + if self.with_graph_breaks: + self.break_graph() + chunk_outputs.append(last_out) + chunk_outputs = list(reversed(chunk_outputs)) + output = torch.cat(chunk_outputs, dim=-2) + return output.to(q.dtype) + + def _setup_slicing(self) -> bool: + from vllm_gaudi.extension.bucketing.common import get_bucketing_manager + bucketing_manager = get_bucketing_manager() + enable_slicing = bucketing_manager is not None + if not enable_slicing: + logger().warning('Bucketing manager is not instantiated, slicing in FSDPA will be disabled.') + return False + assert bucketing_manager is not None + enable_slicing = enable_slicing and bucketing_manager.initialized + if not enable_slicing: + logger().warning('Bucketing manager is not initialized, slicing in FSDPA will be disabled.') + return False + + from vllm_gaudi.extension.bucketing.linear import LinearBucketingStrategy + strategy = bucketing_manager.get_bucketing_strategy() + enable_slicing = isinstance(strategy, LinearBucketingStrategy) + if not enable_slicing: + logger().debug('Not using Linear Bucketing Strategy, slicing in FSDPA will be disabled.') + return False + + max_num_batched_tokens = bucketing_manager.max_num_batched_tokens + slice_thld_default = min(max_num_batched_tokens, 8192) + slice_thld = int(os.getenv("VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD", str(slice_thld_default))) + enable_slicing = enable_slicing and slice_thld >= slice_thld_default + if not enable_slicing and slice_thld > 0: + logger().warning('Invalid slice sequence length threshold, the threshold should be ' + f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.') + slice_thld = slice_thld_default + + if enable_slicing: + # default to half of the threshold and round up by 1024 + chunk_size_default = math.ceil(slice_thld // 2 / 1024) * 1024 + chunk_size = int(os.getenv("VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE", str(chunk_size_default))) + block_size = bucketing_manager.block_size + if chunk_size < block_size or chunk_size > slice_thld: + logger().warning(f'Invalid chunk size for FusedSDPA slicing, the chunk size should be between ' + f'{block_size} and {slice_thld}, falling back to default {chunk_size_default}.') + chunk_size = chunk_size_default + if chunk_size % 1024 != 0: + chunk_size = math.ceil(chunk_size / 1024) * 1024 + logger().warning('Rounded up the chunk size for FusedSDPA slicing to the next multiple of 1024.') + self.slice_thld = slice_thld + self.chunk_size = chunk_size + max_query_pad_default = math.ceil(max_num_batched_tokens / 4) + max_query_pad = int(os.getenv("VLLM_PROMPT_QUERY_BUCKET_PAD_MAX", str(max_query_pad_default))) + self.num_padded_query_chunks = math.ceil(max_query_pad / self.chunk_size) + max_ctx_pad_default = math.ceil(max_num_batched_tokens / block_size) + max_ctx_pad = int(os.getenv("VLLM_PROMPT_CTX_BUCKET_PAD_MAX", str(max_ctx_pad_default))) + self.num_padded_ctx_chunks = math.ceil(max_ctx_pad * block_size / self.chunk_size) + + import habana_frameworks.torch as ht + is_lazy = ht.utils.internal.is_lazy() + self.with_graph_breaks = os.getenv("VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS", + str(is_lazy)).strip().lower() in ("1", "true") + if self.with_graph_breaks: + if is_lazy: + self.break_graph = ht.core.mark_step + else: + self.break_graph = torch._dynamo.graph_break + msg = (f"FusedSDPA slicing is enabled with sequence length threshold {slice_thld}, " + f"chunk size {self.chunk_size}, num padded query chunks {self.num_padded_query_chunks}, " + f"num padded ctx chunks {self.num_padded_ctx_chunks}, with graph breaks {self.with_graph_breaks}.") + logger().debug(msg) + return enable_slicing + class ModuleFP8FusedSDPA(torch.nn.Module):