diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 0682e52b5ed..4bdb218207b 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -16,6 +16,7 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner +from sgl_kernel import merge_state_v2 from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache @@ -30,7 +31,7 @@ class FlashAttentionMetadata: # Sequence lengths for the forward batch cache_seqlens_int32: torch.Tensor = None # Maximum sequence length for query - max_seq_len_q: int = 0 + max_seq_len_q: int = 1 # Maximum sequence length for key max_seq_len_k: int = 0 # Cumulative sequence lengths for query @@ -267,6 +268,12 @@ def cdiv(a: int, b: int) -> int: return -(a // -b) +# TODO(hebiao064): remove this once we have a better way to handle the merge_state_v2 torch.compile issue +@torch._dynamo.disable() +def merge_state_v2_wrapper(o, s_a, o_exp, s_b): + return merge_state_v2(o, s_a, o_exp, s_b) + + class FlashAttentionBackend(AttentionBackend): """FlashAttention backend implementation. @@ -301,6 +308,8 @@ def __init__( ), "Sliding window and cross attention are not supported together" self.forward_metadata: FlashAttentionMetadata = None + # extra metdata for handling speculative decoding topk > 1, extended draft decode and verify + self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device self.decode_cuda_graph_metadata = {} @@ -311,8 +320,7 @@ def __init__( self.page_size = model_runner.page_size self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA self.skip_prefill = skip_prefill - - self.topk = topk + self.topk = model_runner.server_args.speculative_eagle_topk or 0 self.speculative_num_steps = speculative_num_steps self.speculative_num_draft_tokens = ( model_runner.server_args.speculative_num_draft_tokens @@ -336,14 +344,107 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode_or_idle(): # Draft Decode if forward_batch.spec_info is not None: + if self.topk <= 1: + metadata.cache_seqlens_int32 = ( + seqlens_in_batch + (self.speculative_step_id + 1) + ).to(torch.int32) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( + self.speculative_step_id + 1 + ) + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + else: + metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32) + metadata.max_seq_len_q = self.topk + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + metadata.cu_seqlens_q = torch.arange( + 0, + batch_size * self.topk + 1, + step=self.topk, + dtype=torch.int32, + device=device, + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + + metadata_expand = FlashAttentionMetadata() + decode_length = self.speculative_step_id + 1 + metadata_expand.cache_seqlens_int32 = torch.full( + (seqlens_in_batch.numel() * self.topk,), + decode_length, + device=device, + dtype=torch.int32, + ) + metadata_expand.max_seq_len_q = 1 + metadata_expand.max_seq_len_k = self.speculative_step_id + 1 + metadata_expand.cu_seqlens_q = torch.arange( + 0, + metadata_expand.cache_seqlens_int32.numel() + 1, + dtype=torch.int32, + device=device, + ) + metadata_expand.cu_seqlens_k = torch.arange( + 0, + metadata_expand.cache_seqlens_int32.numel() * decode_length + 1, + step=decode_length, + dtype=torch.int32, + device=device, + ) + cache_loc = forward_batch.out_cache_loc.view( + self.speculative_num_steps, -1 + ).T.contiguous() + metadata_expand.page_table = ( + cache_loc[:, :decode_length].contiguous().to(torch.int32) + ) + self.forward_metadata_spec_decode_expand = metadata_expand + else: + # Normal Decode + metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + # TODO: we need to test this part for llama 4 eagle case + self._init_local_attn_metadata(metadata, device) + elif forward_batch.forward_mode.is_target_verify(): + if self.topk <= 1: metadata.cache_seqlens_int32 = ( - seqlens_in_batch + (self.speculative_step_id + 1) + forward_batch.seq_lens + self.speculative_num_draft_tokens ).to(torch.int32) - metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( - self.speculative_step_id + 1 + metadata.max_seq_len_q = self.speculative_num_draft_tokens + metadata.max_seq_len_k = ( + forward_batch.seq_lens_cpu.max().item() + + self.speculative_num_draft_tokens ) metadata.cu_seqlens_q = torch.arange( - 0, batch_size + 1, dtype=torch.int32, device=device + 0, + batch_size * self.speculative_num_draft_tokens + 1, + self.speculative_num_draft_tokens, + dtype=torch.int32, + device=device, ) metadata.cu_seqlens_k = torch.nn.functional.pad( torch.cumsum( @@ -357,44 +458,101 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self._init_local_attn_metadata(metadata, device) else: - # Normal Decode - metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32) + metadata.max_seq_len_q = self.speculative_num_draft_tokens metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() metadata.cu_seqlens_q = torch.arange( - 0, batch_size + 1, dtype=torch.int32, device=device + 0, + batch_size * self.speculative_num_draft_tokens + 1, + step=self.speculative_num_draft_tokens, + dtype=torch.int32, + device=device, ) metadata.cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), ) metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] - self._init_local_attn_metadata(metadata, device) - elif forward_batch.forward_mode.is_target_verify(): - metadata.cache_seqlens_int32 = ( - forward_batch.seq_lens + self.speculative_num_draft_tokens - ).to(torch.int32) - metadata.max_seq_len_q = self.speculative_num_draft_tokens - metadata.max_seq_len_k = ( - forward_batch.seq_lens_cpu.max().item() - + self.speculative_num_draft_tokens - ) - metadata.cu_seqlens_q = torch.arange( - 0, - batch_size * self.speculative_num_draft_tokens + 1, - self.speculative_num_draft_tokens, - dtype=torch.int32, - device=device, - ) - metadata.cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32), - (1, 0), - ) - metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ - forward_batch.req_pool_indices, : metadata.max_seq_len_k - ] + metadata_expand = FlashAttentionMetadata() + metadata_expand.max_seq_len_q = 1 + metadata_expand.cu_seqlens_q = torch.arange( + 0, + forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens + + 1, + dtype=torch.int32, + device=device, + ) + + # create expand page table + offsets = torch.arange( + self.speculative_num_draft_tokens, device=device + ).unsqueeze( + 0 + ) # shape: (1, self.speculative_num_draft_tokens) + cols = offsets.expand( + forward_batch.seq_lens.numel(), -1 + ) + forward_batch.seq_lens.unsqueeze(1) + cum_len = torch.nn.functional.pad( + torch.cumsum( + ( + forward_batch.seq_lens + self.speculative_num_draft_tokens + ).repeat_interleave(self.speculative_num_draft_tokens), + dim=0, + ), + (1, 0), + )[:-1] + mask_extraction_indices = ( + cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0) + + cum_len[:, None] + ).view(1, -1) + mask = forward_batch.spec_info.custom_mask[ + mask_extraction_indices + ].view( + -1, self.speculative_num_draft_tokens + ) # (bsz * draft_num, draft_num) + + # shift table indices to avoid padding + # non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0], + # [8, 9, 10], [1, 1, 0], + # [8, 9, 10]] [1, 0, 1]] + # if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10], + # [8, 9, 0], [8, 9, 10], + # [8, 0, 10]] [8, 10, 9]] + # note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row + col_indices = offsets.expand( + mask.shape[0], self.speculative_num_draft_tokens + ) + # Build keys: if an entry is valid (mask==True), keep its original index; + # if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries. + keys = torch.where( + mask, col_indices, col_indices + self.speculative_num_draft_tokens + ) + _, sort_order = torch.sort(keys, dim=1) + non_masked_page_table = ( + forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + .gather(1, cols) + .repeat_interleave(self.speculative_num_draft_tokens, dim=0) + ) # (bsz, draft_num) + metadata_expand.page_table = non_masked_page_table.gather(1, sort_order) + metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32) + metadata_expand.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + metadata_expand.max_seq_len_k = ( + metadata_expand.cache_seqlens_int32.max().item() + ) + self.forward_metadata_spec_decode_expand = metadata_expand elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(): metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() @@ -514,6 +672,11 @@ def forward_extend( and (hasattr(layer, "use_irope") and layer.use_irope) ) + # We do cascade attention for Target Verify with topk > 1 + use_cascade_attn = ( + forward_batch.forward_mode.is_target_verify() and self.topk > 1 + ) + # Get the appropriate page table based on whether we're using local attention if use_local_attn: local_metadata = metadata.local_attn_metadata @@ -548,7 +711,7 @@ def forward_extend( cu_seqlens_k = metadata.encoder_cu_seqlens_k window_size = (-1, -1) - o = flash_attn_with_kvcache( + result = flash_attn_with_kvcache( q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), k_cache=key_cache, v_cache=value_cache, @@ -558,13 +721,41 @@ def forward_extend( cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, max_seqlen_q=max_seqlen_q, softmax_scale=layer.scaling, - causal=causal, + causal=False if use_cascade_attn else causal, window_size=window_size, softcap=layer.logit_cap, k_descale=k_descale, v_descale=v_descale, + return_softmax_lse=use_cascade_attn, ) - return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + if use_cascade_attn: + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache=key_cache, + v_cache=value_cache, + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + softmax_scale=layer.scaling, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + ) + o, _ = merge_state_v2_wrapper( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result else: if ( not global_server_args_dict["disable_chunked_prefix_cache"] @@ -627,7 +818,8 @@ def forward_extend( q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q_nope = q_all[:, :, : layer.v_head_dim] q_rope = q_all[:, :, layer.v_head_dim :] - o = flash_attn_with_kvcache( + + result = flash_attn_with_kvcache( q=q_rope, k_cache=k_rope_cache, v_cache=c_kv_cache, @@ -638,13 +830,44 @@ def forward_extend( cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, max_seqlen_q=max_seqlen_q, softmax_scale=layer.scaling, - causal=True, + causal=False if use_cascade_attn else causal, softcap=layer.logit_cap, k_descale=k_descale, v_descale=v_descale, + return_softmax_lse=use_cascade_attn, ) + if use_cascade_attn: + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = ( + flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + softmax_scale=layer.scaling, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + ) + ) + o, _ = merge_state_v2_wrapper( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result - return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) def forward_decode( self, @@ -681,6 +904,8 @@ def forward_decode( use_local_attention = ( self.attention_chunk_size is not None and local_attn_metadata is not None ) + # We do cascade attention for Draft Decode with topk > 1 + use_cascade_attn = self.topk > 1 # Calculate window size (can be moved to metadata if layer properties don't change) # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 @@ -752,23 +977,61 @@ def forward_decode( v_descale=v_descale, ) else: + page_table = metadata.page_table + cache_seqlens = metadata.cache_seqlens_int32 + cu_seqlens_k = metadata.cu_seqlens_k + max_seqlen_q = metadata.max_seq_len_q + q_reshaped = q.contiguous().view( + -1, layer.tp_q_head_num, layer.head_dim + ) + # Default: single-token self-attention - o = flash_attn_with_kvcache( - q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + result = flash_attn_with_kvcache( + q=q_reshaped, k_cache=key_cache, v_cache=value_cache, - page_table=metadata.page_table, - cache_seqlens=metadata.cache_seqlens_int32, + page_table=page_table, + cache_seqlens=cache_seqlens, cu_seqlens_q=metadata.cu_seqlens_q, - cu_seqlens_k_new=metadata.cu_seqlens_k, - max_seqlen_q=1, + cu_seqlens_k_new=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, softmax_scale=layer.scaling, - causal=True, + causal=False if use_cascade_attn else causal, window_size=window_size, softcap=layer.logit_cap, k_descale=k_descale, v_descale=v_descale, + return_softmax_lse=use_cascade_attn, ) + if use_cascade_attn: + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = ( + flash_attn_with_kvcache( + q=q_reshaped, + k_cache=key_cache, + v_cache=value_cache, + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + softmax_scale=layer.scaling, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + ) + ) + o, _ = merge_state_v2( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result else: # Do absorbed multi-latent attention kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) @@ -787,8 +1050,9 @@ def forward_decode( q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q_nope = q_all[:, :, : layer.v_head_dim] q_rope = q_all[:, :, layer.v_head_dim :] + max_seqlen_q = metadata.max_seq_len_q - o = flash_attn_with_kvcache( + result = flash_attn_with_kvcache( q=q_rope, k_cache=k_rope_cache, v_cache=c_kv_cache, @@ -797,13 +1061,43 @@ def forward_decode( cache_seqlens=metadata.cache_seqlens_int32, cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_k_new=metadata.cu_seqlens_k, - max_seqlen_q=1, + max_seqlen_q=max_seqlen_q, softmax_scale=layer.scaling, - causal=True, + causal=False if use_cascade_attn else causal, softcap=layer.logit_cap, k_descale=k_descale, v_descale=v_descale, + return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states ) + if use_cascade_attn: + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + softmax_scale=layer.scaling, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + ) + o, _ = merge_state_v2( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) def init_cuda_graph_state(self, max_bs: int): @@ -815,6 +1109,8 @@ def init_cuda_graph_state(self, max_bs: int): This creates fixed-size tensors that will be reused during CUDA graph replay to avoid memory allocations. """ + + # This is being used by normal decode and draft decode when topk == 1 self.decode_cuda_graph_metadata = { "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), "cu_seqlens_q": torch.arange( @@ -840,24 +1136,136 @@ def init_cuda_graph_state(self, max_bs: int): ), } - self.target_verify_metadata = { - "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), - "cu_seqlens_q": torch.zeros( - max_bs + 1, dtype=torch.int32, device=self.device - ), - "cu_seqlens_k": torch.zeros( - max_bs + 1, dtype=torch.int32, device=self.device - ), - "page_table": torch.zeros( - max_bs, - (self.max_context_len + self.page_size - 1) // self.page_size, - dtype=torch.int32, - device=self.device, - ), - "strided_indices": torch.arange( - 0, self.max_context_len, self.page_size, device=self.device - ), - } + # This is used by draft decode's first half of metadata when topk > 1 + if self.topk > 1: + self.draft_decode_metadata_topk_normal = { + "cache_seqlens": torch.zeros( + max_bs, dtype=torch.int32, device=self.device + ), + "cu_seqlens_q": torch.arange( + 0, + max_bs * self.topk + 1, + step=self.topk, + dtype=torch.int32, + device=self.device, + ), + "cu_seqlens_k": torch.zeros( + max_bs + 1, dtype=torch.int32, device=self.device + ), + "page_table": torch.zeros( + max_bs, + self.max_context_len, + dtype=torch.int32, + device=self.device, + ), + } + + # This is used by draft decode's second half of metadata when topk > 1 + decode_length = self.speculative_step_id + 1 + self.draft_decode_metadata_topk_expand = { + "cache_seqlens": torch.full( + (max_bs * self.topk,), + decode_length, + device=self.device, + dtype=torch.int32, + ), + "cu_seqlens_q": torch.arange( + 0, + max_bs * self.topk + 1, + dtype=torch.int32, + device=self.device, + ), + "cu_seqlens_k": torch.arange( + 0, + max_bs * self.topk * decode_length + 1, + step=decode_length, + dtype=torch.int32, + device=self.device, + ), + "page_table": torch.zeros( + max_bs * self.topk, + decode_length, + dtype=torch.int32, + device=self.device, + ), + } + + if ( + self.speculative_num_draft_tokens is not None + and self.speculative_num_draft_tokens > 0 + ): + self.target_verify_metadata = { + "cache_seqlens": torch.zeros( + max_bs, dtype=torch.int32, device=self.device + ), + "cu_seqlens_q": torch.arange( + 0, + max_bs * self.speculative_num_draft_tokens + 1, + step=self.speculative_num_draft_tokens, + dtype=torch.int32, + device=self.device, + ), + "cu_seqlens_k": torch.zeros( + max_bs + 1, dtype=torch.int32, device=self.device + ), + "page_table": torch.zeros( + max_bs, + (self.max_context_len + self.page_size - 1) // self.page_size, + dtype=torch.int32, + device=self.device, + ), + "strided_indices": torch.arange( + 0, self.max_context_len, self.page_size, device=self.device + ), + } + + if self.topk > 1: + self.target_verify_metadata_topk_normal = { + "cache_seqlens": torch.zeros( + max_bs, dtype=torch.int32, device=self.device + ), + "cu_seqlens_q": torch.arange( + 0, + max_bs * self.speculative_num_draft_tokens + 1, + step=self.speculative_num_draft_tokens, + dtype=torch.int32, + device=self.device, + ), + "cu_seqlens_k": torch.zeros( + max_bs + 1, dtype=torch.int32, device=self.device + ), + "page_table": torch.zeros( + max_bs, + self.max_context_len, + dtype=torch.int32, + device=self.device, + ), + } + + self.target_verify_metadata_topk_expand = { + "cache_seqlens": torch.zeros( + max_bs * self.speculative_num_draft_tokens, + dtype=torch.int32, + device=self.device, + ), + "cu_seqlens_k": torch.zeros( + max_bs * self.speculative_num_draft_tokens + 1, + dtype=torch.int32, + device=self.device, + ), + "cu_seqlens_q": torch.arange( + 0, + max_bs * self.speculative_num_draft_tokens + 1, + dtype=torch.int32, + device=self.device, + ), + "page_table": torch.zeros( + max_bs * self.speculative_num_draft_tokens, + self.speculative_num_draft_tokens, + dtype=torch.int32, + device=self.device, + ), + } self.encoder_metadata = { "encoder_page_table": torch.zeros( @@ -886,28 +1294,78 @@ def init_forward_metadata_capture_cuda_graph( ): """Initialize forward metadata for capturing CUDA graph.""" metadata = FlashAttentionMetadata() + + # metadata_expand is needed for Spec Decoding when top k > 1 + metadata_expand = FlashAttentionMetadata() + device = seq_lens.device if forward_mode.is_decode_or_idle(): if spec_info is not None: # Draft Decode - metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[ - "cache_seqlens" - ][:bs] - metadata.max_seq_len_k = seq_lens.max().item() + ( - self.speculative_step_id + 1 - ) - metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][ - : bs + 1 - ] - metadata.cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum( - metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 - ), - (1, 0), - ) - metadata.page_table = self.decode_cuda_graph_metadata[ - "page_table_draft_decode" - ][req_pool_indices, :] + if self.topk <= 1: + # When topk = 1, we use the normal decode metadata + metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[ + "cache_seqlens" + ][:bs] + metadata.max_seq_len_k = seq_lens.max().item() + ( + self.speculative_step_id + 1 + ) + metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[ + "cu_seqlens_q" + ][: bs + 1] + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + metadata.page_table = self.decode_cuda_graph_metadata[ + "page_table_draft_decode" + ][req_pool_indices, :] + self.decode_cuda_graph_metadata[bs] = metadata + else: + # When top k > 1, we need two specific draft decode metadata, and then merge states + # 1. The first half of metadata for prefix tokens + metadata.cache_seqlens_int32 = ( + self.draft_decode_metadata_topk_normal["cache_seqlens"][:bs] + ) + metadata.max_seq_len_q = self.topk + metadata.max_seq_len_k = seq_lens.max().item() + metadata.cu_seqlens_q = self.draft_decode_metadata_topk_normal[ + "cu_seqlens_q" + ][: bs + 1] + metadata.cu_seqlens_k = self.draft_decode_metadata_topk_normal[ + "cu_seqlens_k" + ][: bs + 1] + metadata.page_table = self.draft_decode_metadata_topk_normal[ + "page_table" + ][req_pool_indices, :] + + # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk) + metadata_expand.cache_seqlens_int32 = ( + self.draft_decode_metadata_topk_expand["cache_seqlens"][ + : bs * self.topk + ] + ) + metadata_expand.max_seq_len_q = 1 + metadata_expand.max_seq_len_k = ( + self.speculative_step_id + 1 + ) # , do this in replay + metadata_expand.cu_seqlens_q = ( + self.draft_decode_metadata_topk_expand["cu_seqlens_q"][ + : bs * self.topk + 1 + ] + ) + metadata_expand.cu_seqlens_k = ( + self.draft_decode_metadata_topk_expand["cu_seqlens_k"][ + : bs * self.topk + 1 + ] + ) + metadata_expand.page_table = self.draft_decode_metadata_topk_expand[ + "page_table" + ][: bs * self.topk] + self.draft_decode_metadata_topk_normal[bs] = metadata + self.draft_decode_metadata_topk_expand[bs] = metadata_expand else: # Normal Decode # Get sequence information @@ -927,37 +1385,77 @@ def init_forward_metadata_capture_cuda_graph( metadata.cu_seqlens_q = torch.arange( 0, batch_size + 1, dtype=torch.int32, device=device ) - self.decode_cuda_graph_metadata[bs] = metadata + self.decode_cuda_graph_metadata[bs] = metadata + elif forward_mode.is_target_verify(): - metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][ - :bs - ] - metadata.cache_seqlens_int32.copy_( - (seq_lens + self.speculative_num_draft_tokens).to(torch.int32) - ) + if self.topk <= 1: + metadata.cache_seqlens_int32 = self.target_verify_metadata[ + "cache_seqlens" + ][:bs] + metadata.cache_seqlens_int32.copy_( + (seq_lens + self.speculative_num_draft_tokens).to(torch.int32) + ) - metadata.max_seq_len_q = self.speculative_num_draft_tokens - metadata.max_seq_len_k = ( - seq_lens.max().item() + self.speculative_num_draft_tokens - ) + metadata.max_seq_len_q = self.speculative_num_draft_tokens + metadata.max_seq_len_k = ( + seq_lens.max().item() + self.speculative_num_draft_tokens + ) - metadata.cu_seqlens_q = torch.arange( - 0, - bs * self.speculative_num_draft_tokens + 1, - self.speculative_num_draft_tokens, - dtype=torch.int32, - device=device, - ) + metadata.cu_seqlens_q = torch.arange( + 0, + bs * self.speculative_num_draft_tokens + 1, + self.speculative_num_draft_tokens, + dtype=torch.int32, + device=device, + ) - metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][ - : (bs + 1) - ] + metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][ + : (bs + 1) + ] - metadata.page_table = self.target_verify_metadata["page_table"][ - req_pool_indices, : - ] + metadata.page_table = self.target_verify_metadata["page_table"][ + req_pool_indices, : + ] - self.target_verify_metadata[bs] = metadata + self.target_verify_metadata[bs] = metadata + else: + # When topk > 1, we need two specific target verify metadata, and then merge states + # 1. The first half of metadata for prefix tokens + metadata.cache_seqlens_int32 = self.target_verify_metadata_topk_normal[ + "cache_seqlens" + ][:bs] + metadata.max_seq_len_q = self.speculative_num_draft_tokens + # metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item(), do this in replay + metadata.cu_seqlens_q = self.target_verify_metadata_topk_normal[ + "cu_seqlens_q" + ][: bs + 1] + metadata.cu_seqlens_k = self.target_verify_metadata_topk_normal[ + "cu_seqlens_k" + ][: bs + 1] + metadata.page_table = self.target_verify_metadata_topk_normal[ + "page_table" + ][req_pool_indices, :] + + # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk) + metadata_expand.cache_seqlens_int32 = ( + self.target_verify_metadata_topk_expand["cache_seqlens"][ + : bs * self.speculative_num_draft_tokens + ] + ) + metadata_expand.max_seq_len_q = 1 + metadata_expand.cu_seqlens_q = self.target_verify_metadata_topk_expand[ + "cu_seqlens_q" + ][: bs * self.speculative_num_draft_tokens + 1] + metadata_expand.cu_seqlens_k = self.target_verify_metadata_topk_expand[ + "cu_seqlens_k" + ][: bs * self.speculative_num_draft_tokens + 1] + + metadata_expand.page_table = self.target_verify_metadata_topk_expand[ + "page_table" + ][: bs * self.speculative_num_draft_tokens] + + self.target_verify_metadata_topk_normal[bs] = metadata + self.target_verify_metadata_topk_expand[bs] = metadata_expand if encoder_lens is not None: encoder_bs = encoder_lens.numel() @@ -973,6 +1471,7 @@ def init_forward_metadata_capture_cuda_graph( ] self.forward_metadata = metadata + self.forward_metadata_spec_decode_expand = metadata_expand def init_forward_metadata_replay_cuda_graph( self, @@ -986,41 +1485,85 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_cpu: Optional[torch.Tensor], out_cache_loc: torch.Tensor = None, ): - # """Initialize forward metadata for replaying CUDA graph.""" + """Initialize forward metadata for replaying CUDA graph.""" seq_lens = seq_lens[:bs] seq_lens_cpu = seq_lens_cpu[:bs] req_pool_indices = req_pool_indices[:bs] device = seq_lens.device + metadata = None + metadata_expand = None if forward_mode.is_decode_or_idle(): - metadata = self.decode_cuda_graph_metadata[bs] if spec_info is not None: # Draft Decode - metadata.cache_seqlens_int32.copy_( - (seq_lens + (self.speculative_step_id + 1)).to(torch.int32) - ) + if self.topk <= 1: + metadata = self.decode_cuda_graph_metadata[bs] + # When topk = 1, we use the normal decode metadata + metadata.cache_seqlens_int32.copy_( + (seq_lens + (self.speculative_step_id + 1)).to(torch.int32) + ) - metadata.max_seq_len_k = seq_lens_cpu.max().item() + ( - self.speculative_step_id + 1 - ) - metadata.cu_seqlens_k.copy_( - torch.nn.functional.pad( - torch.cumsum( - metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 - ), - (1, 0), + metadata.max_seq_len_k = seq_lens_cpu.max().item() + ( + self.speculative_step_id + 1 + ) + metadata.cu_seqlens_k.copy_( + torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) ) - ) - page_table = self.req_to_token[ - req_pool_indices, : metadata.max_seq_len_k - ] + max_seq_pages = ( + metadata.max_seq_len_k + self.page_size - 1 + ) // self.page_size + page_indices = self.req_to_token[ + req_pool_indices[:, None], + self.decode_cuda_graph_metadata["strided_indices"][ + :max_seq_pages + ], + ] + + page_indices //= self.page_size + metadata.page_table[:, :max_seq_pages].copy_(page_indices) + else: + # When top k > 1, we need two specific draft decode metadata, and then merge states + # 1. The first half of metadata for prefix tokens + metadata = self.draft_decode_metadata_topk_normal[bs] + metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32)) + # metadata.max_seq_len_q = self.topk, already set in capture + metadata.max_seq_len_k = seq_lens_cpu.max().item() + # metadata.cu_seqlens_q already set in capture + metadata.cu_seqlens_k.copy_( + torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + ) - metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table) + page_table = self.req_to_token[ + req_pool_indices, : metadata.max_seq_len_k + ] + + metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table) + # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk) + metadata_expand = self.draft_decode_metadata_topk_expand[bs] + decode_length = self.speculative_step_id + 1 + cache_loc = out_cache_loc.view( + self.speculative_num_steps, -1 + ).T.contiguous() + metadata_expand.page_table[: cache_loc.shape[0]].copy_( + cache_loc[:, :decode_length].contiguous().to(torch.int32) + ) + # TODO: we need to test this part for llama 4 eagle case self._init_local_attn_metadata(metadata, device) else: + metadata = self.decode_cuda_graph_metadata[bs] # Normal Decode max_len = seq_lens_cpu.max().item() metadata.max_seq_len_k = max_len @@ -1045,24 +1588,117 @@ def init_forward_metadata_replay_cuda_graph( self._init_local_attn_metadata(metadata, device) elif forward_mode.is_target_verify(): - metadata = self.target_verify_metadata[bs] - metadata.cache_seqlens_int32.copy_( - (seq_lens + self.speculative_num_draft_tokens).to(torch.int32) - ) + if self.topk <= 1: + metadata = self.target_verify_metadata[bs] + metadata.cache_seqlens_int32.copy_( + (seq_lens + self.speculative_num_draft_tokens).to(torch.int32) + ) - metadata.max_seq_len_k = ( - seq_lens_cpu.max().item() + self.speculative_num_draft_tokens - ) - metadata.cu_seqlens_k.copy_( - torch.nn.functional.pad( + metadata.max_seq_len_k = ( + seq_lens_cpu.max().item() + self.speculative_num_draft_tokens + ) + metadata.cu_seqlens_k.copy_( + torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + ) + max_seq_pages = ( + metadata.max_seq_len_k + self.page_size - 1 + ) // self.page_size + page_indices = self.req_to_token[ + req_pool_indices[:, None], + self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages], + ] + page_indices //= self.page_size + metadata.page_table[:, :max_seq_pages].copy_(page_indices) + else: + # When topk > 1, we need two specific target verify metadata, and then merge states + # 1. The first half of metadata for prefix tokens + metadata = self.target_verify_metadata_topk_normal[bs] + metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32)) + # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture + metadata.max_seq_len_k = seq_lens_cpu.max().item() + # metadata.cu_seqlens_q already set in capture + metadata.cu_seqlens_k.copy_( + torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + ) + page_table = self.req_to_token[ + req_pool_indices, : metadata.max_seq_len_k + ] + metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table) + + # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk) + metadata_expand = self.target_verify_metadata_topk_expand[bs] + # metadata_expand.max_seq_len_q = 1, already set in capture + # metadata_expand.cu_seqlens_q already set in capture + + offsets = torch.arange( + self.speculative_num_draft_tokens, device=device + ).unsqueeze( + 0 + ) # shape: (1, self.speculative_num_draft_tokens) + cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1) + cum_len = torch.nn.functional.pad( torch.cumsum( - metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ( + seq_lens + self.speculative_num_draft_tokens + ).repeat_interleave(self.speculative_num_draft_tokens), + dim=0, ), (1, 0), + )[:-1] + mask_extraction_indices = ( + cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0) + + cum_len[:, None] + ).view(1, -1) + # avoid extracting padded seq indices which will be out of boundary + mask_extraction_indices[ + :, spec_info.positions.numel() * self.speculative_num_draft_tokens : + ].fill_(0) + + mask = spec_info.custom_mask[mask_extraction_indices].view( + -1, self.speculative_num_draft_tokens + ) # (bsz * draft_num, draft_num) + col_indices = offsets.expand( + mask.shape[0], self.speculative_num_draft_tokens + ) + keys = torch.where( + mask, col_indices, col_indices + self.speculative_num_draft_tokens + ) + _, sort_order = torch.sort(keys, dim=1) + + non_masked_page_table = ( + self.req_to_token[req_pool_indices, :] + .gather(1, cols) + .repeat_interleave(self.speculative_num_draft_tokens, dim=0) + ) # (bsz, draft_num) + metadata_expand.page_table.copy_( + non_masked_page_table.gather(1, sort_order) + ) + metadata_expand.cache_seqlens_int32.copy_( + mask.sum(dim=1).to(torch.int32) + ) + metadata_expand.cu_seqlens_k.copy_( + torch.nn.functional.pad( + torch.cumsum( + metadata_expand.cache_seqlens_int32, + dim=0, + dtype=torch.int32, + ), + (1, 0), + ) + ) + metadata_expand.max_seq_len_k = ( + metadata_expand.cache_seqlens_int32.max().item() ) - ) - page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k] - metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table) if encoder_lens is not None: # Only support encoder size 1 for now @@ -1089,6 +1725,7 @@ def init_forward_metadata_replay_cuda_graph( metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table) self.forward_metadata = metadata + self.forward_metadata_spec_decode_expand = metadata_expand def get_cuda_graph_seq_len_fill_value(self): """Get the fill value for sequence length in CUDA graph.""" @@ -1139,12 +1776,6 @@ def __init__( self.model_runner = model_runner self.topk = topk self.speculative_num_steps = speculative_num_steps - - # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding - assert ( - self.topk == 1 - ), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend" - self.attn_backends = [] for i in range(self.speculative_num_steps): self.attn_backends.append( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6a48a60b716..8833ccc4287 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -221,7 +221,16 @@ def model_specific_adjustment(self): server_args = self.server_args if server_args.attention_backend is None: - # By default, use flashinfer for non-mla attention and triton for mla attention + """ + We auto select the fastest attention backend according to the current offering + 1. Models with MHA Architecture (e.g: Llama, QWen) + 1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1. + 1.2 In other cases, we will use flashinfer if available, otherwise use triton. + 2. Models with MLA Architecture and using FA3 + 2.1 We will use FA3 backend on hopper. + 2.2 Otherwise, we will use triton backend. + """ + if not self.use_mla_backend: if ( is_hopper_with_cuda_12_3() @@ -234,9 +243,7 @@ def model_specific_adjustment(self): "flashinfer" if is_flashinfer_available() else "triton" ) else: - if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one( - server_args - ): + if is_hopper_with_cuda_12_3(): server_args.attention_backend = "fa3" else: server_args.attention_backend = "triton" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ddbbdf35d8c..e294a56f11f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -359,7 +359,18 @@ def __post_init__(self): if self.page_size > 1 and self.speculative_eagle_topk > 1: self.speculative_eagle_topk = 1 - logger.info("speculative_eagle_topk is changed to 1 when page_size > 1") + logger.info( + "speculative_eagle_topk is adjusted to 1 when page_size > 1" + ) + + if ( + self.speculative_eagle_topk == 1 + and self.speculative_num_draft_tokens != self.speculative_num_steps + 1 + ): + logger.info( + "speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1" + ) + self.speculative_num_draft_tokens = self.speculative_num_steps + 1 # The token generated from the verify step is counted. # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded. diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 5b8b2cbe818..7cc70dcb701 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1909,6 +1909,8 @@ def is_page_size_one(server_args): return server_args.page_size == 1 +# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1. +# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1. def is_no_spec_infer_or_topk_one(server_args): return server_args.speculative_eagle_topk is None or ( server_args.speculative_eagle_topk is not None diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 3f7d846a552..4cdcb240d85 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -29,7 +29,7 @@ class TestFile: TestFile("test_chunked_prefill.py", 336), TestFile("test_eagle_infer.py", 500), TestFile("test_ebnf_constrained.py"), - TestFile("test_fa3.py", 5), + TestFile("test_fa3.py", 200), TestFile("test_fp8_kernel.py", 8), TestFile("test_embedding_openai_server.py", 36), TestFile("test_hidden_states.py", 55), diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 90b462aaa38..886e19db1d2 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -173,6 +173,60 @@ def test_gsm8k(self): self.assertGreater(avg_spec_accept_length, 1.5) +class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest): + """Test FlashAttention3 with speculative decode enabled, topk > 1""" + + model = "meta-llama/Llama-3.1-8B-Instruct" + + @classmethod + def get_server_args(cls): + args = super().get_server_args() + args.extend( + [ + "--cuda-graph-max-bs", + "2", + "--speculative-algorithm", + "EAGLE3", + "--speculative-draft", + "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "4", + "--speculative-num-draft-tokens", + "8", + "--dtype", + "float16", + ] + ) + return args + + def test_gsm8k(self): + """ + Override the test_gsm8k to further test for average speculative accept length. + """ + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=DATA_PATH, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 1.8) + + class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest): """Test FlashAttention3 with speculative decode enabled."""