Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/configuration/env_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,20 @@ and warm-up. Recommended settings for this case are:

!!! note
If the model config specifies a high `max_model_len`, set it to the sum of `input_tokens` and `output_tokens`, rounded up to a multiple of `block_size` according to actual requirements.

## Additional Performance Tuning Parameters for the FusedSDPA Kernel with Linear Bucketing

FusedSDPA can be split into smaller chunks to improve performance by:

- fitting smaller chunks into SRAM,
- improving TPC/MME pipelining,
- reducing attention-mask usage.

| Parameter name | Description | Default value |
| ---------------------------------------- | -------------------------------------------------------------------------------------------- | ------------------------------------------ |
| `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD` | KV length threshold above which slicing is applied. Set to `-1` to disable slicing. | `min(max_num_batched_tokens, 8192)` |
| `VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE` | Chunk size for `q_len` and `kv_len` in each chunk. Rounded up to the next multiple of 1024. | `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD // 2` |
| `VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS` | Places each chunk in a separate graph to reduce compilation time. | `true` for lazy mode and `false` otherwise |
Comment on lines +106 to +110
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The markdown table uses || at the start of each row, which typically renders as an extra empty column (or may not render as intended depending on the markdown renderer). Use single | delimiters for standard GitHub-flavored markdown tables.

Copilot uses AI. Check for mistakes.

!!! note
These parameters are effective only with the linear bucketing strategy.
5 changes: 5 additions & 0 deletions vllm_gaudi/extension/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def get_user_flags():
Env('PT_HPU_SDPA_QKV_SLICE_MODE_FWD', boolean),
Env('PT_HPU_SDPA_BC_FACTOR', int),
Env('VLLM_FUSEDSDPA_SLIDE_THLD', int),

# FusedSDPA slicing flags
Env('VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD', int),
Env('VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE', int),
Env('VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS', boolean),
]
return to_dict(flags)

Expand Down
21 changes: 12 additions & 9 deletions vllm_gaudi/extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,12 @@ def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, bat
batch2block_matmul_op, block2batch_matmul_op):
# When fp32_softmax is enabled attn is left in fp32 after Q@K
# We can return to native dtype after we renormalize and calculate the adjustments
if block_bias is not None and attn.dtype != block_bias.dtype:
block_bias = block_bias.to(dtype=attn.dtype)
if block_bias is not None:
if block_bias.dtype == torch.bool:
# Convert boolean mask (True=valid, False=masked) to additive bias (0.0/-inf)
block_bias = torch.zeros_like(block_bias, dtype=attn.dtype).masked_fill_(~block_bias, float('-inf'))
elif attn.dtype != block_bias.dtype:
block_bias = block_bias.to(dtype=attn.dtype)
# TODO: w/a with 5D req as the block_softmax kernel does not support 4D attn tensor, which is used in e.g. Granite-3B
if get_config().fused_block_softmax and get_config().fused_block_softmax_adjustment and attn.dim() == 5:
attn, block_max, block_sums = torch.ops.hpu.block_softmax(attn, block_bias, block_groups)
Expand Down Expand Up @@ -357,9 +361,12 @@ def _naive_prompt_attention(query: torch.Tensor,
htcore.mark_step()
attn_weights.add_(position_bias)
if attn_bias is not None:
if attn_weights.dtype != attn_bias.dtype:
attn_bias = attn_bias.to(dtype=attn_weights.dtype)
attn_weights.add_(attn_bias)
if attn_bias.dtype == torch.bool:
attn_weights = attn_weights.masked_fill(~attn_bias, float("-inf"))
else:
if attn_weights.dtype != attn_bias.dtype:
attn_bias = attn_bias.to(dtype=attn_weights.dtype)
attn_weights.add_(attn_bias)
Comment on lines +364 to +369
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the boolean-mask path, masked_fill creates a new tensor and can significantly increase peak memory for large attention matrices. Prefer the in-place variant (attn_weights.masked_fill_(...)) when it’s safe in this function’s flow to avoid an extra allocation.

Copilot uses AI. Check for mistakes.
if sinks is not None:
sink = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
if query_heads != kv_heads:
Expand Down Expand Up @@ -404,10 +411,6 @@ def _fsdpa_prompt_attention(query: torch.Tensor,
recompute_mode = True
assert attn_bias is not None or valid_seq_lengths is not None, \
'Either attn_bias or valid_seq_lengths must be != None'
if is_causal and attn_bias is not None:
# TODO: causal + attn_bias is not yet supported
is_causal = False
valid_seq_lengths = None

args = [
query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths,
Expand Down
213 changes: 213 additions & 0 deletions vllm_gaudi/extension/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
###############################################################################

import os
import math
from functools import lru_cache, wraps
from typing import Optional, Any

import habana_frameworks.torch as htorch
import torch
import itertools
from vllm_gaudi.extension.logger import logger

from vllm_gaudi.extension.runtime import get_config

Expand Down Expand Up @@ -155,6 +157,7 @@ def __init__(self, fusedSDPA):
super().__init__()
assert fusedSDPA is not None, f'fusedSDPA kernel is None'
self._hpu_kernel_fsdpa = fusedSDPA
self.enable_slicing = self._setup_slicing()

def forward(
self,
Expand All @@ -172,6 +175,20 @@ def forward(
window_size=None,
sinks=None,
):
bs = query.shape[0]
q_len = query.shape[-2]
kv_len = key.shape[-2]
if (self.enable_slicing and kv_len >= self.slice_thld \
and bs == 1 # bs should be 1 for chunked prefill
and q_len != kv_len # normal causal prefill route to the default dispatch for better performance
and is_causal and attn_mask is not None # only supports causal attention with mask
):
return self._sliced_fsdpa_fwd(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
recompute_mode, valid_sequence_lengths, padding_side)
if is_causal and attn_mask is not None:
# TODO: causal + attn_bias is not yet supported
is_causal = False
valid_sequence_lengths = None
Comment on lines +188 to +191
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description states the causal+attn_bias workaround is removed now that boolean masks are supported, but this ModuleFusedSDPA.forward path still forcibly disables is_causal when attn_mask is present. If boolean masking is intended to enable is_causal + mask in the fused kernel, this block should be removed or updated to only apply to unsupported mask dtypes/shapes; otherwise the description should be updated to reflect that the workaround still exists here.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, if the threshold is met then we do not enter here so it's not that misleading

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole if is_causal ... could be removed because we should be able to support all of that. Slawomir Laba can help if this doesn't work. If there are issues then it's not needed though so do not block the work here for the enablement of it.

if window_size is not None:
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
recompute_mode, valid_sequence_lengths, padding_side, False, False,
Expand All @@ -181,6 +198,202 @@ def forward(
recompute_mode, valid_sequence_lengths, padding_side, False, False,
(-1, -1), sinks)

def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode,
valid_sequence_lengths, padding_side):
assert is_causal and attn_mask is not None

from habana_frameworks.torch.hpex.kernels.FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape
gqa = is_gqa(query, key)
if gqa:
q, k, v, attn_mask = gqa_input_reshape_fwd(query, key, value, attn_mask)
else:
q, k, v, attn_mask = (query, key, value, attn_mask)
q_len = q.shape[-2]
kv_len = k.shape[-2]
prefix_len = kv_len - q_len

chunk_outputs = []
num_q_chunks = math.ceil(q_len / self.chunk_size)
num_prefix_chunks = math.ceil(prefix_len / self.chunk_size)
for q_chunk_idx in range(num_q_chunks):
q_start = q_len - (q_chunk_idx + 1) * self.chunk_size
q_start = max(q_start, 0)
q_end = q_len - q_chunk_idx * self.chunk_size
q_chunk_size = q_end - q_start
q_chunk = q[..., q_start:q_end, :].clone() if self.with_graph_breaks else q[..., q_start:q_end, :]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've never seen ... operator, TIL. I trust it's needed here


last_out = None
last_m = None
last_linv = None

# the causal part
for kv_chunk_idx in range(0, num_q_chunks - q_chunk_idx):
kv_start = prefix_len + q_end - (kv_chunk_idx + 1) * self.chunk_size
kv_start = max(kv_start, prefix_len)
kv_end = prefix_len + q_end - kv_chunk_idx * self.chunk_size
kv_chunk_size = kv_end - kv_start
k_chunk = k[..., kv_start:kv_end, :]
v_chunk = v[..., kv_start:kv_end, :]

is_causal_chunk = kv_chunk_idx == 0 and q_chunk_idx != 0
# chunk sizes must be multiples of 1024 to get valid m and linv
is_causal_chunk = is_causal_chunk and q_chunk_size % 1024 == 0 and kv_chunk_size % 1024 == 0
# use mask only for the causal chunks that may have padding
mask_chunk = attn_mask[
..., q_start:q_end,
kv_start:kv_end] if kv_chunk_idx < self.num_padded_query_chunks and not is_causal_chunk else None

if self.with_graph_breaks:
k_chunk = k_chunk.clone()
v_chunk = v_chunk.clone()
mask_chunk = mask_chunk.clone() if mask_chunk is not None else None
self.break_graph()

chunk_res = torch.ops.hpu.sdpa_recomp_fwd(
q_chunk,
k_chunk,
v_chunk,
mask_chunk,
dropout_p,
scale,
is_causal_chunk,
True, # requires_backward
softmax_mode,
None, # valid_seq_len
padding_side,
)
chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32)
for x in (chunk_res[:3]))
Comment on lines +252 to +266
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sliced path hard-codes requires_backward=True and unconditionally casts chunk outputs (including chunk_out) to float32. For long contexts this can materially increase memory and runtime overhead in inference. If the kernel/API allows it, consider disabling backward requirements and limiting float32 to the accumulator numerics (e.g., keep m/linv and the running mix in float32, but avoid extra float32 copies of each chunk_out unless needed).

Copilot uses AI. Check for mistakes.

if last_out is None or last_m is None or last_linv is None:
last_out = chunk_out
last_m = chunk_m
last_linv = chunk_linv
else:
new_m = torch.maximum(last_m, chunk_m)
last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m)
chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m)
last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled)
last_out = (last_linv_rescaled * last_linv) * last_out + (chunk_linv_rescaled *
last_linv) * chunk_out
last_m = new_m

if self.with_graph_breaks:
self.break_graph()

# the context part
for kv_chunk_idx in range(num_prefix_chunks):
kv_start = prefix_len - (kv_chunk_idx + 1) * self.chunk_size
kv_start = max(kv_start, 0)
kv_end = prefix_len - kv_chunk_idx * self.chunk_size
k_chunk = k[..., kv_start:kv_end, :]
v_chunk = v[..., kv_start:kv_end, :]
# use mask only for the chunks that may have padding
mask_chunk = attn_mask[..., q_start:q_end,
kv_start:kv_end] if kv_chunk_idx < self.num_padded_ctx_chunks else None

if self.with_graph_breaks:
k_chunk = k_chunk.clone()
v_chunk = v_chunk.clone()
mask_chunk = mask_chunk.clone() if mask_chunk is not None else None
self.break_graph()

chunk_res = torch.ops.hpu.sdpa_recomp_fwd(
q_chunk,
k_chunk,
v_chunk,
mask_chunk,
dropout_p,
scale,
False, # is_causal
True, # requires_backward
softmax_mode,
None, # valid_seq_len
padding_side,
)
chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32)
for x in chunk_res[:3])

assert not (last_out is None or last_m is None or last_linv is None)
new_m = torch.maximum(last_m, chunk_m)
last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m)
chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m)
last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled)
last_out = (last_linv_rescaled * last_linv) * last_out + (chunk_linv_rescaled * last_linv) * chunk_out
last_m = new_m

if self.with_graph_breaks:
self.break_graph()
chunk_outputs.append(last_out)
chunk_outputs = list(reversed(chunk_outputs))
output = torch.cat(chunk_outputs, dim=-2)
return output.to(q.dtype)

def _setup_slicing(self) -> bool:
from vllm_gaudi.extension.bucketing.common import get_bucketing_manager
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The imports inside are needed because they are not seen otherwise? If not we could place them outside.

bucketing_manager = get_bucketing_manager()
enable_slicing = bucketing_manager is not None
if not enable_slicing:
logger().warning('Bucketing manager is not instantiated, slicing in FSDPA will be disabled.')
return False
assert bucketing_manager is not None
enable_slicing = enable_slicing and bucketing_manager.initialized
if not enable_slicing:
logger().warning('Bucketing manager is not initialized, slicing in FSDPA will be disabled.')
return False

from vllm_gaudi.extension.bucketing.linear import LinearBucketingStrategy
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The imports inside are needed because they are not seen otherwise? If not we could place them outside.

strategy = bucketing_manager.get_bucketing_strategy()
enable_slicing = isinstance(strategy, LinearBucketingStrategy)
if not enable_slicing:
logger().debug('Not using Linear Bucketing Strategy, slicing in FSDPA will be disabled.')
return False

max_num_batched_tokens = bucketing_manager.max_num_batched_tokens
slice_thld_default = min(max_num_batched_tokens, 8192)
slice_thld = int(os.getenv("VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD", str(slice_thld_default)))
enable_slicing = enable_slicing and slice_thld >= slice_thld_default
if not enable_slicing and slice_thld > 0:
logger().warning('Invalid slice sequence length threshold, the threshold should be '
f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.')
slice_thld = slice_thld_default
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user sets VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD to a positive value below the default, the code logs that it is “falling back to default” but slicing remains disabled because enable_slicing is never re-enabled after resetting slice_thld. After assigning slice_thld = slice_thld_default, either set enable_slicing back to True (since prerequisites were met) or restructure the validation to compute enable_slicing only after sanitizing slice_thld.

Suggested change
slice_thld = slice_thld_default
slice_thld = slice_thld_default
enable_slicing = True

Copilot uses AI. Check for mistakes.

if enable_slicing:
# default to half of the threshold and round up by 1024
chunk_size_default = math.ceil(slice_thld // 2 / 1024) * 1024
chunk_size = int(os.getenv("VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE", str(chunk_size_default)))
block_size = bucketing_manager.block_size
if chunk_size < block_size or chunk_size > slice_thld:
logger().warning(f'Invalid chunk size for FusedSDPA slicing, the chunk size should be between '
f'{block_size} and {slice_thld}, falling back to default {chunk_size_default}.')
chunk_size = chunk_size_default
if chunk_size % 1024 != 0:
chunk_size = math.ceil(chunk_size / 1024) * 1024
logger().warning('Rounded up the chunk size for FusedSDPA slicing to the next multiple of 1024.')
self.slice_thld = slice_thld
self.chunk_size = chunk_size
max_query_pad_default = math.ceil(max_num_batched_tokens / 4)
max_query_pad = int(os.getenv("VLLM_PROMPT_QUERY_BUCKET_PAD_MAX", str(max_query_pad_default)))
self.num_padded_query_chunks = math.ceil(max_query_pad / self.chunk_size)
max_ctx_pad_default = math.ceil(max_num_batched_tokens / block_size)
max_ctx_pad = int(os.getenv("VLLM_PROMPT_CTX_BUCKET_PAD_MAX", str(max_ctx_pad_default)))
self.num_padded_ctx_chunks = math.ceil(max_ctx_pad * block_size / self.chunk_size)

import habana_frameworks.torch as ht
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The imports inside are needed because they are not seen otherwise? If not we could place them outside.

is_lazy = ht.utils.internal.is_lazy()
self.with_graph_breaks = os.getenv("VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS",
str(is_lazy)).strip().lower() in ("1", "true")
if self.with_graph_breaks:
if is_lazy:
self.break_graph = ht.core.mark_step
else:
self.break_graph = torch._dynamo.graph_break
msg = (f"FusedSDPA slicing is enabled with sequence length threshold {slice_thld}, "
f"chunk size {self.chunk_size}, num padded query chunks {self.num_padded_query_chunks}, "
f"num padded ctx chunks {self.num_padded_ctx_chunks}, with graph breaks {self.with_graph_breaks}.")
logger().debug(msg)
return enable_slicing


class ModuleFP8FusedSDPA(torch.nn.Module):

Expand Down
Loading
Loading