diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 8618c01f33e..ff6d81c85ce 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -913,8 +913,10 @@ def forward_decode( # Use precomputed metadata across all layers metadata = self.forward_metadata local_attn_metadata = getattr(metadata, "local_attn_metadata", None) - use_local_attention = ( - self.attention_chunk_size is not None and local_attn_metadata is not None + use_local_attn = ( + self.attention_chunk_size is not None + and local_attn_metadata is not None + and (hasattr(layer, "use_irope") and layer.use_irope) ) # We do cascade attention for Draft Decode with topk > 1 use_cascade_attn = self.topk > 1 @@ -970,7 +972,7 @@ def forward_decode( k_descale=k_descale, v_descale=v_descale, ) - elif use_local_attention: + elif use_local_attn: # Use chunked (local) attention batching for self-attention o = flash_attn_with_kvcache( q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), @@ -979,7 +981,7 @@ def forward_decode( page_table=local_attn_metadata.local_block_table, cache_seqlens=local_attn_metadata.local_seqused_k, cu_seqlens_q=local_attn_metadata.local_query_start_loc, - cu_seqlens_k_new=metadata.cu_seqlens_k, + cu_seqlens_k_new=None, max_seqlen_q=local_attn_metadata.local_max_query_len, softmax_scale=layer.scaling, causal=True, @@ -1127,7 +1129,6 @@ def init_cuda_graph_state(self, max_bs: int): This creates fixed-size tensors that will be reused during CUDA graph replay to avoid memory allocations. """ - # This is being used by normal decode and draft decode when topk == 1 self.decode_cuda_graph_metadata = { "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), @@ -1154,6 +1155,34 @@ def init_cuda_graph_state(self, max_bs: int): ), } + # Only allocate local attention buffers if local attention is enabled + # This prevents OOM errors when local attention is not being used + if self.attention_chunk_size is not None: + # Estimate maximum sizes for local attention metadata + max_seq_len = self.max_context_len + page_size = self.page_size or 1 + attn_chunk_size = self.attention_chunk_size + max_virtual_batches = max_bs * ( + (max_seq_len + attn_chunk_size - 1) // attn_chunk_size + ) + max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size + max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size + + self.decode_cuda_graph_local_attn_metadata = { + "local_query_start_loc": torch.zeros( + max_virtual_batches + 1, dtype=torch.int32, device=self.device + ), + "local_seqused_k": torch.zeros( + max_virtual_batches, dtype=torch.int32, device=self.device + ), + "local_block_table": torch.zeros( + max_virtual_batches, + max_blocks_per_seq * max_pages_per_block, + dtype=torch.int32, + device=self.device, + ), + } + # This is used by draft decode's first half of metadata when topk > 1 if self.topk > 1: self.draft_decode_metadata_topk_normal = { @@ -1405,6 +1434,21 @@ def init_forward_metadata_capture_cuda_graph( ) self.decode_cuda_graph_metadata[bs] = metadata + if self.attention_chunk_size is not None: + metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( + local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[ + "local_query_start_loc" + ], + local_seqused_k=self.decode_cuda_graph_local_attn_metadata[ + "local_seqused_k" + ], + local_block_table=self.decode_cuda_graph_local_attn_metadata[ + "local_block_table" + ], + local_max_query_len=1, + local_max_seq_len=1, + ) + elif forward_mode.is_target_verify(): if self.topk <= 1: metadata.cache_seqlens_int32 = self.target_verify_metadata[ @@ -1572,8 +1616,7 @@ def init_forward_metadata_replay_cuda_graph( metadata_expand.page_table[: cache_loc.shape[0]].copy_( cache_loc[:, :decode_length].contiguous().to(torch.int32) ) - # TODO: we need to test this part for llama 4 eagle case - self._init_local_attn_metadata(metadata, device) + # TODO: Handle local attention metadata for draft decode when llama4 eagle is supported else: metadata = self.decode_cuda_graph_metadata[bs] # Normal Decode @@ -1599,7 +1642,7 @@ def init_forward_metadata_replay_cuda_graph( metadata.page_table[:, :max_seq_pages].copy_(page_indices) metadata.page_table[:, max_seq_pages:].fill_(0) - self._init_local_attn_metadata(metadata, device) + self._update_local_attn_metadata_for_replay(metadata, bs) elif forward_mode.is_target_verify(): if self.topk <= 1: metadata = self.target_verify_metadata[bs] @@ -1755,6 +1798,7 @@ def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device): page_table, self.page_size, ) + local_metadata = FlashAttentionMetadata.LocalAttentionMetadata( local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device), local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device), @@ -1764,6 +1808,79 @@ def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device): ) metadata.local_attn_metadata = local_metadata + def _update_local_attn_metadata_for_replay( + self, metadata: FlashAttentionMetadata, bs: int + ): + """Update preallocated local attention metadata in-place before CUDA graph replay.""" + if self.attention_chunk_size is None: + return + + # Access preallocated buffers + local_q_buf = self.decode_cuda_graph_local_attn_metadata[ + "local_query_start_loc" + ] + local_k_buf = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"] + local_block_buf = self.decode_cuda_graph_local_attn_metadata[ + "local_block_table" + ] + cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"] + + # Create a modified version for local attention that only processes the last token + # This mimics the normal decode pattern + cu_seqlens_q = torch.arange( + bs + 1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype + ) + seqlens = metadata.cache_seqlens_int32[:bs] + # Slice the page_table to match the batch size and actual sequence length + # This serves three important purposes: + # 1. Ensures we only process the actual batch size (bs) and not the maximum batch size + # 2. Limits the sequence length to prevent processing padding tokens or garbage values + # 3. Prevents zeros in the block table which can cause garbage output during replay + # + # Without this slicing, the pre-allocated page_table may contain zeros or invalid indices + # beyond the actual sequence length, leading to incorrect attention calculations + max_seq_len = int(seqlens.max().item()) + sliced_page_table = metadata.page_table[:bs, :max_seq_len] + + cu_seqlens_q_np = cu_seqlens_q.cpu().numpy() + seqlens_np = seqlens.cpu().numpy() + ( + seqlens_q_local_np, + cu_seqlens_q_local_np, + seqlens_k_local_np, + block_table_local, + ) = make_local_attention_virtual_batches( + self.attention_chunk_size, + cu_seqlens_q_np, + seqlens_np, + sliced_page_table, + self.page_size, + ) + + # Convert back to tensors + device = local_q_buf.device + cu_seqlens_q_local = torch.from_numpy(cu_seqlens_q_local_np).to(device) + seqlens_k_local = torch.from_numpy(seqlens_k_local_np).to(device) + block_table_local = block_table_local.to(device) + # Get sizes + q_len = cu_seqlens_q_local.shape[0] + k_len = seqlens_k_local.shape[0] + b0, b1 = block_table_local.shape + + # In-place updates into preallocated tensors and zero out the unused space + local_q_buf[:q_len].copy_(cu_seqlens_q_local) + local_q_buf[q_len:].fill_(0) + local_k_buf[:k_len].copy_(seqlens_k_local) + local_k_buf[k_len:].fill_(0) + local_block_buf[:b0, :b1].copy_(block_table_local) + local_block_buf[b0:, :].fill_(0) + local_block_buf[:b0, b1:].fill_(0) + + if metadata.local_attn_metadata is not None: + lam = metadata.local_attn_metadata + lam.local_max_query_len = int(seqlens_q_local_np.max()) + lam.local_max_seq_len = int(seqlens_k_local_np.max()) + class FlashAttentionMultiStepBackend: