From d7819d0fe8609e51417a6836a3faa3acf5c218ea Mon Sep 17 00:00:00 2001 From: qingquansong Date: Mon, 31 Mar 2025 17:27:23 +0000 Subject: [PATCH 01/12] Add Eagle Speculative Decoding to FA3 Backend --- .../attention/flashattention_backend.py | 149 ++++++++++++++++-- python/sglang/srt/speculative/eagle_worker.py | 13 ++ sgl-kernel/3rdparty/flashinfer | 1 - sgl-kernel/flash-attention | 1 + 4 files changed, 146 insertions(+), 18 deletions(-) delete mode 160000 sgl-kernel/3rdparty/flashinfer create mode 160000 sgl-kernel/flash-attention diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 93c263f7473..e199d89aa57 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -45,6 +45,9 @@ def __init__( self, model_runner: ModelRunner, skip_prefill: bool = False, + topk = 0, + speculative_num_steps = 0, + step_id = 0 ): super().__init__() @@ -63,6 +66,10 @@ def __init__( self.use_mla = ( model_runner.model_config.attention_arch == AttentionArch.MLA ) and (not global_server_args_dict["disable_mla"]) + self.skip_prefill = skip_prefill + self.topk = topk + self.speculative_num_steps = speculative_num_steps + self.step_id = step_id def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize forward metadata to cache repetitive calculations.""" @@ -72,37 +79,78 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): # Get sequence information seqlens_in_batch = forward_batch.seq_lens # Precompute int32 version of sequence lengths - metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) batch_size = len(seqlens_in_batch) device = seqlens_in_batch.device + + metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cu_seqlens_k = torch.nn.functional.pad( torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) ) # Precompute maximum sequence length - metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + metadata.max_seq_len_k = seqlens_in_batch.max().item() # Precompute page table metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] - # Precompute strided indices - # [0, page_size, 2 * page_size, ...] - if self.page_size > 1: - self.strided_indices = torch.arange( - 0, metadata.page_table.shape[1], self.page_size, device=self.device - ) - metadata.page_table = ( - metadata.page_table[:, self.strided_indices] // self.page_size - ) if forward_batch.forward_mode == ForwardMode.DECODE: - # Precompute cumulative sequence lengths + if self.skip_prefill: + metadata.cu_seqlens_q = torch.arange( + 0, batch_size * self.topk + 1, dtype=torch.int32, device=device + ) + seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1) + metadata.cache_seqlens_int32 = (seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, + dim=0, dtype=torch.int32), + (1, 0) + ) + # .repeat_interleave(self.topk) # tensor([7, 7, 7, 8, 8, 8]) + # .repeat(self.topk) # tensor([7, 8, 7, 8, 7, 8]) + metadata.max_seq_len_k = metadata.cache_seqlens_int32.max().item() + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] # (bsz, max_seq_len) + metadata.page_table = metadata.page_table.repeat_interleave(self.topk, dim=0) + cache_loc = forward_batch.out_cache_loc.view(self.speculative_num_steps, -1).T + + for idx, single_seq_len in enumerate(seq_lens_with_decode): + real_bsz_start_idx = idx * self.topk + real_bsz_end_idx = (idx+1) * self.topk + metadata.page_table[real_bsz_start_idx : real_bsz_end_idx, (single_seq_len-(self.step_id + 1)):single_seq_len] = cache_loc[real_bsz_start_idx:real_bsz_end_idx, :(self.step_id + 1)] + else: + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) + elif forward_batch.forward_mode == ForwardMode.TARGET_VERIFY: + draft_token_num = forward_batch.spec_info.draft_token_num + metadata.cu_seqlens_q = torch.arange( - 0, batch_size + 1, dtype=torch.int32, device=device + 0, batch_size * draft_token_num + 1, dtype=torch.int32, device=device ) + + aug_seq_lens = (forward_batch.seq_lens + draft_token_num).to(torch.int32) + metadata.cache_seqlens_int32 = aug_seq_lens.repeat_interleave(forward_batch.spec_info.draft_token_num) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, + dim=0, dtype=torch.int32), + (1, 0) + ) + metadata.max_seq_len_k = metadata.cache_seqlens_int32.max().item() + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ].repeat_interleave(draft_token_num, dim=0) + aug_cum_len = torch.nn.functional.pad(torch.cumsum(aug_seq_lens, dim=0, dtype=torch.int32), (1, 0)) + for idx, single_seq_len in enumerate(aug_seq_lens): + metadata.page_table[idx *draft_token_num :(idx+1)*draft_token_num, : single_seq_len] *= forward_batch.spec_info.custom_mask[aug_cum_len[idx] * draft_token_num: aug_cum_len[idx+1] * draft_token_num].view(draft_token_num, -1) + + metadata.max_seq_len_q = 1 else: # Precompute cumulative sequence lengths - if any(forward_batch.extend_prefix_lens_cpu): + if any(forward_batch.extend_prefix_lens_cpu) or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND: extend_seq_lens = forward_batch.extend_seq_lens metadata.cu_seqlens_q = torch.nn.functional.pad( torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0) @@ -111,6 +159,17 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): else: metadata.cu_seqlens_q = metadata.cu_seqlens_k metadata.max_seq_len_q = metadata.max_seq_len_k + + + # Precompute strided indices + # [0, page_size, 2 * page_size, ...] + if self.page_size > 1: + self.strided_indices = torch.arange( + 0, metadata.page_table.shape[1], self.page_size, device=self.device + ) + metadata.page_table = ( + metadata.page_table[:, self.strided_indices] // self.page_size + ) self.forward_metadata = metadata def forward_extend( @@ -281,8 +340,6 @@ def forward_decode( # Pre-reshape query tensor q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) - - # Run attention with precomputed values o = flash_attn_with_kvcache( q=q_reshaped, k_cache=key_cache, @@ -385,7 +442,7 @@ def init_forward_metadata_capture_cuda_graph( metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ req_pool_indices, : ] - if forward_mode == ForwardMode.DECODE: + if forward_mode.is_cuda_graph(): # Precompute cumulative sequence lengths metadata.cu_seqlens_q = torch.arange( 0, batch_size + 1, dtype=torch.int32, device=device @@ -432,3 +489,61 @@ def init_forward_metadata_replay_cuda_graph( def get_cuda_graph_seq_len_fill_value(self): """Get the fill value for sequence length in CUDA graph.""" return 0 + +class FlashAttentionMultiStepBackend: + + def __init__(self, model_runner: ModelRunner, topk: int, speculative_num_steps: int): + self.model_runner = model_runner + self.topk = topk + self.speculative_num_steps = speculative_num_steps + + self.attn_backends = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + FlashAttentionBackend( + model_runner, + skip_prefill=True, + topk = self.topk, + speculative_num_steps=self.speculative_num_steps, + step_id = i, + ) + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + for i in range(self.speculative_num_steps - 1): + self.attn_backends[i].init_forward_metadata(forward_batch) + + def init_cuda_graph_state(self, max_bs: int): + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state(max_bs) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + assert forward_batch.spec_info is not None + assert isinstance(forward_batch.spec_info, EagleDraftInput) + + for i in range(self.speculative_num_steps - 1): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + def init_forward_metadata_replay_cuda_graph(self, forward_batch: ForwardBatch, bs: int): + assert forward_batch.spec_info is not None + assert isinstance(forward_batch.spec_info, EagleDraftInput) + + for i in range(self.speculative_num_steps - 1): + self.attn_backends[i].init_forward_metadata_replay_cuda_graph( + bs, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + seq_lens_cpu=forward_batch.decode_seq_lens_cpu, + ) + diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 234e13209e2..e95fc2a3b13 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -184,6 +184,19 @@ def init_attention_backend(self): self.draft_extend_attn_backend = None self.padded_static_len = self.speculative_num_steps + 1 self.has_prefill_wrapper_verify = True + elif self.server_args.attention_backend == "fa3": + from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionMultiStepBackend + ) + + self.draft_attn_backend = FlashAttentionMultiStepBackend( + self.draft_model_runner, + self.topk, + self.speculative_num_steps, + ) + self.draft_extend_attn_backend = None + self.padded_static_len = self.speculative_num_steps + 1 + self.has_prefill_wrapper_verify = False else: raise ValueError( f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}" diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer deleted file mode 160000 index 2b9f16eb79b..00000000000 --- a/sgl-kernel/3rdparty/flashinfer +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 2b9f16eb79bd344e31725e8d7a92fe7fe980ffdf diff --git a/sgl-kernel/flash-attention b/sgl-kernel/flash-attention new file mode 160000 index 00000000000..1a58058a6da --- /dev/null +++ b/sgl-kernel/flash-attention @@ -0,0 +1 @@ +Subproject commit 1a58058a6da83bd7baaf4c512e8a1abe0240bb77 From 189fea3dd1d2580834d3ae8831a04df165ab903d Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Mon, 31 Mar 2025 17:55:29 +0000 Subject: [PATCH 02/12] test --- sgl-kernel/3rdparty/flashinfer | 1 + 1 file changed, 1 insertion(+) create mode 160000 sgl-kernel/3rdparty/flashinfer diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer new file mode 160000 index 00000000000..2b9f16eb79b --- /dev/null +++ b/sgl-kernel/3rdparty/flashinfer @@ -0,0 +1 @@ +Subproject commit 2b9f16eb79bd344e31725e8d7a92fe7fe980ffdf From 91a004517d0659dd82cbaecb65aeabe5aaa6a908 Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Mon, 31 Mar 2025 17:55:58 +0000 Subject: [PATCH 03/12] test --- sgl-kernel/flash-attention | 1 - 1 file changed, 1 deletion(-) delete mode 160000 sgl-kernel/flash-attention diff --git a/sgl-kernel/flash-attention b/sgl-kernel/flash-attention deleted file mode 160000 index 1a58058a6da..00000000000 --- a/sgl-kernel/flash-attention +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1a58058a6da83bd7baaf4c512e8a1abe0240bb77 From 77b200c8109dc049496c2ff9c7278f514b295673 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 31 Mar 2025 13:00:15 -0700 Subject: [PATCH 04/12] use .item() on cpu tensors on reformat --- .../attention/flashattention_backend.py | 118 +++++++++++------- python/sglang/srt/speculative/eagle_worker.py | 2 +- 2 files changed, 75 insertions(+), 45 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index e199d89aa57..64baa1b6f80 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -45,9 +45,9 @@ def __init__( self, model_runner: ModelRunner, skip_prefill: bool = False, - topk = 0, - speculative_num_steps = 0, - step_id = 0 + topk=0, + speculative_num_steps=0, + step_id=0, ): super().__init__() @@ -82,48 +82,49 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): batch_size = len(seqlens_in_batch) device = seqlens_in_batch.device - metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) - metadata.cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) - ) - # Precompute maximum sequence length - metadata.max_seq_len_k = seqlens_in_batch.max().item() - # Precompute page table - metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ - forward_batch.req_pool_indices, : metadata.max_seq_len_k - ] - - if forward_batch.forward_mode == ForwardMode.DECODE: if self.skip_prefill: metadata.cu_seqlens_q = torch.arange( 0, batch_size * self.topk + 1, dtype=torch.int32, device=device ) - seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1) - metadata.cache_seqlens_int32 = (seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32) + seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1) + metadata.cache_seqlens_int32 = ( + (seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32) + ) metadata.cu_seqlens_k = torch.nn.functional.pad( torch.cumsum( - metadata.cache_seqlens_int32, - dim=0, dtype=torch.int32), - (1, 0) + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), ) # .repeat_interleave(self.topk) # tensor([7, 7, 7, 8, 8, 8]) # .repeat(self.topk) # tensor([7, 8, 7, 8, 7, 8]) - metadata.max_seq_len_k = metadata.cache_seqlens_int32.max().item() + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( + self.step_id + 1 + ) metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k - ] # (bsz, max_seq_len) - metadata.page_table = metadata.page_table.repeat_interleave(self.topk, dim=0) - cache_loc = forward_batch.out_cache_loc.view(self.speculative_num_steps, -1).T + ] # (bsz, max_seq_len) + metadata.page_table = metadata.page_table.repeat_interleave( + self.topk, dim=0 + ) + cache_loc = forward_batch.out_cache_loc.view( + self.speculative_num_steps, -1 + ).T for idx, single_seq_len in enumerate(seq_lens_with_decode): - real_bsz_start_idx = idx * self.topk - real_bsz_end_idx = (idx+1) * self.topk - metadata.page_table[real_bsz_start_idx : real_bsz_end_idx, (single_seq_len-(self.step_id + 1)):single_seq_len] = cache_loc[real_bsz_start_idx:real_bsz_end_idx, :(self.step_id + 1)] + real_bsz_start_idx = idx * self.topk + real_bsz_end_idx = (idx + 1) * self.topk + metadata.page_table[ + real_bsz_start_idx:real_bsz_end_idx, + (single_seq_len - (self.step_id + 1)) : single_seq_len, + ] = cache_loc[ + real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1) + ] else: metadata.cu_seqlens_q = torch.arange( 0, batch_size + 1, dtype=torch.int32, device=device - ) + ) elif forward_batch.forward_mode == ForwardMode.TARGET_VERIFY: draft_token_num = forward_batch.spec_info.draft_token_num @@ -132,25 +133,50 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) aug_seq_lens = (forward_batch.seq_lens + draft_token_num).to(torch.int32) - metadata.cache_seqlens_int32 = aug_seq_lens.repeat_interleave(forward_batch.spec_info.draft_token_num) + metadata.cache_seqlens_int32 = aug_seq_lens.repeat_interleave( + forward_batch.spec_info.draft_token_num + ) metadata.cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum( - metadata.cache_seqlens_int32, - dim=0, dtype=torch.int32), - (1, 0) + torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32), + (1, 0), + ) + metadata.max_seq_len_k = ( + forward_batch.seq_lens_cpu.max().item() + draft_token_num ) - metadata.max_seq_len_k = metadata.cache_seqlens_int32.max().item() metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ].repeat_interleave(draft_token_num, dim=0) - aug_cum_len = torch.nn.functional.pad(torch.cumsum(aug_seq_lens, dim=0, dtype=torch.int32), (1, 0)) + aug_cum_len = torch.nn.functional.pad( + torch.cumsum(aug_seq_lens, dim=0, dtype=torch.int32), (1, 0) + ) for idx, single_seq_len in enumerate(aug_seq_lens): - metadata.page_table[idx *draft_token_num :(idx+1)*draft_token_num, : single_seq_len] *= forward_batch.spec_info.custom_mask[aug_cum_len[idx] * draft_token_num: aug_cum_len[idx+1] * draft_token_num].view(draft_token_num, -1) + metadata.page_table[ + idx * draft_token_num : (idx + 1) * draft_token_num, :single_seq_len + ] *= forward_batch.spec_info.custom_mask[ + aug_cum_len[idx] + * draft_token_num : aug_cum_len[idx + 1] + * draft_token_num + ].view( + draft_token_num, -1 + ) - metadata.max_seq_len_q = 1 + metadata.max_seq_len_q = 1 else: + metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + # Precompute maximum sequence length + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + # Precompute page table + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] # Precompute cumulative sequence lengths - if any(forward_batch.extend_prefix_lens_cpu) or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND: + if ( + any(forward_batch.extend_prefix_lens_cpu) + or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND + ): extend_seq_lens = forward_batch.extend_seq_lens metadata.cu_seqlens_q = torch.nn.functional.pad( torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0) @@ -160,7 +186,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): metadata.cu_seqlens_q = metadata.cu_seqlens_k metadata.max_seq_len_q = metadata.max_seq_len_k - # Precompute strided indices # [0, page_size, 2 * page_size, ...] if self.page_size > 1: @@ -490,9 +515,12 @@ def get_cuda_graph_seq_len_fill_value(self): """Get the fill value for sequence length in CUDA graph.""" return 0 + class FlashAttentionMultiStepBackend: - def __init__(self, model_runner: ModelRunner, topk: int, speculative_num_steps: int): + def __init__( + self, model_runner: ModelRunner, topk: int, speculative_num_steps: int + ): self.model_runner = model_runner self.topk = topk self.speculative_num_steps = speculative_num_steps @@ -503,9 +531,9 @@ def __init__(self, model_runner: ModelRunner, topk: int, speculative_num_steps: FlashAttentionBackend( model_runner, skip_prefill=True, - topk = self.topk, + topk=self.topk, speculative_num_steps=self.speculative_num_steps, - step_id = i, + step_id=i, ) ) @@ -531,7 +559,10 @@ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): forward_mode=ForwardMode.DECODE, spec_info=forward_batch.spec_info, ) - def init_forward_metadata_replay_cuda_graph(self, forward_batch: ForwardBatch, bs: int): + + def init_forward_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int + ): assert forward_batch.spec_info is not None assert isinstance(forward_batch.spec_info, EagleDraftInput) @@ -546,4 +577,3 @@ def init_forward_metadata_replay_cuda_graph(self, forward_batch: ForwardBatch, b spec_info=forward_batch.spec_info, seq_lens_cpu=forward_batch.decode_seq_lens_cpu, ) - diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index e95fc2a3b13..ffec4a0adb8 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -186,7 +186,7 @@ def init_attention_backend(self): self.has_prefill_wrapper_verify = True elif self.server_args.attention_backend == "fa3": from sglang.srt.layers.attention.flashattention_backend import ( - FlashAttentionMultiStepBackend + FlashAttentionMultiStepBackend, ) self.draft_attn_backend = FlashAttentionMultiStepBackend( From 8717cba28f4fd8a09a8338872a9af45bd90ed8fe Mon Sep 17 00:00:00 2001 From: zcnrex Date: Mon, 31 Mar 2025 20:52:25 +0000 Subject: [PATCH 05/12] Convert decode for loop into masks --- .../attention/flashattention_backend.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 64baa1b6f80..743d3219765 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -97,8 +97,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ), (1, 0), ) - # .repeat_interleave(self.topk) # tensor([7, 7, 7, 8, 8, 8]) - # .repeat(self.topk) # tensor([7, 8, 7, 8, 7, 8]) metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( self.step_id + 1 ) @@ -112,15 +110,22 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.speculative_num_steps, -1 ).T - for idx, single_seq_len in enumerate(seq_lens_with_decode): - real_bsz_start_idx = idx * self.topk - real_bsz_end_idx = (idx + 1) * self.topk - metadata.page_table[ - real_bsz_start_idx:real_bsz_end_idx, - (single_seq_len - (self.step_id + 1)) : single_seq_len, - ] = cache_loc[ - real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1) - ] + # page table indices to update + # [bsz, topk] + row_indices = torch.arange(batch_size * self.topk, device=device, dtype=torch.int32).view(batch_size, self.topk) + # [max_seq_len : max_seq_len + step_id + 1] + col_indices = torch.arange(forward_batch.seq_lens_cpu.max().item(), metadata.max_seq_len_k, device=device, dtype=torch.int32) + # mask for all valid page table indices + valid_mask = (col_indices.view(1, -1) >= seqlens_in_batch.view(-1, 1)) & (col_indices.view(1, -1) < seq_lens_with_decode.view(-1, 1)) + + # cache indices to read + cache_indices = torch.arange(self.step_id + 1, device=device, dtype=torch.int32) + + metadata.page_table[row_indices, col_indices] = torch.where( + valid_mask, + cache_loc[row_indices, cache_indices].to(torch.int32), + metadata.page_table[row_indices, col_indices] + ) else: metadata.cu_seqlens_q = torch.arange( 0, batch_size + 1, dtype=torch.int32, device=device From ab84b358c46fbdada01f99c13d827c9f8c309ae4 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 31 Mar 2025 14:03:49 -0700 Subject: [PATCH 06/12] small fix --- python/sglang/srt/layers/attention/flashattention_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 743d3219765..6b44fc51d73 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -580,5 +580,5 @@ def init_forward_metadata_replay_cuda_graph( encoder_lens=None, forward_mode=ForwardMode.DECODE, spec_info=forward_batch.spec_info, - seq_lens_cpu=forward_batch.decode_seq_lens_cpu, + seq_lens_cpu=forward_batch.seq_lens_cpu, ) From 3dcd43605db086839f03c14373f8a6b72a6b0dcf Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 31 Mar 2025 14:35:24 -0700 Subject: [PATCH 07/12] fix bug --- .../attention/flashattention_backend.py | 106 ++++++++++-------- 1 file changed, 57 insertions(+), 49 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 6b44fc51d73..f3e42e75894 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -82,54 +82,58 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): batch_size = len(seqlens_in_batch) device = seqlens_in_batch.device - if forward_batch.forward_mode == ForwardMode.DECODE: - if self.skip_prefill: - metadata.cu_seqlens_q = torch.arange( - 0, batch_size * self.topk + 1, dtype=torch.int32, device=device - ) - seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1) - metadata.cache_seqlens_int32 = ( - (seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32) - ) - metadata.cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum( - metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 - ), - (1, 0), - ) - metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( - self.step_id + 1 - ) - metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ - forward_batch.req_pool_indices, : metadata.max_seq_len_k - ] # (bsz, max_seq_len) - metadata.page_table = metadata.page_table.repeat_interleave( - self.topk, dim=0 - ) - cache_loc = forward_batch.out_cache_loc.view( - self.speculative_num_steps, -1 - ).T - - # page table indices to update - # [bsz, topk] - row_indices = torch.arange(batch_size * self.topk, device=device, dtype=torch.int32).view(batch_size, self.topk) - # [max_seq_len : max_seq_len + step_id + 1] - col_indices = torch.arange(forward_batch.seq_lens_cpu.max().item(), metadata.max_seq_len_k, device=device, dtype=torch.int32) - # mask for all valid page table indices - valid_mask = (col_indices.view(1, -1) >= seqlens_in_batch.view(-1, 1)) & (col_indices.view(1, -1) < seq_lens_with_decode.view(-1, 1)) - - # cache indices to read - cache_indices = torch.arange(self.step_id + 1, device=device, dtype=torch.int32) - - metadata.page_table[row_indices, col_indices] = torch.where( - valid_mask, - cache_loc[row_indices, cache_indices].to(torch.int32), - metadata.page_table[row_indices, col_indices] - ) - else: - metadata.cu_seqlens_q = torch.arange( - 0, batch_size + 1, dtype=torch.int32, device=device - ) + if forward_batch.forward_mode == ForwardMode.DECODE and self.skip_prefill: + metadata.cu_seqlens_q = torch.arange( + 0, batch_size * self.topk + 1, dtype=torch.int32, device=device + ) + seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1) + metadata.cache_seqlens_int32 = ( + (seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32) + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32), + (1, 0), + ) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( + self.step_id + 1 + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] # (bsz, max_seq_len) + metadata.page_table = metadata.page_table.repeat_interleave( + self.topk, dim=0 + ) + cache_loc = forward_batch.out_cache_loc.view( + self.speculative_num_steps, -1 + ).T + + # page table indices to update + # [bsz, topk] + row_indices = torch.arange( + batch_size * self.topk, device=device, dtype=torch.int32 + ).view(batch_size, self.topk) + # [max_seq_len : max_seq_len + step_id + 1] + col_indices = torch.arange( + forward_batch.seq_lens_cpu.max().item(), + metadata.max_seq_len_k, + device=device, + dtype=torch.int32, + ) + # mask for all valid page table indices + valid_mask = (col_indices.view(1, -1) >= seqlens_in_batch.view(-1, 1)) & ( + col_indices.view(1, -1) < seq_lens_with_decode.view(-1, 1) + ) + + # cache indices to read + cache_indices = torch.arange( + self.step_id + 1, device=device, dtype=torch.int32 + ) + + metadata.page_table[row_indices, col_indices] = torch.where( + valid_mask, + cache_loc[row_indices, cache_indices].to(torch.int32), + metadata.page_table[row_indices, col_indices], + ) elif forward_batch.forward_mode == ForwardMode.TARGET_VERIFY: draft_token_num = forward_batch.spec_info.draft_token_num @@ -178,7 +182,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.req_pool_indices, : metadata.max_seq_len_k ] # Precompute cumulative sequence lengths - if ( + if forward_batch.forward_mode == ForwardMode.DECODE: + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) + elif ( any(forward_batch.extend_prefix_lens_cpu) or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND ): From 54f0c1a592494a09682ba0d59d571721c8267a16 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 31 Mar 2025 15:34:13 -0700 Subject: [PATCH 08/12] revert --- .../attention/flashattention_backend.py | 111 +++++++++--------- 1 file changed, 54 insertions(+), 57 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index f3e42e75894..9847f7fe83d 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -82,58 +82,59 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): batch_size = len(seqlens_in_batch) device = seqlens_in_batch.device - if forward_batch.forward_mode == ForwardMode.DECODE and self.skip_prefill: - metadata.cu_seqlens_q = torch.arange( - 0, batch_size * self.topk + 1, dtype=torch.int32, device=device - ) - seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1) - metadata.cache_seqlens_int32 = ( - (seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32) - ) - metadata.cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32), - (1, 0), - ) - metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( - self.step_id + 1 - ) - metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ - forward_batch.req_pool_indices, : metadata.max_seq_len_k - ] # (bsz, max_seq_len) - metadata.page_table = metadata.page_table.repeat_interleave( - self.topk, dim=0 - ) - cache_loc = forward_batch.out_cache_loc.view( - self.speculative_num_steps, -1 - ).T - - # page table indices to update - # [bsz, topk] - row_indices = torch.arange( - batch_size * self.topk, device=device, dtype=torch.int32 - ).view(batch_size, self.topk) - # [max_seq_len : max_seq_len + step_id + 1] - col_indices = torch.arange( - forward_batch.seq_lens_cpu.max().item(), - metadata.max_seq_len_k, - device=device, - dtype=torch.int32, - ) - # mask for all valid page table indices - valid_mask = (col_indices.view(1, -1) >= seqlens_in_batch.view(-1, 1)) & ( - col_indices.view(1, -1) < seq_lens_with_decode.view(-1, 1) - ) - - # cache indices to read - cache_indices = torch.arange( - self.step_id + 1, device=device, dtype=torch.int32 - ) - - metadata.page_table[row_indices, col_indices] = torch.where( - valid_mask, - cache_loc[row_indices, cache_indices].to(torch.int32), - metadata.page_table[row_indices, col_indices], - ) + if forward_batch.forward_mode == ForwardMode.DECODE: + if self.skip_prefill: + metadata.cu_seqlens_q = torch.arange( + 0, batch_size * self.topk + 1, dtype=torch.int32, device=device + ) + seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1) + metadata.cache_seqlens_int32 = ( + (seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32) + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + # .repeat_interleave(self.topk) # tensor([7, 7, 7, 8, 8, 8]) + # .repeat(self.topk) # tensor([7, 8, 7, 8, 7, 8]) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( + self.step_id + 1 + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] # (bsz, max_seq_len) + metadata.page_table = metadata.page_table.repeat_interleave( + self.topk, dim=0 + ) + cache_loc = forward_batch.out_cache_loc.view( + self.speculative_num_steps, -1 + ).T + + for idx, single_seq_len in enumerate(seq_lens_with_decode): + real_bsz_start_idx = idx * self.topk + real_bsz_end_idx = (idx + 1) * self.topk + metadata.page_table[ + real_bsz_start_idx:real_bsz_end_idx, + (single_seq_len - (self.step_id + 1)) : single_seq_len, + ] = cache_loc[ + real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1) + ] + else: + metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + # Precompute maximum sequence length + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + # Precompute page table + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) elif forward_batch.forward_mode == ForwardMode.TARGET_VERIFY: draft_token_num = forward_batch.spec_info.draft_token_num @@ -182,11 +183,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.req_pool_indices, : metadata.max_seq_len_k ] # Precompute cumulative sequence lengths - if forward_batch.forward_mode == ForwardMode.DECODE: - metadata.cu_seqlens_q = torch.arange( - 0, batch_size + 1, dtype=torch.int32, device=device - ) - elif ( + if ( any(forward_batch.extend_prefix_lens_cpu) or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND ): From 654092d556f96720a50f44bf85520db7bc339017 Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Wed, 2 Apr 2025 03:13:30 +0000 Subject: [PATCH 09/12] Add Error Msg about Cuda Graph Not Supported --- .../attention/flashattention_backend.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 9847f7fe83d..502b67d296d 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -438,19 +438,23 @@ def init_cuda_graph_state(self, max_bs: int): This creates fixed-size tensors that will be reused during CUDA graph replay to avoid memory allocations. """ + raise NotImplementedError( + "FlashAttentionBackend does not support CUDA graph yet, stay tuned!" + ) + # Initialize fixed size tensors for decode operations - self.decode_cuda_graph_metadata = { - # Page table for token mapping (batch_size, max_context_len) - "page_table": torch.zeros( - max_bs, - (self.max_context_len + self.page_size - 1) // self.page_size, - dtype=torch.int32, - device=self.device, - ), - "strided_indices": torch.arange( - 0, self.max_context_len, self.page_size, device=self.device - ), - } + # self.decode_cuda_graph_metadata = { + # # Page table for token mapping (batch_size, max_context_len) + # "page_table": torch.zeros( + # max_bs, + # (self.max_context_len + self.page_size - 1) // self.page_size, + # dtype=torch.int32, + # device=self.device, + # ), + # "strided_indices": torch.arange( + # 0, self.max_context_len, self.page_size, device=self.device + # ), + # } def init_forward_metadata_capture_cuda_graph( self, From 511e38796cddfade4b4cd78f1529704cfd046b60 Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Wed, 2 Apr 2025 03:23:10 +0000 Subject: [PATCH 10/12] fix error msg --- .../attention/flashattention_backend.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index f25bb7ff328..fa7242503fa 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -438,23 +438,23 @@ def init_cuda_graph_state(self, max_bs: int): This creates fixed-size tensors that will be reused during CUDA graph replay to avoid memory allocations. """ - raise NotImplementedError( - "FlashAttentionBackend does not support CUDA graph yet, stay tuned!" - ) + if self.speculative_num_steps > 0: + raise NotImplementedError( + "FlashAttentionBackend Spec Decoding does not support CUDA graph yet, stay tuned!" + ) - # Initialize fixed size tensors for decode operations - # self.decode_cuda_graph_metadata = { - # # Page table for token mapping (batch_size, max_context_len) - # "page_table": torch.zeros( - # max_bs, - # (self.max_context_len + self.page_size - 1) // self.page_size, - # dtype=torch.int32, - # device=self.device, - # ), - # "strided_indices": torch.arange( - # 0, self.max_context_len, self.page_size, device=self.device - # ), - # } + self.decode_cuda_graph_metadata = { + # Page table for token mapping (batch_size, max_context_len) + "page_table": torch.zeros( + max_bs, + (self.max_context_len + self.page_size - 1) // self.page_size, + dtype=torch.int32, + device=self.device, + ), + "strided_indices": torch.arange( + 0, self.max_context_len, self.page_size, device=self.device + ), + } def init_forward_metadata_capture_cuda_graph( self, From 7364bbd3cd359ffa77bea65c776417f71008bace Mon Sep 17 00:00:00 2001 From: zcnrex Date: Wed, 2 Apr 2025 03:42:59 +0000 Subject: [PATCH 11/12] Vectorize draft decode page table indices compute --- .../attention/flashattention_backend.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index fa7242503fa..36c1353875a 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -112,15 +112,17 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.speculative_num_steps, -1 ).T - for idx, single_seq_len in enumerate(seq_lens_with_decode): - real_bsz_start_idx = idx * self.topk - real_bsz_end_idx = (idx + 1) * self.topk - metadata.page_table[ - real_bsz_start_idx:real_bsz_end_idx, - (single_seq_len - (self.step_id + 1)) : single_seq_len, - ] = cache_loc[ - real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1) - ] + # Calculate page table indices and cache location indices to update the page table. + batch_indices = torch.arange(batch_size, device=device).repeat_interleave(self.topk * (self.step_id + 1)) + topk_indices = torch.arange(self.topk, device=device).repeat(batch_size * (self.step_id + 1)) + row_indices = batch_indices * self.topk + topk_indices + + page_table_col_base_indices = seqlens_in_batch.unsqueeze(1) + torch.arange(self.step_id + 1, device=device) + page_table_col_indices = page_table_col_base_indices.view(-1).repeat(self.topk) + + cache_loc_col_indices = torch.arange(self.step_id + 1, device=device, dtype=torch.int32).repeat(batch_size * self.topk) + + metadata.page_table[row_indices, page_table_col_indices] = cache_loc[row_indices, cache_loc_col_indices].to(torch.int32) else: metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cu_seqlens_k = torch.nn.functional.pad( From 2a9984e0244b1dacb1733c567ed32f3c37e6806b Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Wed, 2 Apr 2025 05:15:39 +0000 Subject: [PATCH 12/12] Cleanup and Format --- .../attention/flashattention_backend.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 36c1353875a..1e220b9e2f5 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -97,32 +97,41 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ), (1, 0), ) - # .repeat_interleave(self.topk) # tensor([7, 7, 7, 8, 8, 8]) - # .repeat(self.topk) # tensor([7, 8, 7, 8, 7, 8]) metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( self.step_id + 1 ) metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k - ] # (bsz, max_seq_len) + ] metadata.page_table = metadata.page_table.repeat_interleave( self.topk, dim=0 ) cache_loc = forward_batch.out_cache_loc.view( self.speculative_num_steps, -1 ).T - # Calculate page table indices and cache location indices to update the page table. - batch_indices = torch.arange(batch_size, device=device).repeat_interleave(self.topk * (self.step_id + 1)) - topk_indices = torch.arange(self.topk, device=device).repeat(batch_size * (self.step_id + 1)) + batch_indices = torch.arange( + batch_size, device=device + ).repeat_interleave(self.topk * (self.step_id + 1)) + topk_indices = torch.arange(self.topk, device=device).repeat( + batch_size * (self.step_id + 1) + ) row_indices = batch_indices * self.topk + topk_indices - page_table_col_base_indices = seqlens_in_batch.unsqueeze(1) + torch.arange(self.step_id + 1, device=device) - page_table_col_indices = page_table_col_base_indices.view(-1).repeat(self.topk) + page_table_col_base_indices = seqlens_in_batch.unsqueeze( + 1 + ) + torch.arange(self.step_id + 1, device=device) + page_table_col_indices = page_table_col_base_indices.view(-1).repeat( + self.topk + ) - cache_loc_col_indices = torch.arange(self.step_id + 1, device=device, dtype=torch.int32).repeat(batch_size * self.topk) + cache_loc_col_indices = torch.arange( + self.step_id + 1, device=device, dtype=torch.int32 + ).repeat(batch_size * self.topk) - metadata.page_table[row_indices, page_table_col_indices] = cache_loc[row_indices, cache_loc_col_indices].to(torch.int32) + metadata.page_table[row_indices, page_table_col_indices] = cache_loc[ + row_indices, cache_loc_col_indices + ].to(torch.int32) else: metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cu_seqlens_k = torch.nn.functional.pad(