diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index f04742c0cca..1e220b9e2f5 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -45,6 +45,9 @@ def __init__( self, model_runner: ModelRunner, skip_prefill: bool = False, + topk=0, + speculative_num_steps=0, + step_id=0, ): super().__init__() @@ -63,6 +66,10 @@ def __init__( self.use_mla = ( model_runner.model_config.attention_arch == AttentionArch.MLA ) and (not global_server_args_dict["disable_mla"]) + self.skip_prefill = skip_prefill + self.topk = topk + self.speculative_num_steps = speculative_num_steps + self.step_id = step_id def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize forward metadata to cache repetitive calculations.""" @@ -72,37 +79,125 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): # Get sequence information seqlens_in_batch = forward_batch.seq_lens # Precompute int32 version of sequence lengths - metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) batch_size = len(seqlens_in_batch) device = seqlens_in_batch.device - metadata.cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) - ) - # Precompute maximum sequence length - metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() - # Precompute page table - metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ - forward_batch.req_pool_indices, : metadata.max_seq_len_k - ] - - # Precompute strided indices - # [0, page_size, 2 * page_size, ...] - if self.page_size > 1: - self.strided_indices = torch.arange( - 0, metadata.page_table.shape[1], self.page_size, device=self.device - ) - metadata.page_table = ( - metadata.page_table[:, self.strided_indices] // self.page_size - ) if forward_batch.forward_mode == ForwardMode.DECODE: - # Precompute cumulative sequence lengths + if self.skip_prefill: + metadata.cu_seqlens_q = torch.arange( + 0, batch_size * self.topk + 1, dtype=torch.int32, device=device + ) + seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1) + metadata.cache_seqlens_int32 = ( + (seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32) + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( + self.step_id + 1 + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + metadata.page_table = metadata.page_table.repeat_interleave( + self.topk, dim=0 + ) + cache_loc = forward_batch.out_cache_loc.view( + self.speculative_num_steps, -1 + ).T + # Calculate page table indices and cache location indices to update the page table. + batch_indices = torch.arange( + batch_size, device=device + ).repeat_interleave(self.topk * (self.step_id + 1)) + topk_indices = torch.arange(self.topk, device=device).repeat( + batch_size * (self.step_id + 1) + ) + row_indices = batch_indices * self.topk + topk_indices + + page_table_col_base_indices = seqlens_in_batch.unsqueeze( + 1 + ) + torch.arange(self.step_id + 1, device=device) + page_table_col_indices = page_table_col_base_indices.view(-1).repeat( + self.topk + ) + + cache_loc_col_indices = torch.arange( + self.step_id + 1, device=device, dtype=torch.int32 + ).repeat(batch_size * self.topk) + + metadata.page_table[row_indices, page_table_col_indices] = cache_loc[ + row_indices, cache_loc_col_indices + ].to(torch.int32) + else: + metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + # Precompute maximum sequence length + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + # Precompute page table + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) + elif forward_batch.forward_mode == ForwardMode.TARGET_VERIFY: + draft_token_num = forward_batch.spec_info.draft_token_num + metadata.cu_seqlens_q = torch.arange( - 0, batch_size + 1, dtype=torch.int32, device=device + 0, batch_size * draft_token_num + 1, dtype=torch.int32, device=device + ) + + aug_seq_lens = (forward_batch.seq_lens + draft_token_num).to(torch.int32) + metadata.cache_seqlens_int32 = aug_seq_lens.repeat_interleave( + forward_batch.spec_info.draft_token_num + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32), + (1, 0), ) + metadata.max_seq_len_k = ( + forward_batch.seq_lens_cpu.max().item() + draft_token_num + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ].repeat_interleave(draft_token_num, dim=0) + aug_cum_len = torch.nn.functional.pad( + torch.cumsum(aug_seq_lens, dim=0, dtype=torch.int32), (1, 0) + ) + for idx, single_seq_len in enumerate(aug_seq_lens): + metadata.page_table[ + idx * draft_token_num : (idx + 1) * draft_token_num, :single_seq_len + ] *= forward_batch.spec_info.custom_mask[ + aug_cum_len[idx] + * draft_token_num : aug_cum_len[idx + 1] + * draft_token_num + ].view( + draft_token_num, -1 + ) + + metadata.max_seq_len_q = 1 else: + metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + # Precompute maximum sequence length + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + # Precompute page table + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] # Precompute cumulative sequence lengths - if any(forward_batch.extend_prefix_lens_cpu): + if ( + any(forward_batch.extend_prefix_lens_cpu) + or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND + ): extend_seq_lens = forward_batch.extend_seq_lens metadata.cu_seqlens_q = torch.nn.functional.pad( torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0) @@ -111,6 +206,16 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): else: metadata.cu_seqlens_q = metadata.cu_seqlens_k metadata.max_seq_len_q = metadata.max_seq_len_k + + # Precompute strided indices + # [0, page_size, 2 * page_size, ...] + if self.page_size > 1: + self.strided_indices = torch.arange( + 0, metadata.page_table.shape[1], self.page_size, device=self.device + ) + metadata.page_table = ( + metadata.page_table[:, self.strided_indices] // self.page_size + ) self.forward_metadata = metadata def forward_extend( @@ -281,8 +386,6 @@ def forward_decode( # Pre-reshape query tensor q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) - - # Run attention with precomputed values o = flash_attn_with_kvcache( q=q_reshaped, k_cache=key_cache, @@ -346,7 +449,11 @@ def init_cuda_graph_state(self, max_bs: int): This creates fixed-size tensors that will be reused during CUDA graph replay to avoid memory allocations. """ - # Initialize fixed size tensors for decode operations + if self.speculative_num_steps > 0: + raise NotImplementedError( + "FlashAttentionBackend Spec Decoding does not support CUDA graph yet, stay tuned!" + ) + self.decode_cuda_graph_metadata = { # Page table for token mapping (batch_size, max_context_len) "page_table": torch.zeros( @@ -385,7 +492,7 @@ def init_forward_metadata_capture_cuda_graph( metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ req_pool_indices, : ] - if forward_mode == ForwardMode.DECODE: + if forward_mode.is_cuda_graph(): # Precompute cumulative sequence lengths metadata.cu_seqlens_q = torch.arange( 0, batch_size + 1, dtype=torch.int32, device=device @@ -432,3 +539,66 @@ def init_forward_metadata_replay_cuda_graph( def get_cuda_graph_seq_len_fill_value(self): """Get the fill value for sequence length in CUDA graph.""" return 0 + + +class FlashAttentionMultiStepBackend: + + def __init__( + self, model_runner: ModelRunner, topk: int, speculative_num_steps: int + ): + self.model_runner = model_runner + self.topk = topk + self.speculative_num_steps = speculative_num_steps + + self.attn_backends = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + FlashAttentionBackend( + model_runner, + skip_prefill=True, + topk=self.topk, + speculative_num_steps=self.speculative_num_steps, + step_id=i, + ) + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + for i in range(self.speculative_num_steps - 1): + self.attn_backends[i].init_forward_metadata(forward_batch) + + def init_cuda_graph_state(self, max_bs: int): + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state(max_bs) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + assert forward_batch.spec_info is not None + assert isinstance(forward_batch.spec_info, EagleDraftInput) + + for i in range(self.speculative_num_steps - 1): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + def init_forward_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int + ): + assert forward_batch.spec_info is not None + assert isinstance(forward_batch.spec_info, EagleDraftInput) + + for i in range(self.speculative_num_steps - 1): + self.attn_backends[i].init_forward_metadata_replay_cuda_graph( + bs, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + seq_lens_cpu=forward_batch.seq_lens_cpu, + ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 234e13209e2..ffec4a0adb8 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -184,6 +184,19 @@ def init_attention_backend(self): self.draft_extend_attn_backend = None self.padded_static_len = self.speculative_num_steps + 1 self.has_prefill_wrapper_verify = True + elif self.server_args.attention_backend == "fa3": + from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionMultiStepBackend, + ) + + self.draft_attn_backend = FlashAttentionMultiStepBackend( + self.draft_model_runner, + self.topk, + self.speculative_num_steps, + ) + self.draft_extend_attn_backend = None + self.padded_static_len = self.speculative_num_steps + 1 + self.has_prefill_wrapper_verify = False else: raise ValueError( f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"