diff --git a/docs/configuration/env_variables.md b/docs/configuration/env_variables.md index 184c7e8474..680a530de1 100644 --- a/docs/configuration/env_variables.md +++ b/docs/configuration/env_variables.md @@ -94,3 +94,20 @@ and warm-up. Recommended settings for this case are: !!! note If the model config specifies a high `max_model_len`, set it to the sum of `input_tokens` and `output_tokens`, rounded up to a multiple of `block_size` according to actual requirements. + +## Additional Performance Tuning Parameters for the FusedSDPA Kernel with Linear Bucketing + +FusedSDPA can be split into smaller chunks to improve performance by: + +- fitting smaller chunks into SRAM, +- improving TPC/MME pipelining, +- reducing attention-mask usage. + +| Parameter name | Description | Default value | +| ---------------------------------------- | -------------------------------------------------------------------------------------------- | ------------------------------------------ | +| `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD` | KV length threshold above which slicing is applied. Set to `-1` to disable slicing. | `min(max_num_batched_tokens, 8192)` | +| `VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE` | Chunk size for `q_len` and `kv_len` in each chunk. Rounded up to the next multiple of 1024. | `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD // 2` | +| `VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS` | Places each chunk in a separate graph to reduce compilation time. | `true` for lazy mode and `false` otherwise | + +!!! note + These parameters are effective only with the linear bucketing strategy. diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index cc6fa3f266..d26d12040b 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -44,6 +44,11 @@ def get_user_flags(): Env('PT_HPU_SDPA_QKV_SLICE_MODE_FWD', boolean), Env('PT_HPU_SDPA_BC_FACTOR', int), Env('VLLM_FUSEDSDPA_SLIDE_THLD', int), + + # FusedSDPA slicing flags + Env('VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD', int), + Env('VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE', int), + Env('VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS', boolean), ] return to_dict(flags) diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 5b3b2a640f..8c4eeb5455 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -80,8 +80,12 @@ def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, bat batch2block_matmul_op, block2batch_matmul_op): # When fp32_softmax is enabled attn is left in fp32 after Q@K # We can return to native dtype after we renormalize and calculate the adjustments - if block_bias is not None and attn.dtype != block_bias.dtype: - block_bias = block_bias.to(dtype=attn.dtype) + if block_bias is not None: + if block_bias.dtype == torch.bool: + # Convert boolean mask (True=valid, False=masked) to additive bias (0.0/-inf) + block_bias = torch.zeros_like(block_bias, dtype=attn.dtype).masked_fill_(~block_bias, float('-inf')) + elif attn.dtype != block_bias.dtype: + block_bias = block_bias.to(dtype=attn.dtype) # TODO: w/a with 5D req as the block_softmax kernel does not support 4D attn tensor, which is used in e.g. Granite-3B if get_config().fused_block_softmax and get_config().fused_block_softmax_adjustment and attn.dim() == 5: attn, block_max, block_sums = torch.ops.hpu.block_softmax(attn, block_bias, block_groups) @@ -357,9 +361,12 @@ def _naive_prompt_attention(query: torch.Tensor, htcore.mark_step() attn_weights.add_(position_bias) if attn_bias is not None: - if attn_weights.dtype != attn_bias.dtype: - attn_bias = attn_bias.to(dtype=attn_weights.dtype) - attn_weights.add_(attn_bias) + if attn_bias.dtype == torch.bool: + attn_weights = attn_weights.masked_fill(~attn_bias, float("-inf")) + else: + if attn_weights.dtype != attn_bias.dtype: + attn_bias = attn_bias.to(dtype=attn_weights.dtype) + attn_weights.add_(attn_bias) if sinks is not None: sink = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) if query_heads != kv_heads: @@ -404,10 +411,6 @@ def _fsdpa_prompt_attention(query: torch.Tensor, recompute_mode = True assert attn_bias is not None or valid_seq_lengths is not None, \ 'Either attn_bias or valid_seq_lengths must be != None' - if is_causal and attn_bias is not None: - # TODO: causal + attn_bias is not yet supported - is_causal = False - valid_seq_lengths = None args = [ query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths, diff --git a/vllm_gaudi/extension/utils.py b/vllm_gaudi/extension/utils.py index b65dae5734..33c68f97dd 100644 --- a/vllm_gaudi/extension/utils.py +++ b/vllm_gaudi/extension/utils.py @@ -6,12 +6,14 @@ ############################################################################### import os +import math from functools import lru_cache, wraps from typing import Optional, Any import habana_frameworks.torch as htorch import torch import itertools +from vllm_gaudi.extension.logger import logger from vllm_gaudi.extension.runtime import get_config @@ -155,6 +157,7 @@ def __init__(self, fusedSDPA): super().__init__() assert fusedSDPA is not None, f'fusedSDPA kernel is None' self._hpu_kernel_fsdpa = fusedSDPA + self.enable_slicing = self._setup_slicing() def forward( self, @@ -172,6 +175,20 @@ def forward( window_size=None, sinks=None, ): + bs = query.shape[0] + q_len = query.shape[-2] + kv_len = key.shape[-2] + if (self.enable_slicing and kv_len >= self.slice_thld \ + and bs == 1 # bs should be 1 for chunked prefill + and q_len != kv_len # normal causal prefill route to the default dispatch for better performance + and is_causal and attn_mask is not None # only supports causal attention with mask + ): + return self._sliced_fsdpa_fwd(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, + recompute_mode, valid_sequence_lengths, padding_side) + if is_causal and attn_mask is not None: + # TODO: causal + attn_bias is not yet supported + is_causal = False + valid_sequence_lengths = None if window_size is not None: return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side, False, False, @@ -181,6 +198,202 @@ def forward( recompute_mode, valid_sequence_lengths, padding_side, False, False, (-1, -1), sinks) + def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, + valid_sequence_lengths, padding_side): + assert is_causal and attn_mask is not None + + from habana_frameworks.torch.hpex.kernels.FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape + gqa = is_gqa(query, key) + if gqa: + q, k, v, attn_mask = gqa_input_reshape_fwd(query, key, value, attn_mask) + else: + q, k, v, attn_mask = (query, key, value, attn_mask) + q_len = q.shape[-2] + kv_len = k.shape[-2] + prefix_len = kv_len - q_len + + chunk_outputs = [] + num_q_chunks = math.ceil(q_len / self.chunk_size) + num_prefix_chunks = math.ceil(prefix_len / self.chunk_size) + for q_chunk_idx in range(num_q_chunks): + q_start = q_len - (q_chunk_idx + 1) * self.chunk_size + q_start = max(q_start, 0) + q_end = q_len - q_chunk_idx * self.chunk_size + q_chunk_size = q_end - q_start + q_chunk = q[..., q_start:q_end, :].clone() if self.with_graph_breaks else q[..., q_start:q_end, :] + + last_out = None + last_m = None + last_linv = None + + # the causal part + for kv_chunk_idx in range(0, num_q_chunks - q_chunk_idx): + kv_start = prefix_len + q_end - (kv_chunk_idx + 1) * self.chunk_size + kv_start = max(kv_start, prefix_len) + kv_end = prefix_len + q_end - kv_chunk_idx * self.chunk_size + kv_chunk_size = kv_end - kv_start + k_chunk = k[..., kv_start:kv_end, :] + v_chunk = v[..., kv_start:kv_end, :] + + is_causal_chunk = kv_chunk_idx == 0 and q_chunk_idx != 0 + # chunk sizes must be multiples of 1024 to get valid m and linv + is_causal_chunk = is_causal_chunk and q_chunk_size % 1024 == 0 and kv_chunk_size % 1024 == 0 + # use mask only for the causal chunks that may have padding + mask_chunk = attn_mask[ + ..., q_start:q_end, + kv_start:kv_end] if kv_chunk_idx < self.num_padded_query_chunks and not is_causal_chunk else None + + if self.with_graph_breaks: + k_chunk = k_chunk.clone() + v_chunk = v_chunk.clone() + mask_chunk = mask_chunk.clone() if mask_chunk is not None else None + self.break_graph() + + chunk_res = torch.ops.hpu.sdpa_recomp_fwd( + q_chunk, + k_chunk, + v_chunk, + mask_chunk, + dropout_p, + scale, + is_causal_chunk, + True, # requires_backward + softmax_mode, + None, # valid_seq_len + padding_side, + ) + chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32) + for x in (chunk_res[:3])) + + if last_out is None or last_m is None or last_linv is None: + last_out = chunk_out + last_m = chunk_m + last_linv = chunk_linv + else: + new_m = torch.maximum(last_m, chunk_m) + last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) + chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) + last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) + last_out = (last_linv_rescaled * last_linv) * last_out + (chunk_linv_rescaled * + last_linv) * chunk_out + last_m = new_m + + if self.with_graph_breaks: + self.break_graph() + + # the context part + for kv_chunk_idx in range(num_prefix_chunks): + kv_start = prefix_len - (kv_chunk_idx + 1) * self.chunk_size + kv_start = max(kv_start, 0) + kv_end = prefix_len - kv_chunk_idx * self.chunk_size + k_chunk = k[..., kv_start:kv_end, :] + v_chunk = v[..., kv_start:kv_end, :] + # use mask only for the chunks that may have padding + mask_chunk = attn_mask[..., q_start:q_end, + kv_start:kv_end] if kv_chunk_idx < self.num_padded_ctx_chunks else None + + if self.with_graph_breaks: + k_chunk = k_chunk.clone() + v_chunk = v_chunk.clone() + mask_chunk = mask_chunk.clone() if mask_chunk is not None else None + self.break_graph() + + chunk_res = torch.ops.hpu.sdpa_recomp_fwd( + q_chunk, + k_chunk, + v_chunk, + mask_chunk, + dropout_p, + scale, + False, # is_causal + True, # requires_backward + softmax_mode, + None, # valid_seq_len + padding_side, + ) + chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32) + for x in chunk_res[:3]) + + assert not (last_out is None or last_m is None or last_linv is None) + new_m = torch.maximum(last_m, chunk_m) + last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) + chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) + last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) + last_out = (last_linv_rescaled * last_linv) * last_out + (chunk_linv_rescaled * last_linv) * chunk_out + last_m = new_m + + if self.with_graph_breaks: + self.break_graph() + chunk_outputs.append(last_out) + chunk_outputs = list(reversed(chunk_outputs)) + output = torch.cat(chunk_outputs, dim=-2) + return output.to(q.dtype) + + def _setup_slicing(self) -> bool: + from vllm_gaudi.extension.bucketing.common import get_bucketing_manager + bucketing_manager = get_bucketing_manager() + enable_slicing = bucketing_manager is not None + if not enable_slicing: + logger().warning('Bucketing manager is not instantiated, slicing in FSDPA will be disabled.') + return False + assert bucketing_manager is not None + enable_slicing = enable_slicing and bucketing_manager.initialized + if not enable_slicing: + logger().warning('Bucketing manager is not initialized, slicing in FSDPA will be disabled.') + return False + + from vllm_gaudi.extension.bucketing.linear import LinearBucketingStrategy + strategy = bucketing_manager.get_bucketing_strategy() + enable_slicing = isinstance(strategy, LinearBucketingStrategy) + if not enable_slicing: + logger().debug('Not using Linear Bucketing Strategy, slicing in FSDPA will be disabled.') + return False + + max_num_batched_tokens = bucketing_manager.max_num_batched_tokens + slice_thld_default = min(max_num_batched_tokens, 8192) + slice_thld = int(os.getenv("VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD", str(slice_thld_default))) + enable_slicing = enable_slicing and slice_thld >= slice_thld_default + if not enable_slicing and slice_thld > 0: + logger().warning('Invalid slice sequence length threshold, the threshold should be ' + f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.') + slice_thld = slice_thld_default + + if enable_slicing: + # default to half of the threshold and round up by 1024 + chunk_size_default = math.ceil(slice_thld // 2 / 1024) * 1024 + chunk_size = int(os.getenv("VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE", str(chunk_size_default))) + block_size = bucketing_manager.block_size + if chunk_size < block_size or chunk_size > slice_thld: + logger().warning(f'Invalid chunk size for FusedSDPA slicing, the chunk size should be between ' + f'{block_size} and {slice_thld}, falling back to default {chunk_size_default}.') + chunk_size = chunk_size_default + if chunk_size % 1024 != 0: + chunk_size = math.ceil(chunk_size / 1024) * 1024 + logger().warning('Rounded up the chunk size for FusedSDPA slicing to the next multiple of 1024.') + self.slice_thld = slice_thld + self.chunk_size = chunk_size + max_query_pad_default = math.ceil(max_num_batched_tokens / 4) + max_query_pad = int(os.getenv("VLLM_PROMPT_QUERY_BUCKET_PAD_MAX", str(max_query_pad_default))) + self.num_padded_query_chunks = math.ceil(max_query_pad / self.chunk_size) + max_ctx_pad_default = math.ceil(max_num_batched_tokens / block_size) + max_ctx_pad = int(os.getenv("VLLM_PROMPT_CTX_BUCKET_PAD_MAX", str(max_ctx_pad_default))) + self.num_padded_ctx_chunks = math.ceil(max_ctx_pad * block_size / self.chunk_size) + + import habana_frameworks.torch as ht + is_lazy = ht.utils.internal.is_lazy() + self.with_graph_breaks = os.getenv("VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS", + str(is_lazy)).strip().lower() in ("1", "true") + if self.with_graph_breaks: + if is_lazy: + self.break_graph = ht.core.mark_step + else: + self.break_graph = torch._dynamo.graph_break + msg = (f"FusedSDPA slicing is enabled with sequence length threshold {slice_thld}, " + f"chunk size {self.chunk_size}, num padded query chunks {self.num_padded_query_chunks}, " + f"num padded ctx chunks {self.num_padded_ctx_chunks}, with graph breaks {self.with_graph_breaks}.") + logger().debug(msg) + return enable_slicing + class ModuleFP8FusedSDPA(torch.nn.Module): diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 6a0d2eb40a..54c0ea7d20 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2062,7 +2062,6 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes, num_schedul return all_batch_contents, num_pad_across_dp def _make_attn_bias(self, context_groups, token_groups): - dtype = self.dtype is_causal = True # TODO: add support for non-causal tasks context_groups = torch.tensor(context_groups, device='cpu', dtype=torch.int16) context_groups = context_groups.repeat_interleave(self.block_size, dim=-1) @@ -2075,7 +2074,7 @@ def _make_attn_bias(self, context_groups, token_groups): causal_mask = torch.ones(num_queries, num_queries, device='cpu', dtype=torch.bool) causal_mask = torch.triu(causal_mask, diagonal=1).unsqueeze(0) attn_mask[:, :, context_len:].logical_or_(causal_mask) - attn_mask = attn_mask.to(dtype).masked_fill_(attn_mask, -math.inf) + attn_mask = ~attn_mask return attn_mask.unflatten(0, (1, -1)) @@ -4385,8 +4384,13 @@ def load_model(self) -> None: self.drafter.load_model(self.model.model) if self.use_aux_hidden_state_outputs: if supports_eagle3(self.model.model): - self.model.model.set_aux_hidden_state_layers( - self.model.model.get_eagle3_aux_hidden_state_layers()) + # Try new API name first (upstream >= v0.17.2), + # fall back to old name for older vLLM versions. + if hasattr(self.model.model, 'get_eagle3_default_aux_hidden_state_layers'): + aux_layers = self.model.model.get_eagle3_default_aux_hidden_state_layers() + else: + aux_layers = self.model.model.get_eagle3_aux_hidden_state_layers() + self.model.model.set_aux_hidden_state_layers(aux_layers) else: raise RuntimeError("Model does not support EAGLE3 interface but " "aux_hidden_state_outputs was requested") @@ -6610,7 +6614,7 @@ def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, diagonal=1) mask = causal_mask.logical_or(len_mask) mask = torch.concat((past_mask, mask), dim=-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) + attn_bias = ~mask attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", attn_bias=attn_bias) return attn_metadata @@ -6668,16 +6672,13 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV # seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len)) # causal_mask = causal_mask.logical_and(len_mask) - mask = torch.concat((past_mask, causal_mask), dim=-1) - attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), - torch.tensor(float('-inf'), dtype=dtype, device=device)) + attn_bias = torch.concat((past_mask, causal_mask), dim=-1) else: # CAUSAL MASK without removing padding (CAUSAL+sliding window) # removing padding cause accuracy issue for images input - tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) + tensor = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) mask = torch.tril(tensor, diagonal=shift) - mask = torch.triu(mask, diagonal=shift - window_size + 1) - attn_bias = torch.log(mask) + attn_bias = torch.triu(mask, diagonal=shift - window_size + 1) attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", window_attn_bias=attn_bias) return attn_metadata @@ -6730,18 +6731,15 @@ def _set_attn_bias_for_chunked_attention(self, attn_metadata: HPUAttentionMetada causal_mask = causal_mask & same_chunk_mask causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len) - mask = torch.concat((past_mask, causal_mask), dim=-1) - attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), - torch.tensor(float('-inf'), dtype=dtype, device=device)) + attn_bias = torch.concat((past_mask, causal_mask), dim=-1) else: - tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) + tensor = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) mask = torch.tril(tensor, diagonal=shift) idx = torch.arange(seq_len, device=device) chunk_id = idx // chunk_size same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1) same_chunk = same_chunk.unsqueeze(0).unsqueeze(0) - mask = torch.where(same_chunk, mask, torch.tensor(0.0, dtype=dtype, device=device)) - attn_bias = torch.log(mask) + attn_bias = same_chunk & mask attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", chunked_attn_bias=attn_bias) return attn_metadata @@ -6780,8 +6778,7 @@ def _set_block_mapping(self, block_groups = metadata.block_groups mask = torch.arange(0, self.block_size, device=device, dtype=torch.int32).unsqueeze(0) - mask = mask >= block_usage.unsqueeze(-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) + attn_bias = mask < block_usage.unsqueeze(-1) if not is_fake_hpu(): block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size)