From e3f89b27842970e8bb3cbf5633e1f904fc51eb04 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 2 Apr 2026 07:45:40 +0000 Subject: [PATCH 1/5] Revert "Use Boolean attention mask (#1032)" This reverts commit 9271c0884d383f21b187fb388fc84e251fb1cc1a. --- vllm_gaudi/extension/ops.py | 9 +++----- vllm_gaudi/v1/worker/hpu_model_runner.py | 26 ++++++++++++++++-------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index d65cdf1f5c..d8a6649947 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -349,12 +349,9 @@ def _naive_prompt_attention(query: torch.Tensor, htcore.mark_step() attn_weights.add_(position_bias) if attn_bias is not None: - if attn_bias.dtype == torch.bool: - attn_weights = attn_weights.masked_fill(~attn_bias, float("-inf")) - else: - if attn_weights.dtype != attn_bias.dtype: - attn_bias = attn_bias.to(dtype=attn_weights.dtype) - attn_weights.add_(attn_bias) + if attn_weights.dtype != attn_bias.dtype: + attn_bias = attn_bias.to(dtype=attn_weights.dtype) + attn_weights.add_(attn_bias) if sinks is not None: sink = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) if query_heads != kv_heads: diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index b10de61e82..bc439ef66e 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2054,6 +2054,7 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes, num_schedul return all_batch_contents, num_pad_across_dp def _make_attn_bias(self, context_groups, token_groups): + dtype = self.dtype is_causal = True # TODO: add support for non-causal tasks context_groups = torch.tensor(context_groups, device='cpu', dtype=torch.int16) context_groups = context_groups.repeat_interleave(self.block_size, dim=-1) @@ -2066,7 +2067,7 @@ def _make_attn_bias(self, context_groups, token_groups): causal_mask = torch.ones(num_queries, num_queries, device='cpu', dtype=torch.bool) causal_mask = torch.triu(causal_mask, diagonal=1).unsqueeze(0) attn_mask[:, :, context_len:].logical_or_(causal_mask) - attn_mask = ~attn_mask + attn_mask = attn_mask.to(dtype).masked_fill_(attn_mask, -math.inf) return attn_mask.unflatten(0, (1, -1)) @@ -6547,7 +6548,7 @@ def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, diagonal=1) mask = causal_mask.logical_or(len_mask) mask = torch.concat((past_mask, mask), dim=-1) - attn_bias = ~mask + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", attn_bias=attn_bias) return attn_metadata @@ -6605,13 +6606,16 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV # seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len)) # causal_mask = causal_mask.logical_and(len_mask) - attn_bias = torch.concat((past_mask, causal_mask), dim=-1) + mask = torch.concat((past_mask, causal_mask), dim=-1) + attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), + torch.tensor(float('-inf'), dtype=dtype, device=device)) else: # CAUSAL MASK without removing padding (CAUSAL+sliding window) # removing padding cause accuracy issue for images input - tensor = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) + tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) mask = torch.tril(tensor, diagonal=shift) - attn_bias = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.triu(mask, diagonal=shift - window_size + 1) + attn_bias = torch.log(mask) attn_metadata = prefill_metadata._replace(window_attn_bias=attn_bias) return attn_metadata @@ -6664,15 +6668,18 @@ def _set_attn_bias_for_chunked_attention(self, attn_metadata: HPUAttentionMetada causal_mask = causal_mask & same_chunk_mask causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len) - attn_bias = torch.concat((past_mask, causal_mask), dim=-1) + mask = torch.concat((past_mask, causal_mask), dim=-1) + attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), + torch.tensor(float('-inf'), dtype=dtype, device=device)) else: - tensor = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) + tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) mask = torch.tril(tensor, diagonal=shift) idx = torch.arange(seq_len, device=device) chunk_id = idx // chunk_size same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1) same_chunk = same_chunk.unsqueeze(0).unsqueeze(0) - attn_bias = same_chunk & mask + mask = torch.where(same_chunk, mask, torch.tensor(0.0, dtype=dtype, device=device)) + attn_bias = torch.log(mask) attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", chunked_attn_bias=attn_bias) return attn_metadata @@ -6711,7 +6718,8 @@ def _set_block_mapping(self, block_groups = metadata.block_groups mask = torch.arange(0, self.block_size, device=device, dtype=torch.int32).unsqueeze(0) - attn_bias = mask < block_usage.unsqueeze(-1) + mask = mask >= block_usage.unsqueeze(-1) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) if not is_fake_hpu(): block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size) From f3ff913f113906e9cf434d76c3a2042318d9ef32 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 2 Apr 2026 08:24:18 +0000 Subject: [PATCH 2/5] use finite numbers for attention mask value Signed-off-by: Youlei Yang --- vllm_gaudi/v1/worker/hpu_model_runner.py | 33 ++++++++++++++---------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index bc439ef66e..61475346d6 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -133,6 +133,9 @@ logger = init_logger() +MASK_VALUES = {torch.float32: -3E38, torch.bfloat16: -3E38, torch.float16: -6E4} +DEFAULT_MASK_VALUE = -3E38 + try: from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorMetadata except ImportError: @@ -2067,7 +2070,7 @@ def _make_attn_bias(self, context_groups, token_groups): causal_mask = torch.ones(num_queries, num_queries, device='cpu', dtype=torch.bool) causal_mask = torch.triu(causal_mask, diagonal=1).unsqueeze(0) attn_mask[:, :, context_len:].logical_or_(causal_mask) - attn_mask = attn_mask.to(dtype).masked_fill_(attn_mask, -math.inf) + attn_mask = attn_mask.to(dtype).masked_fill_(attn_mask, MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE)) return attn_mask.unflatten(0, (1, -1)) @@ -3787,15 +3790,11 @@ def set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): if self.is_pooling_model: len_mask_v = len_mask.view(batch_size, 1, seq_len, 1) mask = attn_mask.logical_or(len_mask).logical_or(len_mask_v) - off_value = -3E38 # small number, avoid nan and overflow - if dtype == torch.float16: - off_value = -63000 # a small value close to float16.min else: mask = attn_mask.logical_or(len_mask) # no need for len_mask_v as decode overwrites it - off_value = -math.inf mask = torch.concat((past_mask, mask), dim=-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, off_value)) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE))) attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", attn_bias=attn_bias) return attn_metadata @@ -6548,7 +6547,7 @@ def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, diagonal=1) mask = causal_mask.logical_or(len_mask) mask = torch.concat((past_mask, mask), dim=-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE))) attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", attn_bias=attn_bias) return attn_metadata @@ -6607,15 +6606,18 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV # causal_mask = causal_mask.logical_and(len_mask) mask = torch.concat((past_mask, causal_mask), dim=-1) - attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), - torch.tensor(float('-inf'), dtype=dtype, device=device)) + attn_bias = torch.where( + mask, torch.tensor(0.0, dtype=dtype, device=device), + torch.tensor(MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE), dtype=dtype, device=device)) else: # CAUSAL MASK without removing padding (CAUSAL+sliding window) # removing padding cause accuracy issue for images input tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) mask = torch.tril(tensor, diagonal=shift) mask = torch.triu(mask, diagonal=shift - window_size + 1) - attn_bias = torch.log(mask) + attn_bias = torch.where( + mask, torch.tensor(0.0, dtype=dtype, device=device), + torch.tensor(MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE), dtype=dtype, device=device)) attn_metadata = prefill_metadata._replace(window_attn_bias=attn_bias) return attn_metadata @@ -6669,8 +6671,9 @@ def _set_attn_bias_for_chunked_attention(self, attn_metadata: HPUAttentionMetada causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len) mask = torch.concat((past_mask, causal_mask), dim=-1) - attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), - torch.tensor(float('-inf'), dtype=dtype, device=device)) + attn_bias = torch.where( + mask, torch.tensor(0.0, dtype=dtype, device=device), + torch.tensor(MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE), dtype=dtype, device=device)) else: tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) mask = torch.tril(tensor, diagonal=shift) @@ -6679,7 +6682,9 @@ def _set_attn_bias_for_chunked_attention(self, attn_metadata: HPUAttentionMetada same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1) same_chunk = same_chunk.unsqueeze(0).unsqueeze(0) mask = torch.where(same_chunk, mask, torch.tensor(0.0, dtype=dtype, device=device)) - attn_bias = torch.log(mask) + attn_bias = torch.where( + mask, torch.tensor(0.0, dtype=dtype, device=device), + torch.tensor(MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE), dtype=dtype, device=device)) attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", chunked_attn_bias=attn_bias) return attn_metadata @@ -6719,7 +6724,7 @@ def _set_block_mapping(self, mask = torch.arange(0, self.block_size, device=device, dtype=torch.int32).unsqueeze(0) mask = mask >= block_usage.unsqueeze(-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, MASK_VALUES.get(dtype, DEFAULT_MASK_VALUE))) if not is_fake_hpu(): block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size) From 1a004e04dba165520600c1df7dc1742e0283fbfb Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Wed, 1 Apr 2026 06:45:59 +0000 Subject: [PATCH 3/5] enable slicing for fp8 FusedSDPA Signed-off-by: Youlei Yang --- vllm_gaudi/attention/backends/hpu_attn.py | 7 + vllm_gaudi/extension/utils.py | 329 +++++++++++++++++----- 2 files changed, 259 insertions(+), 77 deletions(-) diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index 370d479b88..bc30597711 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -454,6 +454,13 @@ def __init__( HPUFusedSDPA = kernels.fsdpa() self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \ else ModuleFusedSDPA(HPUFusedSDPA) + try: + from habana_frameworks.torch.hpex.kernels import fp8_fused_sdpa + if self.enable_fp8_attn: + self.fused_scaled_dot_product_attention = ModuleFP8FusedSDPA(fp8_fused_sdpa) + except ImportError: + pass + self.prefill_impl = get_config().prompt_attn_impl self.use_contiguous_pa = get_config().use_contiguous_pa self.use_merged_prefill = get_config().merged_prefill diff --git a/vllm_gaudi/extension/utils.py b/vllm_gaudi/extension/utils.py index df148466d8..e39233988c 100644 --- a/vllm_gaudi/extension/utils.py +++ b/vllm_gaudi/extension/utils.py @@ -151,13 +151,84 @@ def forward(self, input, other, **kwargs): return output -class ModuleFusedSDPA(torch.nn.Module): +class ModuleFusedSDPABase(torch.nn.Module): + + def __init__(self): + super().__init__() + self.enable_slicing = self._setup_slicing() + + def _setup_slicing(self) -> bool: + from vllm_gaudi.extension.bucketing.common import get_bucketing_manager + bucketing_manager = get_bucketing_manager() + enable_slicing = bucketing_manager is not None + if not enable_slicing: + logger().warning('Bucketing manager is not instantiated, slicing in FSDPA will be disabled.') + return False + assert bucketing_manager is not None + enable_slicing = enable_slicing and bucketing_manager.initialized + if not enable_slicing: + logger().warning('Bucketing manager is not initialized, slicing in FSDPA will be disabled.') + return False + + from vllm_gaudi.extension.bucketing.linear import LinearBucketingStrategy + strategy = bucketing_manager.get_bucketing_strategy() + enable_slicing = isinstance(strategy, LinearBucketingStrategy) + if not enable_slicing: + logger().debug('Not using Linear Bucketing Strategy, slicing in FSDPA will be disabled.') + return False + + max_num_batched_tokens = bucketing_manager.max_num_batched_tokens + slice_thld_default = min(max_num_batched_tokens, 8192) + slice_thld = int(os.getenv("VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD", str(slice_thld_default))) + enable_slicing = enable_slicing and slice_thld >= slice_thld_default + if not enable_slicing and slice_thld > 0: + logger().warning('Invalid slice sequence length threshold, the threshold should be ' + f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.') + slice_thld = slice_thld_default + + if enable_slicing: + # default to half of the threshold and round up by 1024 + chunk_size_default = math.ceil(slice_thld // 2 / 1024) * 1024 + chunk_size = int(os.getenv("VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE", str(chunk_size_default))) + block_size = bucketing_manager.block_size + if chunk_size < block_size or chunk_size > slice_thld: + logger().warning(f'Invalid chunk size for FusedSDPA slicing, the chunk size should be between ' + f'{block_size} and {slice_thld}, falling back to default {chunk_size_default}.') + chunk_size = chunk_size_default + if chunk_size % 1024 != 0: + chunk_size = math.ceil(chunk_size / 1024) * 1024 + logger().warning('Rounded up the chunk size for FusedSDPA slicing to the next multiple of 1024.') + self.slice_thld = slice_thld + self.chunk_size = chunk_size + max_query_pad_default = math.ceil(max_num_batched_tokens / 4) + max_query_pad = int(os.getenv("VLLM_PROMPT_QUERY_BUCKET_PAD_MAX", str(max_query_pad_default))) + self.num_padded_query_chunks = math.ceil(max_query_pad / self.chunk_size) + max_ctx_pad_default = math.ceil(max_num_batched_tokens / block_size) + max_ctx_pad = int(os.getenv("VLLM_PROMPT_CTX_BUCKET_PAD_MAX", str(max_ctx_pad_default))) + self.num_padded_ctx_chunks = math.ceil(max_ctx_pad * block_size / self.chunk_size) + + import habana_frameworks.torch as ht + is_lazy = ht.utils.internal.is_lazy() + self.with_graph_breaks = os.getenv("VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS", + str(is_lazy)).strip().lower() in ("1", "true") + if self.with_graph_breaks: + if is_lazy: + self.break_graph = ht.core.mark_step + else: + self.break_graph = torch._dynamo.graph_break + msg = (f"FusedSDPA slicing is enabled with sequence length threshold {slice_thld}, " + f"chunk size {self.chunk_size}, num padded query chunks {self.num_padded_query_chunks}, " + f"num padded ctx chunks {self.num_padded_ctx_chunks}, with graph breaks {self.with_graph_breaks}.") + logger().debug(msg) + return enable_slicing + + +class ModuleFusedSDPA(ModuleFusedSDPABase): def __init__(self, fusedSDPA): super().__init__() assert fusedSDPA is not None, f'fusedSDPA kernel is None' self._hpu_kernel_fsdpa = fusedSDPA - self.enable_slicing = self._setup_slicing() def forward( self, @@ -184,7 +255,7 @@ def forward( and is_causal and attn_mask is not None # only supports causal attention with mask ): return self._sliced_fsdpa_fwd(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, - recompute_mode, valid_sequence_lengths, padding_side) + padding_side) if is_causal and attn_mask is not None: # TODO: causal + attn_bias is not yet supported is_causal = False @@ -198,8 +269,7 @@ def forward( recompute_mode, valid_sequence_lengths, padding_side, False, False, (-1, -1), sinks) - def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, - valid_sequence_lengths, padding_side): + def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, padding_side): assert is_causal and attn_mask is not None from habana_frameworks.torch.hpex.kernels.FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape @@ -211,6 +281,8 @@ def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, q_len = q.shape[-2] kv_len = k.shape[-2] prefix_len = kv_len - q_len + if scale is None: + scale = 1.0 / (query.shape[-1]**0.5) chunk_outputs = [] num_q_chunks = math.ceil(q_len / self.chunk_size) @@ -235,13 +307,12 @@ def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, k_chunk = k[..., kv_start:kv_end, :] v_chunk = v[..., kv_start:kv_end, :] - is_causal_chunk = kv_chunk_idx == 0 and q_chunk_idx != 0 + is_causal_chunk = kv_chunk_idx == 0 and q_chunk_idx >= self.num_padded_query_chunks # chunk sizes must be multiples of 1024 to get valid m and linv is_causal_chunk = is_causal_chunk and q_chunk_size % 1024 == 0 and kv_chunk_size % 1024 == 0 # use mask only for the causal chunks that may have padding - mask_chunk = attn_mask[ - ..., q_start:q_end, - kv_start:kv_end] if kv_chunk_idx < self.num_padded_query_chunks and not is_causal_chunk else None + mask_chunk = (attn_mask[..., q_start:q_end, kv_start:kv_end] + if kv_chunk_idx < self.num_padded_query_chunks and not is_causal_chunk else None) if self.with_graph_breaks: # break_graph() cannot break the tensor slicing, use clone to isolate the graph @@ -290,8 +361,8 @@ def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, k_chunk = k[..., kv_start:kv_end, :] v_chunk = v[..., kv_start:kv_end, :] # use mask only for the chunks that may have padding - mask_chunk = attn_mask[..., q_start:q_end, - kv_start:kv_end] if kv_chunk_idx < self.num_padded_ctx_chunks else None + mask_chunk = (attn_mask[..., q_start:q_end, + kv_start:kv_end] if kv_chunk_idx < self.num_padded_ctx_chunks else None) if self.with_graph_breaks: # break_graph() cannot break the tensor slicing, use clone to isolate the graph @@ -331,73 +402,8 @@ def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, output = torch.cat(chunk_outputs, dim=-2) return output.to(q.dtype) - def _setup_slicing(self) -> bool: - from vllm_gaudi.extension.bucketing.common import get_bucketing_manager - bucketing_manager = get_bucketing_manager() - enable_slicing = bucketing_manager is not None - if not enable_slicing: - logger().warning('Bucketing manager is not instantiated, slicing in FSDPA will be disabled.') - return False - assert bucketing_manager is not None - enable_slicing = enable_slicing and bucketing_manager.initialized - if not enable_slicing: - logger().warning('Bucketing manager is not initialized, slicing in FSDPA will be disabled.') - return False - - from vllm_gaudi.extension.bucketing.linear import LinearBucketingStrategy - strategy = bucketing_manager.get_bucketing_strategy() - enable_slicing = isinstance(strategy, LinearBucketingStrategy) - if not enable_slicing: - logger().debug('Not using Linear Bucketing Strategy, slicing in FSDPA will be disabled.') - return False - - max_num_batched_tokens = bucketing_manager.max_num_batched_tokens - slice_thld_default = min(max_num_batched_tokens, 8192) - slice_thld = int(os.getenv("VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD", str(slice_thld_default))) - enable_slicing = enable_slicing and slice_thld >= slice_thld_default - if not enable_slicing and slice_thld > 0: - logger().warning('Invalid slice sequence length threshold, the threshold should be ' - f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.') - slice_thld = slice_thld_default - - if enable_slicing: - # default to half of the threshold and round up by 1024 - chunk_size_default = math.ceil(slice_thld // 2 / 1024) * 1024 - chunk_size = int(os.getenv("VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE", str(chunk_size_default))) - block_size = bucketing_manager.block_size - if chunk_size < block_size or chunk_size > slice_thld: - logger().warning(f'Invalid chunk size for FusedSDPA slicing, the chunk size should be between ' - f'{block_size} and {slice_thld}, falling back to default {chunk_size_default}.') - chunk_size = chunk_size_default - if chunk_size % 1024 != 0: - chunk_size = math.ceil(chunk_size / 1024) * 1024 - logger().warning('Rounded up the chunk size for FusedSDPA slicing to the next multiple of 1024.') - self.slice_thld = slice_thld - self.chunk_size = chunk_size - max_query_pad_default = math.ceil(max_num_batched_tokens / 4) - max_query_pad = int(os.getenv("VLLM_PROMPT_QUERY_BUCKET_PAD_MAX", str(max_query_pad_default))) - self.num_padded_query_chunks = math.ceil(max_query_pad / self.chunk_size) - max_ctx_pad_default = math.ceil(max_num_batched_tokens / block_size) - max_ctx_pad = int(os.getenv("VLLM_PROMPT_CTX_BUCKET_PAD_MAX", str(max_ctx_pad_default))) - self.num_padded_ctx_chunks = math.ceil(max_ctx_pad * block_size / self.chunk_size) - - import habana_frameworks.torch as ht - is_lazy = ht.utils.internal.is_lazy() - self.with_graph_breaks = os.getenv("VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS", - str(is_lazy)).strip().lower() in ("1", "true") - if self.with_graph_breaks: - if is_lazy: - self.break_graph = ht.core.mark_step - else: - self.break_graph = torch._dynamo.graph_break - msg = (f"FusedSDPA slicing is enabled with sequence length threshold {slice_thld}, " - f"chunk size {self.chunk_size}, num padded query chunks {self.num_padded_query_chunks}, " - f"num padded ctx chunks {self.num_padded_ctx_chunks}, with graph breaks {self.with_graph_breaks}.") - logger().debug(msg) - return enable_slicing - -class ModuleFP8FusedSDPA(torch.nn.Module): +class ModuleFP8FusedSDPA(ModuleFusedSDPABase): def __init__(self, fusedSDPA): super().__init__() @@ -413,6 +419,7 @@ def __init__(self, fusedSDPA): self.d_scale_q = torch.tensor(1.0) self.d_scale_k = torch.tensor(1.0) self.d_scale_v = torch.tensor(1.0) + self.d_scale_output = torch.tensor(1.0) def quant_input(self, x, scale): return torch.ops.hpu.cast_to_fp8_v2(x, scale, False, False, torch.float8_e4m3fn)[0] @@ -440,6 +447,22 @@ def forward( kinput = self.quant_input(key, self.scale_k) vinput = self.quant_input(value, self.scale_v) + bs = query.shape[0] + q_len = query.shape[-2] + kv_len = key.shape[-2] + if (self.enable_slicing and kv_len >= self.slice_thld \ + and bs == 1 # bs should be 1 for chunked prefill + and q_len != kv_len # normal causal prefill route to the default dispatch for better performance + and is_causal and attn_mask is not None # only supports causal attention with mask + ): + return self._sliced_fsdpa_fwd(qinput, kinput, vinput, attn_mask, dropout_p, is_causal, scale, softmax_mode, + padding_side).to(query.dtype) + + if is_causal and attn_mask is not None: + # TODO: causal + attn_bias is not yet supported + is_causal = False + valid_sequence_lengths = None + results = self.fp8_fused_sdpa( qinput, kinput, @@ -463,6 +486,158 @@ def forward( output = results[0] return output + def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, padding_side): + assert is_causal and attn_mask is not None + + from habana_frameworks.torch.hpex.kernels.Fp8FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape + gqa = is_gqa(query, key) + if gqa: + q, k, v, attn_mask = gqa_input_reshape_fwd(query, key, value, attn_mask) + else: + q, k, v, attn_mask = (query, key, value, attn_mask) + q_len = query.shape[-2] + kv_len = key.shape[-2] + prefix_len = kv_len - q_len + softmax_mode = softmax_mode if softmax_mode == "fp32" else "fast" + if scale is None: + scale = 1.0 / (query.shape[-1]**0.5) + + chunk_outputs = [] + num_q_chunks = math.ceil(q_len / self.chunk_size) + num_prefix_chunks = math.ceil(prefix_len / self.chunk_size) + for q_chunk_idx in range(num_q_chunks): + q_start = q_len - (q_chunk_idx + 1) * self.chunk_size + q_start = max(q_start, 0) + q_end = q_len - q_chunk_idx * self.chunk_size + q_chunk_size = q_end - q_start + q_chunk = q[..., q_start:q_end, :].clone() if self.with_graph_breaks else q[..., q_start:q_end, :] + + last_out = None + last_m = None + last_linv = None + + # the causal part + for kv_chunk_idx in range(0, num_q_chunks - q_chunk_idx): + kv_start = prefix_len + q_end - (kv_chunk_idx + 1) * self.chunk_size + kv_start = max(kv_start, prefix_len) + kv_end = prefix_len + q_end - kv_chunk_idx * self.chunk_size + kv_chunk_size = kv_end - kv_start + k_chunk = k[..., kv_start:kv_end, :] + v_chunk = v[..., kv_start:kv_end, :] + + is_causal_chunk = kv_chunk_idx == 0 and q_chunk_idx >= self.num_padded_query_chunks + # chunk sizes must be multiples of 1024 to get valid m and linv + is_causal_chunk = is_causal_chunk and q_chunk_size % 1024 == 0 and kv_chunk_size % 1024 == 0 + # use mask only for the causal chunks that may have padding + mask_chunk = (attn_mask[..., q_start:q_end, kv_start:kv_end] + if kv_chunk_idx < self.num_padded_query_chunks and not is_causal_chunk else None) + + if self.with_graph_breaks: + # break_graph() cannot break the tensor slicing, use clone to isolate the graph + k_chunk = k_chunk.clone() + v_chunk = v_chunk.clone() + mask_chunk = mask_chunk.clone() if mask_chunk is not None else None + self.break_graph() + + chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p, scale, is_causal_chunk, + softmax_mode) + + chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) if gqa else x for x in chunk_res[:3]) + chunk_m = chunk_m.to(torch.float32) + chunk_linv = chunk_linv.to(torch.float32) * (128.0 if softmax_mode == "fast" else 1.0) + chunk_out = self.dequant_output(chunk_out, self.d_scale_output).to(torch.float32) + + if last_out is None or last_m is None or last_linv is None: + last_out = chunk_out + last_m = chunk_m + last_linv = chunk_linv + else: + new_m = torch.maximum(last_m, chunk_m) + last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) + chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) + last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) + last_out = (last_linv_rescaled * last_linv) * last_out + \ + (chunk_linv_rescaled * last_linv) * chunk_out + last_m = new_m + + if self.with_graph_breaks: + self.break_graph() + + # the context part + for kv_chunk_idx in range(num_prefix_chunks): + kv_start = prefix_len - (kv_chunk_idx + 1) * self.chunk_size + kv_start = max(kv_start, 0) + kv_end = prefix_len - kv_chunk_idx * self.chunk_size + k_chunk = k[..., kv_start:kv_end, :] + v_chunk = v[..., kv_start:kv_end, :] + # use mask only for the chunks that may have padding + mask_chunk = (attn_mask[..., q_start:q_end, + kv_start:kv_end] if kv_chunk_idx < self.num_padded_ctx_chunks else None) + + if self.with_graph_breaks: + # break_graph() cannot break the tensor slicing, use clone to isolate the graph + k_chunk = k_chunk.clone() + v_chunk = v_chunk.clone() + mask_chunk = mask_chunk.clone() if mask_chunk is not None else None + self.break_graph() + + chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, None, dropout_p, scale, False, softmax_mode) + chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) if gqa else x for x in chunk_res[:3]) + chunk_m = chunk_m.to(torch.float32) + chunk_linv = chunk_linv.to(torch.float32) * (128.0 if softmax_mode == "fast" else 1.0) + chunk_out = self.dequant_output(chunk_out, self.d_scale_output).to(torch.float32) + + assert not (last_out is None or last_m is None or last_linv is None) + new_m = torch.maximum(last_m, chunk_m) + last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m) + chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m) + last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled) + last_out = (last_linv_rescaled * last_linv) * last_out + (chunk_linv_rescaled * last_linv) * chunk_out + last_m = new_m + + if self.with_graph_breaks: + self.break_graph() + + chunk_outputs.append(last_out) + chunk_outputs = list(reversed(chunk_outputs)) + return torch.cat(chunk_outputs, dim=-2) + + def fp8_fsdpa_fwd( + self, + q, + k, + v, + attn_mask, + dropout_p, + scale, + is_causal, + softmax_mode, + ): + results = torch.ops.hpu.fp8_sdpa_recomp_fwd( + q, + k, + v, + attn_mask, + dropout_p, + scale, + is_causal, + True, # requires_backward + softmax_mode, # softmax_mode + self.d_scale_q, # d_scale_q + self.d_scale_k, # d_scale_k + self.d_scale_v, # d_scale_v + self.scale_amax, # q_scale_s + self.d_scale_output, # q_scale_o + self.descale_amax, # d_scale_s + False, # is_amax_s + False, # is_amax_o + None, # valid_seq_len + "right", # seq_padding_type + (-1, -1), # window_size + None, # sink + ) + return results + def pad_list(input, target_len, val_generator): padding = target_len - len(input) From 6c3c91a8bedc3e9cc3675650286c2242e4a4a224 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Wed, 1 Apr 2026 08:24:42 +0000 Subject: [PATCH 4/5] fix gc error Signed-off-by: Youlei Yang --- vllm_gaudi/extension/utils.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm_gaudi/extension/utils.py b/vllm_gaudi/extension/utils.py index e39233988c..03c413ab98 100644 --- a/vllm_gaudi/extension/utils.py +++ b/vllm_gaudi/extension/utils.py @@ -411,15 +411,15 @@ def __init__(self, fusedSDPA): self.fp8_fused_sdpa = fusedSDPA # set the descale_amax and scale_amax 1.0 temporarily - self.descale_amax = torch.tensor(1.0) - self.scale_amax = torch.tensor(1.0) - self.scale_q = torch.tensor(1.0) - self.scale_k = torch.tensor(1.0) - self.scale_v = torch.tensor(1.0) - self.d_scale_q = torch.tensor(1.0) - self.d_scale_k = torch.tensor(1.0) - self.d_scale_v = torch.tensor(1.0) - self.d_scale_output = torch.tensor(1.0) + self.descale_amax = torch.tensor([1.0], dtype=torch.float32, device="hpu") + self.scale_amax = torch.tensor([1.0], dtype=torch.float32, device="hpu") + self.scale_q = torch.tensor([1.0], dtype=torch.float32, device="hpu") + self.scale_k = torch.tensor([1.0], dtype=torch.float32, device="hpu") + self.scale_v = torch.tensor([1.0], dtype=torch.float32, device="hpu") + self.d_scale_q = torch.tensor([1.0], dtype=torch.float32, device="hpu") + self.d_scale_k = torch.tensor([1.0], dtype=torch.float32, device="hpu") + self.d_scale_v = torch.tensor([1.0], dtype=torch.float32, device="hpu") + self.d_scale_output = torch.tensor([1.0], dtype=torch.float32, device="hpu") def quant_input(self, x, scale): return torch.ops.hpu.cast_to_fp8_v2(x, scale, False, False, torch.float8_e4m3fn)[0] @@ -581,7 +581,8 @@ def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, mask_chunk = mask_chunk.clone() if mask_chunk is not None else None self.break_graph() - chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, None, dropout_p, scale, False, softmax_mode) + chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p, scale, False, + softmax_mode) chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) if gqa else x for x in chunk_res[:3]) chunk_m = chunk_m.to(torch.float32) chunk_linv = chunk_linv.to(torch.float32) * (128.0 if softmax_mode == "fast" else 1.0) From 9f186b868a882a7d04e8a6f7666654482da58ee0 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Tue, 7 Apr 2026 03:20:55 +0000 Subject: [PATCH 5/5] use detach() instead of clone() to solve the accuracy issue Signed-off-by: Youlei Yang --- vllm_gaudi/extension/utils.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/vllm_gaudi/extension/utils.py b/vllm_gaudi/extension/utils.py index 03c413ab98..64b89b1a67 100644 --- a/vllm_gaudi/extension/utils.py +++ b/vllm_gaudi/extension/utils.py @@ -443,9 +443,9 @@ def forward( window_size=None, ): - qinput = self.quant_input(query, self.scale_q) - kinput = self.quant_input(key, self.scale_k) - vinput = self.quant_input(value, self.scale_v) + qinput = self.quant_input(query, self.scale_q).detach() + kinput = self.quant_input(key, self.scale_k).detach() + vinput = self.quant_input(value, self.scale_v).detach() bs = query.shape[0] q_len = query.shape[-2] @@ -510,8 +510,7 @@ def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, q_start = max(q_start, 0) q_end = q_len - q_chunk_idx * self.chunk_size q_chunk_size = q_end - q_start - q_chunk = q[..., q_start:q_end, :].clone() if self.with_graph_breaks else q[..., q_start:q_end, :] - + q_chunk = q[..., q_start:q_end, :].detach() last_out = None last_m = None last_linv = None @@ -522,21 +521,17 @@ def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, kv_start = max(kv_start, prefix_len) kv_end = prefix_len + q_end - kv_chunk_idx * self.chunk_size kv_chunk_size = kv_end - kv_start - k_chunk = k[..., kv_start:kv_end, :] - v_chunk = v[..., kv_start:kv_end, :] + k_chunk = k[..., kv_start:kv_end, :].detach() + v_chunk = v[..., kv_start:kv_end, :].detach() is_causal_chunk = kv_chunk_idx == 0 and q_chunk_idx >= self.num_padded_query_chunks # chunk sizes must be multiples of 1024 to get valid m and linv is_causal_chunk = is_causal_chunk and q_chunk_size % 1024 == 0 and kv_chunk_size % 1024 == 0 # use mask only for the causal chunks that may have padding - mask_chunk = (attn_mask[..., q_start:q_end, kv_start:kv_end] + mask_chunk = (attn_mask[..., q_start:q_end, kv_start:kv_end].detach() \ if kv_chunk_idx < self.num_padded_query_chunks and not is_causal_chunk else None) if self.with_graph_breaks: - # break_graph() cannot break the tensor slicing, use clone to isolate the graph - k_chunk = k_chunk.clone() - v_chunk = v_chunk.clone() - mask_chunk = mask_chunk.clone() if mask_chunk is not None else None self.break_graph() chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p, scale, is_causal_chunk,