Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 125 additions & 8 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,8 +913,10 @@ def forward_decode(
# Use precomputed metadata across all layers
metadata = self.forward_metadata
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
use_local_attention = (
self.attention_chunk_size is not None and local_attn_metadata is not None
use_local_attn = (
self.attention_chunk_size is not None
and local_attn_metadata is not None
and (hasattr(layer, "use_irope") and layer.use_irope)
)
# We do cascade attention for Draft Decode with topk > 1
use_cascade_attn = self.topk > 1
Expand Down Expand Up @@ -970,7 +972,7 @@ def forward_decode(
k_descale=k_descale,
v_descale=v_descale,
)
elif use_local_attention:
elif use_local_attn:
# Use chunked (local) attention batching for self-attention
o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
Expand All @@ -979,7 +981,7 @@ def forward_decode(
page_table=local_attn_metadata.local_block_table,
cache_seqlens=local_attn_metadata.local_seqused_k,
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
cu_seqlens_k_new=metadata.cu_seqlens_k,
cu_seqlens_k_new=None,
max_seqlen_q=local_attn_metadata.local_max_query_len,
softmax_scale=layer.scaling,
causal=True,
Expand Down Expand Up @@ -1127,7 +1129,6 @@ def init_cuda_graph_state(self, max_bs: int):
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""

# This is being used by normal decode and draft decode when topk == 1
self.decode_cuda_graph_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
Expand All @@ -1154,6 +1155,34 @@ def init_cuda_graph_state(self, max_bs: int):
),
}

# Only allocate local attention buffers if local attention is enabled
# This prevents OOM errors when local attention is not being used
if self.attention_chunk_size is not None:
# Estimate maximum sizes for local attention metadata
max_seq_len = self.max_context_len
page_size = self.page_size or 1
attn_chunk_size = self.attention_chunk_size
max_virtual_batches = max_bs * (
(max_seq_len + attn_chunk_size - 1) // attn_chunk_size
)
max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size

self.decode_cuda_graph_local_attn_metadata = {
"local_query_start_loc": torch.zeros(
max_virtual_batches + 1, dtype=torch.int32, device=self.device
),
"local_seqused_k": torch.zeros(
max_virtual_batches, dtype=torch.int32, device=self.device
),
"local_block_table": torch.zeros(
max_virtual_batches,
max_blocks_per_seq * max_pages_per_block,
dtype=torch.int32,
device=self.device,
),
}

# This is used by draft decode's first half of metadata when topk > 1
if self.topk > 1:
self.draft_decode_metadata_topk_normal = {
Expand Down Expand Up @@ -1405,6 +1434,21 @@ def init_forward_metadata_capture_cuda_graph(
)
self.decode_cuda_graph_metadata[bs] = metadata

if self.attention_chunk_size is not None:
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
"local_query_start_loc"
],
local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
"local_seqused_k"
],
local_block_table=self.decode_cuda_graph_local_attn_metadata[
"local_block_table"
],
local_max_query_len=1,
local_max_seq_len=1,
)

elif forward_mode.is_target_verify():
if self.topk <= 1:
metadata.cache_seqlens_int32 = self.target_verify_metadata[
Expand Down Expand Up @@ -1572,8 +1616,7 @@ def init_forward_metadata_replay_cuda_graph(
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
cache_loc[:, :decode_length].contiguous().to(torch.int32)
)
# TODO: we need to test this part for llama 4 eagle case
self._init_local_attn_metadata(metadata, device)
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
else:
metadata = self.decode_cuda_graph_metadata[bs]
# Normal Decode
Expand All @@ -1599,7 +1642,7 @@ def init_forward_metadata_replay_cuda_graph(
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)

self._init_local_attn_metadata(metadata, device)
self._update_local_attn_metadata_for_replay(metadata, bs)
elif forward_mode.is_target_verify():
if self.topk <= 1:
metadata = self.target_verify_metadata[bs]
Expand Down Expand Up @@ -1755,6 +1798,7 @@ def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
page_table,
self.page_size,
)

local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
Expand All @@ -1764,6 +1808,79 @@ def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
)
metadata.local_attn_metadata = local_metadata

def _update_local_attn_metadata_for_replay(
self, metadata: FlashAttentionMetadata, bs: int
):
"""Update preallocated local attention metadata in-place before CUDA graph replay."""
if self.attention_chunk_size is None:
return

# Access preallocated buffers
local_q_buf = self.decode_cuda_graph_local_attn_metadata[
"local_query_start_loc"
]
local_k_buf = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"]
local_block_buf = self.decode_cuda_graph_local_attn_metadata[
"local_block_table"
]
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"]

# Create a modified version for local attention that only processes the last token
# This mimics the normal decode pattern
cu_seqlens_q = torch.arange(
bs + 1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype
)
seqlens = metadata.cache_seqlens_int32[:bs]
# Slice the page_table to match the batch size and actual sequence length
# This serves three important purposes:
# 1. Ensures we only process the actual batch size (bs) and not the maximum batch size
# 2. Limits the sequence length to prevent processing padding tokens or garbage values
# 3. Prevents zeros in the block table which can cause garbage output during replay
#
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
# beyond the actual sequence length, leading to incorrect attention calculations
max_seq_len = int(seqlens.max().item())
sliced_page_table = metadata.page_table[:bs, :max_seq_len]

cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
seqlens_np = seqlens.cpu().numpy()
(
seqlens_q_local_np,
cu_seqlens_q_local_np,
seqlens_k_local_np,
block_table_local,
) = make_local_attention_virtual_batches(
self.attention_chunk_size,
cu_seqlens_q_np,
seqlens_np,
sliced_page_table,
self.page_size,
)

# Convert back to tensors
device = local_q_buf.device
cu_seqlens_q_local = torch.from_numpy(cu_seqlens_q_local_np).to(device)
seqlens_k_local = torch.from_numpy(seqlens_k_local_np).to(device)
block_table_local = block_table_local.to(device)
# Get sizes
q_len = cu_seqlens_q_local.shape[0]
k_len = seqlens_k_local.shape[0]
b0, b1 = block_table_local.shape

# In-place updates into preallocated tensors and zero out the unused space
local_q_buf[:q_len].copy_(cu_seqlens_q_local)
local_q_buf[q_len:].fill_(0)
local_k_buf[:k_len].copy_(seqlens_k_local)
local_k_buf[k_len:].fill_(0)
local_block_buf[:b0, :b1].copy_(block_table_local)
local_block_buf[b0:, :].fill_(0)
local_block_buf[:b0, b1:].fill_(0)

if metadata.local_attn_metadata is not None:
lam = metadata.local_attn_metadata
lam.local_max_query_len = int(seqlens_q_local_np.max())
lam.local_max_seq_len = int(seqlens_k_local_np.max())


class FlashAttentionMultiStepBackend:

Expand Down
Loading