Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/configuration/env_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,20 @@ and warm-up. Recommended settings for this case are:

!!! note
If the model config specifies a high `max_model_len`, set it to the sum of `input_tokens` and `output_tokens`, rounded up to a multiple of `block_size` according to actual requirements.

## Additional Performance Tuning Parameters for the FusedSDPA Kernel with Linear Bucketing

FusedSDPA can be split into smaller chunks to improve performance by:

- fitting smaller chunks into SRAM,
- improving TPC/MME pipelining,
- reducing attention-mask usage.

| Parameter name | Description | Default value |
| ---------------------------------------- | -------------------------------------------------------------------------------------------- | ------------------------------------------ |
| `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD` | KV length threshold above which slicing is applied. Set to `-1` to disable slicing. | `min(max_num_batched_tokens, 8192)` |
| `VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE` | Chunk size for `q_len` and `kv_len` in each chunk. Rounded up to the next multiple of 1024. | `VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD // 2` |
| `VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS` | Places each chunk in a separate graph to reduce compilation time. | `true` for lazy mode and `false` otherwise |

!!! note
These parameters are effective only with the linear bucketing strategy.
5 changes: 5 additions & 0 deletions vllm_gaudi/extension/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def get_user_flags():
Env('PT_HPU_SDPA_QKV_SLICE_MODE_FWD', boolean),
Env('PT_HPU_SDPA_BC_FACTOR', int),
Env('VLLM_FUSEDSDPA_SLIDE_THLD', int),

# FusedSDPA slicing flags
Env('VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD', int),
Env('VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE', int),
Env('VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS', boolean),
]
return to_dict(flags)

Expand Down
4 changes: 0 additions & 4 deletions vllm_gaudi/extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,6 @@ def _fsdpa_prompt_attention(query: torch.Tensor,
recompute_mode = True
assert attn_bias is not None or valid_seq_lengths is not None, \
'Either attn_bias or valid_seq_lengths must be != None'
if is_causal and attn_bias is not None:
# TODO: causal + attn_bias is not yet supported
is_causal = False
valid_seq_lengths = None

args = [
query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths,
Expand Down
215 changes: 215 additions & 0 deletions vllm_gaudi/extension/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
###############################################################################

import os
import math
from functools import lru_cache, wraps
from typing import Optional, Any

import habana_frameworks.torch as htorch
import torch
import itertools
from vllm_gaudi.extension.logger import logger

from vllm_gaudi.extension.runtime import get_config

Expand Down Expand Up @@ -155,6 +157,7 @@ def __init__(self, fusedSDPA):
super().__init__()
assert fusedSDPA is not None, f'fusedSDPA kernel is None'
self._hpu_kernel_fsdpa = fusedSDPA
self.enable_slicing = self._setup_slicing()

def forward(
self,
Expand All @@ -172,6 +175,20 @@ def forward(
window_size=None,
sinks=None,
):
bs = query.shape[0]
q_len = query.shape[-2]
kv_len = key.shape[-2]
if (self.enable_slicing and kv_len >= self.slice_thld \
and bs == 1 # bs should be 1 for chunked prefill
and q_len != kv_len # normal causal prefill route to the default dispatch for better performance
and is_causal and attn_mask is not None # only supports causal attention with mask
):
return self._sliced_fsdpa_fwd(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
recompute_mode, valid_sequence_lengths, padding_side)
if is_causal and attn_mask is not None:
# TODO: causal + attn_bias is not yet supported
is_causal = False
valid_sequence_lengths = None
if window_size is not None:
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
recompute_mode, valid_sequence_lengths, padding_side, False, False,
Expand All @@ -181,6 +198,204 @@ def forward(
recompute_mode, valid_sequence_lengths, padding_side, False, False,
(-1, -1), sinks)

def _sliced_fsdpa_fwd(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode,
valid_sequence_lengths, padding_side):
assert is_causal and attn_mask is not None

from habana_frameworks.torch.hpex.kernels.FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape
gqa = is_gqa(query, key)
if gqa:
q, k, v, attn_mask = gqa_input_reshape_fwd(query, key, value, attn_mask)
else:
q, k, v, attn_mask = (query, key, value, attn_mask)
q_len = q.shape[-2]
kv_len = k.shape[-2]
prefix_len = kv_len - q_len

chunk_outputs = []
num_q_chunks = math.ceil(q_len / self.chunk_size)
num_prefix_chunks = math.ceil(prefix_len / self.chunk_size)
for q_chunk_idx in range(num_q_chunks):
q_start = q_len - (q_chunk_idx + 1) * self.chunk_size
q_start = max(q_start, 0)
q_end = q_len - q_chunk_idx * self.chunk_size
q_chunk_size = q_end - q_start
q_chunk = q[..., q_start:q_end, :].clone() if self.with_graph_breaks else q[..., q_start:q_end, :]

last_out = None
last_m = None
last_linv = None

# the causal part
for kv_chunk_idx in range(0, num_q_chunks - q_chunk_idx):
kv_start = prefix_len + q_end - (kv_chunk_idx + 1) * self.chunk_size
kv_start = max(kv_start, prefix_len)
kv_end = prefix_len + q_end - kv_chunk_idx * self.chunk_size
kv_chunk_size = kv_end - kv_start
k_chunk = k[..., kv_start:kv_end, :]
v_chunk = v[..., kv_start:kv_end, :]

is_causal_chunk = kv_chunk_idx == 0 and q_chunk_idx != 0
# chunk sizes must be multiples of 1024 to get valid m and linv
is_causal_chunk = is_causal_chunk and q_chunk_size % 1024 == 0 and kv_chunk_size % 1024 == 0
# use mask only for the causal chunks that may have padding
mask_chunk = attn_mask[
..., q_start:q_end,
kv_start:kv_end] if kv_chunk_idx < self.num_padded_query_chunks and not is_causal_chunk else None

if self.with_graph_breaks:
# break_graph() cannot break the tensor slicing, use clone to isolate the graph
k_chunk = k_chunk.clone()
v_chunk = v_chunk.clone()
mask_chunk = mask_chunk.clone() if mask_chunk is not None else None
self.break_graph()

chunk_res = torch.ops.hpu.sdpa_recomp_fwd(
q_chunk,
k_chunk,
v_chunk,
mask_chunk,
dropout_p,
scale,
is_causal_chunk,
True, # requires_backward
softmax_mode,
None, # valid_seq_len
padding_side,
)
chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32)
for x in (chunk_res[:3]))

if last_out is None or last_m is None or last_linv is None:
last_out = chunk_out
last_m = chunk_m
last_linv = chunk_linv
else:
new_m = torch.maximum(last_m, chunk_m)
last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m)
chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m)
last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled)
last_out = (last_linv_rescaled * last_linv) * last_out + (chunk_linv_rescaled *
last_linv) * chunk_out
last_m = new_m

if self.with_graph_breaks:
self.break_graph()
Comment thread
czhu15 marked this conversation as resolved.

# the context part
for kv_chunk_idx in range(num_prefix_chunks):
kv_start = prefix_len - (kv_chunk_idx + 1) * self.chunk_size
kv_start = max(kv_start, 0)
kv_end = prefix_len - kv_chunk_idx * self.chunk_size
k_chunk = k[..., kv_start:kv_end, :]
v_chunk = v[..., kv_start:kv_end, :]
# use mask only for the chunks that may have padding
mask_chunk = attn_mask[..., q_start:q_end,
kv_start:kv_end] if kv_chunk_idx < self.num_padded_ctx_chunks else None

if self.with_graph_breaks:
# break_graph() cannot break the tensor slicing, use clone to isolate the graph
k_chunk = k_chunk.clone()
v_chunk = v_chunk.clone()
mask_chunk = mask_chunk.clone() if mask_chunk is not None else None
self.break_graph()

chunk_res = torch.ops.hpu.sdpa_recomp_fwd(
q_chunk,
k_chunk,
v_chunk,
mask_chunk,
dropout_p,
scale,
False, # is_causal
True, # requires_backward
softmax_mode,
None, # valid_seq_len
padding_side,
)
chunk_out, chunk_m, chunk_linv = ((gqa_output_reshape(x) if gqa else x).to(torch.float32)
for x in chunk_res[:3])

assert not (last_out is None or last_m is None or last_linv is None)
new_m = torch.maximum(last_m, chunk_m)
last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m)
chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m)
last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled)
last_out = (last_linv_rescaled * last_linv) * last_out + (chunk_linv_rescaled * last_linv) * chunk_out
last_m = new_m

if self.with_graph_breaks:
self.break_graph()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above.

chunk_outputs.append(last_out)
chunk_outputs = list(reversed(chunk_outputs))
output = torch.cat(chunk_outputs, dim=-2)
return output.to(q.dtype)

def _setup_slicing(self) -> bool:
from vllm_gaudi.extension.bucketing.common import get_bucketing_manager
bucketing_manager = get_bucketing_manager()
enable_slicing = bucketing_manager is not None
if not enable_slicing:
logger().warning('Bucketing manager is not instantiated, slicing in FSDPA will be disabled.')
return False
assert bucketing_manager is not None
enable_slicing = enable_slicing and bucketing_manager.initialized
if not enable_slicing:
logger().warning('Bucketing manager is not initialized, slicing in FSDPA will be disabled.')
return False

from vllm_gaudi.extension.bucketing.linear import LinearBucketingStrategy
strategy = bucketing_manager.get_bucketing_strategy()
enable_slicing = isinstance(strategy, LinearBucketingStrategy)
if not enable_slicing:
logger().debug('Not using Linear Bucketing Strategy, slicing in FSDPA will be disabled.')
Comment thread
czhu15 marked this conversation as resolved.
return False

max_num_batched_tokens = bucketing_manager.max_num_batched_tokens
slice_thld_default = min(max_num_batched_tokens, 8192)
slice_thld = int(os.getenv("VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD", str(slice_thld_default)))
enable_slicing = enable_slicing and slice_thld >= slice_thld_default
if not enable_slicing and slice_thld > 0:
logger().warning('Invalid slice sequence length threshold, the threshold should be '
f'>= min(max_num_batched_tokens, 8192), falling back to default {slice_thld_default}.')
slice_thld = slice_thld_default

if enable_slicing:
# default to half of the threshold and round up by 1024
chunk_size_default = math.ceil(slice_thld // 2 / 1024) * 1024
chunk_size = int(os.getenv("VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE", str(chunk_size_default)))
block_size = bucketing_manager.block_size
if chunk_size < block_size or chunk_size > slice_thld:
logger().warning(f'Invalid chunk size for FusedSDPA slicing, the chunk size should be between '
f'{block_size} and {slice_thld}, falling back to default {chunk_size_default}.')
chunk_size = chunk_size_default
if chunk_size % 1024 != 0:
chunk_size = math.ceil(chunk_size / 1024) * 1024
logger().warning('Rounded up the chunk size for FusedSDPA slicing to the next multiple of 1024.')
self.slice_thld = slice_thld
self.chunk_size = chunk_size
max_query_pad_default = math.ceil(max_num_batched_tokens / 4)
max_query_pad = int(os.getenv("VLLM_PROMPT_QUERY_BUCKET_PAD_MAX", str(max_query_pad_default)))
self.num_padded_query_chunks = math.ceil(max_query_pad / self.chunk_size)
max_ctx_pad_default = math.ceil(max_num_batched_tokens / block_size)
max_ctx_pad = int(os.getenv("VLLM_PROMPT_CTX_BUCKET_PAD_MAX", str(max_ctx_pad_default)))
self.num_padded_ctx_chunks = math.ceil(max_ctx_pad * block_size / self.chunk_size)

import habana_frameworks.torch as ht
is_lazy = ht.utils.internal.is_lazy()
self.with_graph_breaks = os.getenv("VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS",
str(is_lazy)).strip().lower() in ("1", "true")
if self.with_graph_breaks:
if is_lazy:
self.break_graph = ht.core.mark_step
else:
self.break_graph = torch._dynamo.graph_break
msg = (f"FusedSDPA slicing is enabled with sequence length threshold {slice_thld}, "
f"chunk size {self.chunk_size}, num padded query chunks {self.num_padded_query_chunks}, "
f"num padded ctx chunks {self.num_padded_ctx_chunks}, with graph breaks {self.with_graph_breaks}.")
logger().debug(msg)
return enable_slicing


class ModuleFP8FusedSDPA(torch.nn.Module):

Expand Down