From 9f238a45b2317e2004334aad30584100659da891 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Thu, 2 Oct 2025 16:22:42 -0700 Subject: [PATCH 01/20] match_prefix --- python/sglang/srt/managers/schedule_batch.py | 38 ++++++------------- python/sglang/srt/managers/schedule_policy.py | 2 +- 2 files changed, 13 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 31b696f9f540..3a577bab671a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -685,11 +685,18 @@ def finished(self) -> bool: # Whether request reached finished condition return self.finished_reason is not None - def init_next_round_input( - self, - tree_cache: Optional[BasePrefixCache] = None, - ): + def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): self.fill_ids = self.origin_input_ids + self.output_ids + input_len = len(self.fill_ids) + # FIXME: To work around some bugs in logprob computation, we need to ensure each + # request has at least one token. Later, we can relax this requirement and use `input_len`. + max_prefix_len = input_len - 1 + if self.return_logprob: + max_prefix_len = min(max_prefix_len, self.logprob_start_len) + max_prefix_len = max(max_prefix_len, 0) + self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) + token_ids = self.fill_ids[:max_prefix_len] + if tree_cache is not None: ( self.prefix_indices, @@ -697,30 +704,9 @@ def init_next_round_input( self.last_host_node, self.host_hit_length, ) = tree_cache.match_prefix( - key=RadixKey( - token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key - ), + key=RadixKey(token_ids=token_ids, extra_key=self.extra_key) ) self.last_matched_prefix_len = len(self.prefix_indices) - self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) - - def adjust_max_prefix_ids(self): - self.fill_ids = self.origin_input_ids + self.output_ids - input_len = len(self.fill_ids) - - # FIXME: To work around some bugs in logprob computation, we need to ensure each - # request has at least one token. Later, we can relax this requirement and use `input_len`. - max_prefix_len = input_len - 1 - - if self.sampling_params.max_new_tokens > 0: - # Need at least one token to compute logits - max_prefix_len = min(max_prefix_len, input_len - 1) - - if self.return_logprob: - max_prefix_len = min(max_prefix_len, self.logprob_start_len) - - max_prefix_len = max(max_prefix_len, 0) - return self.fill_ids[:max_prefix_len] # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 def init_incremental_detokenize(self): diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 60633552bdb3..9b6057d32829 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -174,7 +174,7 @@ def _compute_prefix_matches( self.waiting_queue_radix_tree.reset() for r in waiting_queue: - prefix_ids = r.adjust_max_prefix_ids() + prefix_ids = r.origin_input_ids + r.output_ids extra_key = r.extra_key # NOTE: the prefix_indices must always be aligned with last_node From c0fc6a0da0536e007b1fb682121f2c2b2678015c Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Thu, 2 Oct 2025 20:07:52 -0700 Subject: [PATCH 02/20] fix --- python/sglang/srt/managers/schedule_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3a577bab671a..a1f2016ed965 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -694,7 +694,6 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): if self.return_logprob: max_prefix_len = min(max_prefix_len, self.logprob_start_len) max_prefix_len = max(max_prefix_len, 0) - self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) token_ids = self.fill_ids[:max_prefix_len] if tree_cache is not None: @@ -707,6 +706,7 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): key=RadixKey(token_ids=token_ids, extra_key=self.extra_key) ) self.last_matched_prefix_len = len(self.prefix_indices) + self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 def init_incremental_detokenize(self): From baa34d784eb20a9df8ee1eff1fd6cc2990e1c3a7 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Thu, 2 Oct 2025 20:22:55 -0700 Subject: [PATCH 03/20] prepare_for_extend: adjust order --- python/sglang/srt/managers/schedule_batch.py | 125 +++++++++++-------- 1 file changed, 74 insertions(+), 51 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a1f2016ed965..bd53f59d1437 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1108,6 +1108,37 @@ def alloc_paged_token_slots_decode( else: return out_cache_loc + def write_non_prefix_cache_locs( + self, + req_pool_indices: List[int], + prefix_lens: List[int], + seq_lens: List[int], + extend_lens: List[int], + out_cache_loc: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, + seq_lens_tensor: torch.Tensor, + extend_lens_tensor: torch.Tensor, + ): + if support_triton(global_server_args_dict.get("attention_backend")): + write_req_to_token_pool_triton[(len(req_pool_indices),)]( + self.req_to_token_pool.req_to_token, + req_pool_indices_tensor, + prefix_lens_tensor, + seq_lens_tensor, + extend_lens_tensor, + out_cache_loc, + self.req_to_token_pool.req_to_token.shape[1], + ) + else: + pt = 0 + for i in range(len(req_pool_indices)): + self.req_to_token_pool.write( + (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), + out_cache_loc[pt : pt + extend_lens[i]], + ) + pt += extend_lens[i] + def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): self.encoder_lens_cpu = [] self.encoder_cached = [] @@ -1185,10 +1216,6 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) def prepare_for_extend(self): self.forward_mode = ForwardMode.EXTEND - # Allocate req slots - bs = len(self.reqs) - req_pool_indices = self.alloc_req_slots(bs, self.reqs) - # Init tensors reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] @@ -1202,9 +1229,6 @@ def prepare_for_extend(self): r.token_type_ids for r in reqs if r.token_type_ids is not None ] - req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( - self.device, non_blocking=True - ) input_ids_tensor = torch.tensor( list(chain.from_iterable(input_ids)), dtype=torch.int64 ).to(self.device, non_blocking=True) @@ -1228,11 +1252,14 @@ def prepare_for_extend(self): extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor - # Copy prefix and do some basic check - input_embeds = [] - extend_input_logprob_token_ids = [] - multimodal_inputs = [] + # Allocate req slots + bs = len(self.reqs) + req_pool_indices = self.alloc_req_slots(bs, self.reqs) + req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( + self.device, non_blocking=True + ) + # Write prefix to req_to_token_pool for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)): req.req_pool_idx = req_pool_indices[i] assert seq_len - pre_len == req.extend_input_len @@ -1246,6 +1273,42 @@ def prepare_for_extend(self): req, pre_len, self.model_config.attention_chunk_size ) + # Allocate memory + if self.token_to_kv_pool_allocator.page_size == 1: + out_cache_loc = self.alloc_token_slots(extend_num_tokens) + else: + last_loc = get_last_loc( + self.req_to_token_pool.req_to_token, + req_pool_indices_tensor, + prefix_lens_tensor, + ) + out_cache_loc = self.alloc_paged_token_slots_extend( + prefix_lens_tensor, + prefix_lens_cpu_tensor, + seq_lens_tensor, + seq_lens_cpu, + last_loc, + extend_num_tokens, + ) + + self.write_non_prefix_cache_locs( + req_pool_indices, + prefix_lens, + seq_lens, + extend_lens, + out_cache_loc, + req_pool_indices_tensor, + prefix_lens_tensor, + seq_lens_tensor, + extend_lens_tensor, + ) + + # Copy prefix and do some basic check + input_embeds = [] + extend_input_logprob_token_ids = [] + multimodal_inputs = [] + + for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)): # If input_embeds are available, store them if req.input_embeds is not None: # If req.input_embeds is already a list, append its content directly @@ -1335,24 +1398,6 @@ def prepare_for_extend(self): else: extend_input_logprob_token_ids = None - # Allocate memory - if self.token_to_kv_pool_allocator.page_size == 1: - out_cache_loc = self.alloc_token_slots(extend_num_tokens) - else: - last_loc = get_last_loc( - self.req_to_token_pool.req_to_token, - req_pool_indices_tensor, - prefix_lens_tensor, - ) - out_cache_loc = self.alloc_paged_token_slots_extend( - prefix_lens_tensor, - prefix_lens_cpu_tensor, - seq_lens_tensor, - seq_lens_cpu, - last_loc, - extend_num_tokens, - ) - # Set fields self.input_ids = input_ids_tensor self.req_pool_indices = req_pool_indices_tensor @@ -1386,28 +1431,6 @@ def prepare_for_extend(self): self.extend_lens = extend_lens self.extend_input_logprob_token_ids = extend_input_logprob_token_ids - # Write to req_to_token_pool - if support_triton(global_server_args_dict.get("attention_backend")): - # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) - - write_req_to_token_pool_triton[(bs,)]( - self.req_to_token_pool.req_to_token, - req_pool_indices_tensor, - prefix_lens_tensor, - seq_lens_tensor, - extend_lens_tensor, - out_cache_loc, - self.req_to_token_pool.req_to_token.shape[1], - ) - else: - pt = 0 - for i in range(bs): - self.req_to_token_pool.write( - (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt : pt + extend_lens[i]], - ) - pt += extend_lens[i] - if self.model_config.is_encoder_decoder: self.prepare_encoder_info_extend(input_ids, seq_lens) From ffc57c0bc026ee7d2a6d15d30c1af8d5b144196c Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Fri, 3 Oct 2025 11:54:16 -0700 Subject: [PATCH 04/20] prepare_for_extend: unify indices writing --- python/sglang/srt/managers/schedule_batch.py | 66 +++++++++++++------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index bd53f59d1437..8c4a98e84654 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1108,7 +1108,7 @@ def alloc_paged_token_slots_decode( else: return out_cache_loc - def write_non_prefix_cache_locs( + def write_cache_indices( self, req_pool_indices: List[int], prefix_lens: List[int], @@ -1119,11 +1119,13 @@ def write_non_prefix_cache_locs( prefix_lens_tensor: torch.Tensor, seq_lens_tensor: torch.Tensor, extend_lens_tensor: torch.Tensor, + prefix_tensors: torch.Tensor, ): if support_triton(global_server_args_dict.get("attention_backend")): write_req_to_token_pool_triton[(len(req_pool_indices),)]( self.req_to_token_pool.req_to_token, req_pool_indices_tensor, + prefix_tensors, prefix_lens_tensor, seq_lens_tensor, extend_lens_tensor, @@ -1259,29 +1261,19 @@ def prepare_for_extend(self): self.device, non_blocking=True ) - # Write prefix to req_to_token_pool - for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)): - req.req_pool_idx = req_pool_indices[i] - assert seq_len - pre_len == req.extend_input_len - - if pre_len > 0: - self.req_to_token_pool.write( - (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices - ) - if isinstance(self.tree_cache, SWAChunkCache): - self.tree_cache.evict_swa( - req, pre_len, self.model_config.attention_chunk_size - ) - # Allocate memory if self.token_to_kv_pool_allocator.page_size == 1: out_cache_loc = self.alloc_token_slots(extend_num_tokens) else: - last_loc = get_last_loc( - self.req_to_token_pool.req_to_token, - req_pool_indices_tensor, - prefix_lens_tensor, - ) + last_loc = [] + for prefix in [r.prefix_indices for r in self.reqs]: + if len(prefix) > 0: + last_loc.append(prefix[-1:]) + else: + last_loc.append( + torch.tensor([-1], dtype=torch.int64, device="cuda") + ) + last_loc = torch.cat(last_loc) out_cache_loc = self.alloc_paged_token_slots_extend( prefix_lens_tensor, prefix_lens_cpu_tensor, @@ -1291,7 +1283,12 @@ def prepare_for_extend(self): extend_num_tokens, ) - self.write_non_prefix_cache_locs( + # Write allocated tokens to req_to_token_pool + prefix_tensors = torch.tensor( + [r.prefix_indices.data_ptr() for r in reqs], device=self.device + ) + + self.write_cache_indices( req_pool_indices, prefix_lens, seq_lens, @@ -1301,14 +1298,24 @@ def prepare_for_extend(self): prefix_lens_tensor, seq_lens_tensor, extend_lens_tensor, + prefix_tensors, ) - # Copy prefix and do some basic check + # Set fields input_embeds = [] extend_input_logprob_token_ids = [] multimodal_inputs = [] for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)): + req.req_pool_idx = req_pool_indices[i] + assert seq_len - pre_len == req.extend_input_len + + if pre_len > 0: + if isinstance(self.tree_cache, SWAChunkCache): + self.tree_cache.evict_swa( + req, pre_len, self.model_config.attention_chunk_size + ) + # If input_embeds are available, store them if req.input_embeds is not None: # If req.input_embeds is already a list, append its content directly @@ -1398,7 +1405,6 @@ def prepare_for_extend(self): else: extend_input_logprob_token_ids = None - # Set fields self.input_ids = input_ids_tensor self.req_pool_indices = req_pool_indices_tensor self.seq_lens = seq_lens_tensor @@ -2032,6 +2038,7 @@ class ModelWorkerBatch: def write_req_to_token_pool_triton( req_to_token_ptr, # [max_batch, max_context_len] req_pool_indices, + prefix_tensors, pre_lens, seq_lens, extend_lens, @@ -2044,6 +2051,19 @@ def write_req_to_token_pool_triton( req_pool_index = tl.load(req_pool_indices + pid) pre_len = tl.load(pre_lens + pid) seq_len = tl.load(seq_lens + pid) + prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64)) + + # write prefix + num_loop = tl.cdiv(pre_len, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < pre_len + value = tl.load(prefix_tensor + offset, mask=mask) + tl.store( + req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset, + value, + mask=mask, + ) # NOTE: This can be slow for large bs cumsum_start = tl.cast(0, tl.int64) From 75e3cf7fb79a92a472ce55b2cbfb622c5c300319 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Fri, 3 Oct 2025 13:37:41 -0700 Subject: [PATCH 05/20] clean --- python/sglang/srt/managers/schedule_batch.py | 21 +++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 8c4a98e84654..a8c6b52fb685 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -688,8 +688,7 @@ def finished(self) -> bool: def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): self.fill_ids = self.origin_input_ids + self.output_ids input_len = len(self.fill_ids) - # FIXME: To work around some bugs in logprob computation, we need to ensure each - # request has at least one token. Later, we can relax this requirement and use `input_len`. + # NOTE: the matched length is at most 1 less than the input length to enable logprob computation max_prefix_len = input_len - 1 if self.return_logprob: max_prefix_len = min(max_prefix_len, self.logprob_start_len) @@ -1119,13 +1118,17 @@ def write_cache_indices( prefix_lens_tensor: torch.Tensor, seq_lens_tensor: torch.Tensor, extend_lens_tensor: torch.Tensor, - prefix_tensors: torch.Tensor, + prefix_tensors: list[torch.Tensor], ): if support_triton(global_server_args_dict.get("attention_backend")): + prefix_pointers = torch.tensor( + [t.data_ptr() for t in prefix_tensors], device=self.device + ) + # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) write_req_to_token_pool_triton[(len(req_pool_indices),)]( self.req_to_token_pool.req_to_token, req_pool_indices_tensor, - prefix_tensors, + prefix_pointers, prefix_lens_tensor, seq_lens_tensor, extend_lens_tensor, @@ -1135,6 +1138,10 @@ def write_cache_indices( else: pt = 0 for i in range(len(req_pool_indices)): + self.req_to_token_pool.write( + (req_pool_indices[i], slice(0, prefix_lens[i])), + prefix_tensors[i], + ) self.req_to_token_pool.write( (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), out_cache_loc[pt : pt + extend_lens[i]], @@ -1284,10 +1291,6 @@ def prepare_for_extend(self): ) # Write allocated tokens to req_to_token_pool - prefix_tensors = torch.tensor( - [r.prefix_indices.data_ptr() for r in reqs], device=self.device - ) - self.write_cache_indices( req_pool_indices, prefix_lens, @@ -1298,7 +1301,7 @@ def prepare_for_extend(self): prefix_lens_tensor, seq_lens_tensor, extend_lens_tensor, - prefix_tensors, + [r.prefix_indices for r in reqs], ) # Set fields From 3fe0dd6aa8a76f7bd7251a9359c80520be7357ff Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Fri, 3 Oct 2025 14:00:32 -0700 Subject: [PATCH 06/20] clean --- python/sglang/srt/managers/schedule_batch.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a8c6b52fb685..57fda2e4b4eb 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1272,21 +1272,20 @@ def prepare_for_extend(self): if self.token_to_kv_pool_allocator.page_size == 1: out_cache_loc = self.alloc_token_slots(extend_num_tokens) else: - last_loc = [] - for prefix in [r.prefix_indices for r in self.reqs]: - if len(prefix) > 0: - last_loc.append(prefix[-1:]) - else: - last_loc.append( - torch.tensor([-1], dtype=torch.int64, device="cuda") - ) - last_loc = torch.cat(last_loc) + last_loc = [ + ( + r.prefix_indices[-1:] + if len(r.prefix_indices) > 0 + else torch.tensor([-1], device=self.device) + ) + for r in self.reqs + ] out_cache_loc = self.alloc_paged_token_slots_extend( prefix_lens_tensor, prefix_lens_cpu_tensor, seq_lens_tensor, seq_lens_cpu, - last_loc, + torch.cat(last_loc), extend_num_tokens, ) From ebb29fca36938b540119c39f657cedb590419af2 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Fri, 3 Oct 2025 15:36:59 -0700 Subject: [PATCH 07/20] set prefix_indices default --- python/sglang/bench_one_batch.py | 2 -- python/sglang/srt/managers/schedule_batch.py | 4 ++-- test/srt/test_forward_split_prefill.py | 1 - 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 9def9d8d0b0d..f8a35266bf6f 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -204,7 +204,6 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts): origin_input_ids=tmp_input_ids, sampling_params=sampling_params, ) - req.prefix_indices = [] req.fill_ids = req.origin_input_ids req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.logprob_start_len = len(req.origin_input_ids) - 1 @@ -248,7 +247,6 @@ def prepare_synthetic_inputs_for_latency_test( origin_input_ids=list(input_ids[i]), sampling_params=sampling_params, ) - req.prefix_indices = [] req.fill_ids = req.origin_input_ids req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.logprob_start_len = len(req.origin_input_ids) - 1 diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 57fda2e4b4eb..1b68743f0ac1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -537,7 +537,7 @@ def __init__( # Prefix info # The indices to kv cache for the shared prefix. - self.prefix_indices: torch.Tensor = [] + self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64) # Number of tokens to run prefill. self.extend_input_len = 0 # The relative logprob_start_len in an extend batch @@ -787,7 +787,7 @@ def check_finished(self): return def reset_for_retract(self): - self.prefix_indices = [] + self.prefix_indices = torch.empty((0,), dtype=torch.int64) self.last_node = None self.swa_uuid_for_lock = None self.extend_input_len = 0 diff --git a/test/srt/test_forward_split_prefill.py b/test/srt/test_forward_split_prefill.py index 060535687ccd..314e35ec9724 100644 --- a/test/srt/test_forward_split_prefill.py +++ b/test/srt/test_forward_split_prefill.py @@ -90,7 +90,6 @@ def prepare_test_batch(self, batch_size=2, input_len=128, is_split_prefill=True) origin_input_ids=list(input_ids[i]), sampling_params=sampling_params, ) - req.prefix_indices = [] req.fill_ids = req.origin_input_ids req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.logprob_start_len = len(req.origin_input_ids) - 1 From 5fa6bbf1672cf5cd9dcbda52defe3cf10f9013ea Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Fri, 3 Oct 2025 20:28:54 -0700 Subject: [PATCH 08/20] fix chunk cache --- python/sglang/srt/mem_cache/chunk_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 6ca8d9995001..54626dffd16e 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -60,7 +60,7 @@ def cache_unfinished_req(self, req: Req, chunked=False): ] # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later - req.prefix_indices = kv_indices + req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True) def evict(self, num_tokens: int): pass From 9d7b32fefff46f7115a433242d493936190ddae1 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Mon, 6 Oct 2025 16:15:59 -0700 Subject: [PATCH 09/20] prepare_for_extend --- python/sglang/srt/managers/schedule_batch.py | 288 +++++------------- python/sglang/srt/mem_cache/utils.py | 183 +++++++++++ python/sglang/srt/speculative/eagle_info.py | 7 +- python/sglang/srt/speculative/eagle_worker.py | 7 +- python/sglang/srt/speculative/ngram_utils.py | 7 +- 5 files changed, 261 insertions(+), 231 deletions(-) create mode 100644 python/sglang/srt/mem_cache/utils.py diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 1b68743f0ac1..20ddbd29c857 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -45,8 +45,6 @@ import numpy as np import torch -import triton -import triton.language as tl from sglang.global_config import global_config from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject @@ -65,12 +63,13 @@ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache +from sglang.srt.mem_cache.utils import write_cache_indices from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import flatten_nested_list, support_triton +from sglang.srt.utils import flatten_nested_list if TYPE_CHECKING: from sglang.srt.configs.model_config import ModelConfig @@ -1076,6 +1075,75 @@ def alloc_paged_token_slots_extend( else: return out_cache_loc + def alloc_for_extend(self, reqs: list[Req]): + # Extract all needed data from reqs + input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] + extend_num_tokens = sum(len(ids) for ids in input_ids) + + # CPU tensors + prefix_lens = torch.tensor( + [len(r.prefix_indices) for r in reqs], dtype=torch.int64 + ) + seq_lens_cpu = torch.tensor([len(r.fill_ids) for r in reqs], dtype=torch.int64) + extend_lens_cpu = seq_lens_cpu - prefix_lens + + # Copy to GPU + prefix_lens_device = prefix_lens.to(self.device, non_blocking=True) + seq_lens_device = seq_lens_cpu.to(self.device, non_blocking=True) + extend_lens_device = extend_lens_cpu.to(self.device, non_blocking=True) + + # Allocate req slots + bs = len(reqs) + req_pool_indices = self.alloc_req_slots(bs, reqs) + req_pool_indices_cpu = torch.tensor( + req_pool_indices, dtype=torch.int64, device=self.device + ) + req_pool_indices_device = req_pool_indices_cpu.to( + self.device, non_blocking=True + ) + + # Allocate memory + if self.token_to_kv_pool_allocator.page_size == 1: + out_cache_loc = self.alloc_token_slots(extend_num_tokens) + else: + last_loc = [ + ( + r.prefix_indices[-1:] + if len(r.prefix_indices) > 0 + else torch.tensor([-1], device=self.device) + ) + for r in reqs + ] + out_cache_loc = self.alloc_paged_token_slots_extend( + prefix_lens_device, + prefix_lens, + seq_lens_device, + seq_lens_cpu, + torch.cat(last_loc), + extend_num_tokens, + ) + + # Write allocated tokens to req_to_token_pool + write_cache_indices( + out_cache_loc, + req_pool_indices_device, + req_pool_indices_cpu, + prefix_lens_device, + prefix_lens, + seq_lens_device, + seq_lens_cpu, + extend_lens_device, + extend_lens_cpu, + [r.prefix_indices for r in reqs], + self.req_to_token_pool, + ) + + # update requests + for req, req_pool_index in zip(reqs, req_pool_indices): + req.req_pool_idx = req_pool_index + + return out_cache_loc, req_pool_indices_device + def alloc_paged_token_slots_decode( self, seq_lens: torch.Tensor, @@ -1107,47 +1175,6 @@ def alloc_paged_token_slots_decode( else: return out_cache_loc - def write_cache_indices( - self, - req_pool_indices: List[int], - prefix_lens: List[int], - seq_lens: List[int], - extend_lens: List[int], - out_cache_loc: torch.Tensor, - req_pool_indices_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, - seq_lens_tensor: torch.Tensor, - extend_lens_tensor: torch.Tensor, - prefix_tensors: list[torch.Tensor], - ): - if support_triton(global_server_args_dict.get("attention_backend")): - prefix_pointers = torch.tensor( - [t.data_ptr() for t in prefix_tensors], device=self.device - ) - # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) - write_req_to_token_pool_triton[(len(req_pool_indices),)]( - self.req_to_token_pool.req_to_token, - req_pool_indices_tensor, - prefix_pointers, - prefix_lens_tensor, - seq_lens_tensor, - extend_lens_tensor, - out_cache_loc, - self.req_to_token_pool.req_to_token.shape[1], - ) - else: - pt = 0 - for i in range(len(req_pool_indices)): - self.req_to_token_pool.write( - (req_pool_indices[i], slice(0, prefix_lens[i])), - prefix_tensors[i], - ) - self.req_to_token_pool.write( - (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt : pt + extend_lens[i]], - ) - pt += extend_lens[i] - def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): self.encoder_lens_cpu = [] self.encoder_cached = [] @@ -1248,10 +1275,6 @@ def prepare_for_extend(self): orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to( self.device, non_blocking=True ) - prefix_lens_tensor = torch.tensor( - prefix_lens, dtype=torch.int64, device=self.device - ) - prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64) token_type_ids_tensor = None if len(token_type_ids) > 0: @@ -1259,49 +1282,8 @@ def prepare_for_extend(self): sum(token_type_ids, []), dtype=torch.int64 ).to(self.device, non_blocking=True) - extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor - - # Allocate req slots - bs = len(self.reqs) - req_pool_indices = self.alloc_req_slots(bs, self.reqs) - req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( - self.device, non_blocking=True - ) - # Allocate memory - if self.token_to_kv_pool_allocator.page_size == 1: - out_cache_loc = self.alloc_token_slots(extend_num_tokens) - else: - last_loc = [ - ( - r.prefix_indices[-1:] - if len(r.prefix_indices) > 0 - else torch.tensor([-1], device=self.device) - ) - for r in self.reqs - ] - out_cache_loc = self.alloc_paged_token_slots_extend( - prefix_lens_tensor, - prefix_lens_cpu_tensor, - seq_lens_tensor, - seq_lens_cpu, - torch.cat(last_loc), - extend_num_tokens, - ) - - # Write allocated tokens to req_to_token_pool - self.write_cache_indices( - req_pool_indices, - prefix_lens, - seq_lens, - extend_lens, - out_cache_loc, - req_pool_indices_tensor, - prefix_lens_tensor, - seq_lens_tensor, - extend_lens_tensor, - [r.prefix_indices for r in reqs], - ) + out_cache_loc, req_pool_indices_tensor = self.alloc_for_extend(reqs) # Set fields input_embeds = [] @@ -1309,7 +1291,6 @@ def prepare_for_extend(self): multimodal_inputs = [] for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)): - req.req_pool_idx = req_pool_indices[i] assert seq_len - pre_len == req.extend_input_len if pre_len > 0: @@ -2034,128 +2015,3 @@ class ModelWorkerBatch: # Whether this batch is prefill-only (no token generation needed) is_prefill_only: bool = False - - -@triton.jit -def write_req_to_token_pool_triton( - req_to_token_ptr, # [max_batch, max_context_len] - req_pool_indices, - prefix_tensors, - pre_lens, - seq_lens, - extend_lens, - out_cache_loc, - req_to_token_ptr_stride: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 512 - pid = tl.program_id(0) - - req_pool_index = tl.load(req_pool_indices + pid) - pre_len = tl.load(pre_lens + pid) - seq_len = tl.load(seq_lens + pid) - prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64)) - - # write prefix - num_loop = tl.cdiv(pre_len, BLOCK_SIZE) - for i in range(num_loop): - offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE - mask = offset < pre_len - value = tl.load(prefix_tensor + offset, mask=mask) - tl.store( - req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset, - value, - mask=mask, - ) - - # NOTE: This can be slow for large bs - cumsum_start = tl.cast(0, tl.int64) - for i in range(pid): - cumsum_start += tl.load(extend_lens + i) - - num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE) - for i in range(num_loop): - offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE - mask = offset < (seq_len - pre_len) - value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask) - tl.store( - req_to_token_ptr - + req_pool_index * req_to_token_ptr_stride - + offset - + pre_len, - value, - mask=mask, - ) - - -def get_last_loc( - req_to_token: torch.Tensor, - req_pool_indices_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, -) -> torch.Tensor: - if ( - global_server_args_dict["attention_backend"] != "ascend" - and global_server_args_dict["attention_backend"] != "torch_native" - ): - impl = get_last_loc_triton - else: - impl = get_last_loc_torch - - return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor) - - -def get_last_loc_torch( - req_to_token: torch.Tensor, - req_pool_indices_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, -) -> torch.Tensor: - return torch.where( - prefix_lens_tensor > 0, - req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1], - torch.full_like(prefix_lens_tensor, -1), - ) - - -@triton.jit -def get_last_loc_kernel( - req_to_token, - req_pool_indices_tensor, - prefix_lens_tensor, - result, - num_tokens, - req_to_token_stride, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE - mask = offset < num_tokens - - prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0) - req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0) - - token_mask = prefix_lens > 0 - token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1) - tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1) - - tl.store(result + offset, tokens, mask=mask) - - -def get_last_loc_triton( - req_to_token: torch.Tensor, - req_pool_indices_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, -) -> torch.Tensor: - BLOCK_SIZE = 256 - num_tokens = prefix_lens_tensor.shape[0] - result = torch.empty_like(prefix_lens_tensor) - grid = (triton.cdiv(num_tokens, BLOCK_SIZE),) - - get_last_loc_kernel[grid]( - req_to_token, - req_pool_indices_tensor, - prefix_lens_tensor, - result, - num_tokens, - req_to_token.stride(0), - BLOCK_SIZE, - ) - return result diff --git a/python/sglang/srt/mem_cache/utils.py b/python/sglang/srt/mem_cache/utils.py new file mode 100644 index 000000000000..9fb08ef615c3 --- /dev/null +++ b/python/sglang/srt/mem_cache/utils.py @@ -0,0 +1,183 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import support_triton + +GLOBAL_SERVER_ARGS_KEYS = ["attention_backend"] + +global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS} + + +@triton.jit +def write_req_to_token_pool_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices, + prefix_tensors, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + req_to_token_ptr_stride: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(0) + + req_pool_index = tl.load(req_pool_indices + pid) + pre_len = tl.load(pre_lens + pid) + seq_len = tl.load(seq_lens + pid) + prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64)) + + # write prefix + num_loop = tl.cdiv(pre_len, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < pre_len + value = tl.load(prefix_tensor + offset, mask=mask) + tl.store( + req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset, + value, + mask=mask, + ) + + # NOTE: This can be slow for large bs + cumsum_start = tl.cast(0, tl.int64) + for i in range(pid): + cumsum_start += tl.load(extend_lens + i) + + num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < (seq_len - pre_len) + value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask) + tl.store( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + offset + + pre_len, + value, + mask=mask, + ) + + +def write_cache_indices( + out_cache_loc: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + req_pool_indices_cpu: torch.Tensor, + prefix_lens_tensor: torch.Tensor, + prefix_lens_cpu: torch.Tensor, + seq_lens_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + extend_lens_tensor: torch.Tensor, + extend_lens_cpu: torch.Tensor, + prefix_tensors: list[torch.Tensor], + req_to_token_pool: ReqToTokenPool, +): + if support_triton(global_server_args_dict.get("attention_backend")): + prefix_pointers = torch.tensor( + [t.data_ptr() for t in prefix_tensors], device=req_to_token_pool.device + ) + # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) + write_req_to_token_pool_triton[(req_pool_indices_tensor.shape[0],)]( + req_to_token_pool.req_to_token, + req_pool_indices_tensor, + prefix_pointers, + prefix_lens_tensor, + seq_lens_tensor, + extend_lens_tensor, + out_cache_loc, + req_to_token_pool.req_to_token.shape[1], + ) + else: + pt = 0 + for i in range(req_pool_indices_cpu.shape[0]): + req_idx = req_pool_indices_cpu[i].item() + prefix_len = prefix_lens_cpu[i].item() + seq_len = seq_lens_cpu[i].item() + extend_len = extend_lens_cpu[i].item() + + req_to_token_pool.write( + (req_idx, slice(0, prefix_len)), + prefix_tensors[i], + ) + req_to_token_pool.write( + (req_idx, slice(prefix_len, seq_len)), + out_cache_loc[pt : pt + extend_len], + ) + pt += extend_len + + +def get_last_loc( + req_to_token: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, +) -> torch.Tensor: + if ( + global_server_args_dict["attention_backend"] != "ascend" + and global_server_args_dict["attention_backend"] != "torch_native" + ): + impl = get_last_loc_triton + else: + impl = get_last_loc_torch + + return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor) + + +def get_last_loc_torch( + req_to_token: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, +) -> torch.Tensor: + return torch.where( + prefix_lens_tensor > 0, + req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1], + torch.full_like(prefix_lens_tensor, -1), + ) + + +@triton.jit +def get_last_loc_kernel( + req_to_token, + req_pool_indices_tensor, + prefix_lens_tensor, + result, + num_tokens, + req_to_token_stride, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE + mask = offset < num_tokens + + prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0) + req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0) + + token_mask = prefix_lens > 0 + token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1) + tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1) + + tl.store(result + offset, tokens, mask=mask) + + +def get_last_loc_triton( + req_to_token: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, +) -> torch.Tensor: + BLOCK_SIZE = 256 + num_tokens = prefix_lens_tensor.shape[0] + result = torch.empty_like(prefix_lens_tensor) + grid = (triton.cdiv(num_tokens, BLOCK_SIZE),) + + get_last_loc_kernel[grid]( + req_to_token, + req_pool_indices_tensor, + prefix_lens_tensor, + result, + num_tokens, + req_to_token.stride(0), + BLOCK_SIZE, + ) + return result diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 5d8c920c45d6..eea5c18f12a1 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -10,12 +10,9 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor -from sglang.srt.managers.schedule_batch import ( - ScheduleBatch, - get_last_loc, - global_server_args_dict, -) +from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator +from sglang.srt.mem_cache.utils import get_last_loc from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_utils import ( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 82bfaa276118..e1456da79a05 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -14,12 +14,9 @@ ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs -from sglang.srt.managers.schedule_batch import ( - ScheduleBatch, - get_last_loc, - global_server_args_dict, -) +from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.mem_cache.utils import get_last_loc from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, diff --git a/python/sglang/srt/speculative/ngram_utils.py b/python/sglang/srt/speculative/ngram_utils.py index 345fcbd66122..8edb13e69ad0 100644 --- a/python/sglang/srt/speculative/ngram_utils.py +++ b/python/sglang/srt/speculative/ngram_utils.py @@ -16,11 +16,8 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor -from sglang.srt.managers.schedule_batch import ( - ScheduleBatch, - get_last_loc, - global_server_args_dict, -) +from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict +from sglang.srt.mem_cache.utils import get_last_loc from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_utils import ( From a30f56bb300a42782d6d4b6501202dcffc18d755 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Mon, 6 Oct 2025 17:21:25 -0700 Subject: [PATCH 10/20] move functions out --- python/sglang/srt/managers/schedule_batch.py | 208 ++++++++++-------- .../srt/mem_cache/{utils.py => common.py} | 0 python/sglang/srt/speculative/eagle_info.py | 2 +- python/sglang/srt/speculative/eagle_worker.py | 2 +- python/sglang/srt/speculative/ngram_utils.py | 2 +- 5 files changed, 115 insertions(+), 99 deletions(-) rename python/sglang/srt/mem_cache/{utils.py => common.py} (100%) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 20ddbd29c857..d666968a156d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -60,10 +60,10 @@ ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache +from sglang.srt.mem_cache.common import write_cache_indices from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache -from sglang.srt.mem_cache.utils import write_cache_indices from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo @@ -844,6 +844,103 @@ def __repr__(self): ) +def alloc_for_extend( + batch: "ScheduleBatch", + reqs: list[Req], + prefix_lens: list[int], + seq_lens_cpu: torch.Tensor, + extend_num_tokens: int, +): + # Convert to tensors + prefix_lens_cpu = torch.tensor(prefix_lens, dtype=torch.int64) + extend_lens_cpu = seq_lens_cpu - prefix_lens_cpu + + # Copy to GPU + prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True) + seq_lens_device = seq_lens_cpu.to(batch.device, non_blocking=True) + extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True) + + # Allocate req slots + bs = len(reqs) + req_pool_indices = batch.alloc_req_slots(bs, reqs) + req_pool_indices = torch.tensor( + req_pool_indices, dtype=torch.int64, device=batch.device + ) + req_pool_indices_device = req_pool_indices.to(batch.device, non_blocking=True) + + # Allocate memory + if batch.token_to_kv_pool_allocator.page_size == 1: + out_cache_loc: torch.Tensor = batch.alloc_token_slots(extend_num_tokens) + else: + last_loc = [ + ( + r.prefix_indices[-1:] + if len(r.prefix_indices) > 0 + else torch.tensor([-1], device=batch.device) + ) + for r in reqs + ] + out_cache_loc: torch.Tensor = batch.alloc_paged_token_slots_extend( + prefix_lens_device, + prefix_lens_cpu, + seq_lens_device, + seq_lens_cpu, + torch.cat(last_loc), + extend_num_tokens, + ) + + # Write allocated tokens to req_to_token_pool + write_cache_indices( + out_cache_loc, + req_pool_indices_device, + req_pool_indices, + prefix_lens_device, + prefix_lens_cpu, + seq_lens_device, + seq_lens_cpu, + extend_lens_device, + extend_lens_cpu, + [r.prefix_indices for r in reqs], + batch.req_to_token_pool, + ) + + # update requests + for req, req_pool_index in zip(reqs, req_pool_indices.tolist()): + req.req_pool_idx = req_pool_index + + return out_cache_loc, req_pool_indices_device + + +def alloc_for_decode(batch: "ScheduleBatch", token_per_req: int): + + bs = len(batch.reqs) + seq_lens = batch.seq_lens + req_pool_indices = batch.req_pool_indices + + # Allocate token slots + if batch.token_to_kv_pool_allocator.page_size == 1: + out_cache_loc = batch.alloc_token_slots(bs * token_per_req) + else: + # Get the last token's KV cache location + last_loc = batch.req_to_token_pool.req_to_token[req_pool_indices, seq_lens - 1] + # Prepare tensors for allocation + seq_lens_next = seq_lens + token_per_req + out_cache_loc = batch.alloc_paged_token_slots_decode( + seq_lens_next, batch.seq_lens_cpu + token_per_req, last_loc + ) + + if batch.model_config.is_encoder_decoder: + locs = batch.encoder_lens + seq_lens + else: + locs = seq_lens.clone() + + batch.req_to_token_pool.write( + (req_pool_indices, locs), out_cache_loc.to(torch.int32) + ) + + return out_cache_loc + + @dataclasses.dataclass class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): """Store all information of a batch on the scheduler.""" @@ -1075,75 +1172,6 @@ def alloc_paged_token_slots_extend( else: return out_cache_loc - def alloc_for_extend(self, reqs: list[Req]): - # Extract all needed data from reqs - input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] - extend_num_tokens = sum(len(ids) for ids in input_ids) - - # CPU tensors - prefix_lens = torch.tensor( - [len(r.prefix_indices) for r in reqs], dtype=torch.int64 - ) - seq_lens_cpu = torch.tensor([len(r.fill_ids) for r in reqs], dtype=torch.int64) - extend_lens_cpu = seq_lens_cpu - prefix_lens - - # Copy to GPU - prefix_lens_device = prefix_lens.to(self.device, non_blocking=True) - seq_lens_device = seq_lens_cpu.to(self.device, non_blocking=True) - extend_lens_device = extend_lens_cpu.to(self.device, non_blocking=True) - - # Allocate req slots - bs = len(reqs) - req_pool_indices = self.alloc_req_slots(bs, reqs) - req_pool_indices_cpu = torch.tensor( - req_pool_indices, dtype=torch.int64, device=self.device - ) - req_pool_indices_device = req_pool_indices_cpu.to( - self.device, non_blocking=True - ) - - # Allocate memory - if self.token_to_kv_pool_allocator.page_size == 1: - out_cache_loc = self.alloc_token_slots(extend_num_tokens) - else: - last_loc = [ - ( - r.prefix_indices[-1:] - if len(r.prefix_indices) > 0 - else torch.tensor([-1], device=self.device) - ) - for r in reqs - ] - out_cache_loc = self.alloc_paged_token_slots_extend( - prefix_lens_device, - prefix_lens, - seq_lens_device, - seq_lens_cpu, - torch.cat(last_loc), - extend_num_tokens, - ) - - # Write allocated tokens to req_to_token_pool - write_cache_indices( - out_cache_loc, - req_pool_indices_device, - req_pool_indices_cpu, - prefix_lens_device, - prefix_lens, - seq_lens_device, - seq_lens_cpu, - extend_lens_device, - extend_lens_cpu, - [r.prefix_indices for r in reqs], - self.req_to_token_pool, - ) - - # update requests - for req, req_pool_index in zip(reqs, req_pool_indices): - req.req_pool_idx = req_pool_index - - return out_cache_loc, req_pool_indices_device - def alloc_paged_token_slots_decode( self, seq_lens: torch.Tensor, @@ -1283,7 +1311,9 @@ def prepare_for_extend(self): ).to(self.device, non_blocking=True) # Allocate memory - out_cache_loc, req_pool_indices_tensor = self.alloc_for_extend(reqs) + out_cache_loc, req_pool_indices_tensor = alloc_for_extend( + self, reqs, prefix_lens, seq_lens_cpu, extend_num_tokens + ) # Set fields input_embeds = [] @@ -1657,11 +1687,19 @@ def prepare_for_decode(self): self.output_ids = None if self.model_config.is_encoder_decoder: - locs = self.encoder_lens + self.seq_lens self.prepare_encoder_info_decode() - else: - locs = self.seq_lens.clone() + # free memory + if isinstance(self.tree_cache, SWAChunkCache): + for req in self.reqs: + self.tree_cache.evict_swa( + req, req.seqlen - 1, self.model_config.attention_chunk_size + ) + + # Allocate memory based on current state + self.out_cache_loc = alloc_for_decode(self, 1) # allocate 1 token per request + + # Update seq_lens after allocation if self.enable_overlap: # Do not use in-place operations in the overlap mode self.seq_lens = self.seq_lens + 1 @@ -1674,28 +1712,6 @@ def prepare_for_decode(self): self.orig_seq_lens.add_(1) self.seq_lens_sum += bs - # free memory - if isinstance(self.tree_cache, SWAChunkCache): - for req in self.reqs: - self.tree_cache.evict_swa( - req, req.seqlen - 1, self.model_config.attention_chunk_size - ) - - # Allocate memory - if self.token_to_kv_pool_allocator.page_size == 1: - self.out_cache_loc = self.alloc_token_slots(bs) - else: - last_loc = self.req_to_token_pool.req_to_token[ - self.req_pool_indices, self.seq_lens - 2 - ] - self.out_cache_loc = self.alloc_paged_token_slots_decode( - self.seq_lens, self.seq_lens_cpu, last_loc - ) - - self.req_to_token_pool.write( - (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32) - ) - def filter_batch( self, chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None, diff --git a/python/sglang/srt/mem_cache/utils.py b/python/sglang/srt/mem_cache/common.py similarity index 100% rename from python/sglang/srt/mem_cache/utils.py rename to python/sglang/srt/mem_cache/common.py diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index eea5c18f12a1..dfda8fa0fc78 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -12,7 +12,7 @@ from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator -from sglang.srt.mem_cache.utils import get_last_loc +from sglang.srt.mem_cache.common import get_last_loc from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_utils import ( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index e1456da79a05..4ed77f266bab 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -16,7 +16,7 @@ from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.mem_cache.utils import get_last_loc +from sglang.srt.mem_cache.common import get_last_loc from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, diff --git a/python/sglang/srt/speculative/ngram_utils.py b/python/sglang/srt/speculative/ngram_utils.py index 8edb13e69ad0..e40788673ff3 100644 --- a/python/sglang/srt/speculative/ngram_utils.py +++ b/python/sglang/srt/speculative/ngram_utils.py @@ -17,7 +17,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict -from sglang.srt.mem_cache.utils import get_last_loc +from sglang.srt.mem_cache.common import get_last_loc from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_utils import ( From 4d4a27280473e522db11aebed7c6d42e0964db2a Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Mon, 6 Oct 2025 19:11:17 -0700 Subject: [PATCH 11/20] better --- python/sglang/srt/managers/schedule_batch.py | 154 ++++-------------- python/sglang/srt/mem_cache/common.py | 145 ++++++++++++++++- python/sglang/srt/speculative/eagle_info.py | 14 +- python/sglang/srt/speculative/eagle_worker.py | 15 +- python/sglang/srt/speculative/ngram_utils.py | 13 +- 5 files changed, 206 insertions(+), 135 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index d666968a156d..109b0d442012 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -60,7 +60,13 @@ ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache -from sglang.srt.mem_cache.common import write_cache_indices +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_decode, + alloc_paged_token_slots_extend, + alloc_req_slots, + alloc_token_slots, + write_cache_indices, +) from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache @@ -862,7 +868,7 @@ def alloc_for_extend( # Allocate req slots bs = len(reqs) - req_pool_indices = batch.alloc_req_slots(bs, reqs) + req_pool_indices = alloc_req_slots(batch.req_to_token_pool, bs, reqs) req_pool_indices = torch.tensor( req_pool_indices, dtype=torch.int64, device=batch.device ) @@ -870,7 +876,10 @@ def alloc_for_extend( # Allocate memory if batch.token_to_kv_pool_allocator.page_size == 1: - out_cache_loc: torch.Tensor = batch.alloc_token_slots(extend_num_tokens) + out_cache_loc: torch.Tensor = alloc_token_slots( + batch.tree_cache, + extend_num_tokens, + ) else: last_loc = [ ( @@ -880,13 +889,14 @@ def alloc_for_extend( ) for r in reqs ] - out_cache_loc: torch.Tensor = batch.alloc_paged_token_slots_extend( - prefix_lens_device, - prefix_lens_cpu, - seq_lens_device, - seq_lens_cpu, - torch.cat(last_loc), - extend_num_tokens, + out_cache_loc: torch.Tensor = alloc_paged_token_slots_extend( + tree_cache=batch.tree_cache, + prefix_lens=prefix_lens_device, + prefix_lens_cpu=prefix_lens_cpu, + seq_lens=seq_lens_device, + seq_lens_cpu=seq_lens_cpu, + last_loc=torch.cat(last_loc), + extend_num_tokens=extend_num_tokens, ) # Write allocated tokens to req_to_token_pool @@ -912,21 +922,26 @@ def alloc_for_extend( def alloc_for_decode(batch: "ScheduleBatch", token_per_req: int): - bs = len(batch.reqs) seq_lens = batch.seq_lens req_pool_indices = batch.req_pool_indices # Allocate token slots if batch.token_to_kv_pool_allocator.page_size == 1: - out_cache_loc = batch.alloc_token_slots(bs * token_per_req) + out_cache_loc = alloc_token_slots( + batch.tree_cache, + bs * token_per_req, + ) else: # Get the last token's KV cache location last_loc = batch.req_to_token_pool.req_to_token[req_pool_indices, seq_lens - 1] # Prepare tensors for allocation seq_lens_next = seq_lens + token_per_req - out_cache_loc = batch.alloc_paged_token_slots_decode( - seq_lens_next, batch.seq_lens_cpu + token_per_req, last_loc + out_cache_loc = alloc_paged_token_slots_decode( + batch.tree_cache, + seq_lens_next, + batch.seq_lens_cpu + token_per_req, + last_loc, ) if batch.model_config.is_encoder_decoder: @@ -1092,117 +1107,6 @@ def batch_size(self): def is_empty(self): return len(self.reqs) == 0 - def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None): - if isinstance(self.req_to_token_pool, HybridReqToTokenPool): - req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs) - else: - req_pool_indices = self.req_to_token_pool.alloc(num_reqs) - if req_pool_indices is None: - raise RuntimeError( - "alloc_req_slots runs out of memory. " - "Please set a smaller number for `--max-running-requests`. " - f"{self.req_to_token_pool.available_size()=}, " - f"{num_reqs=}, " - ) - return req_pool_indices - - def alloc_token_slots(self, num_tokens: int, backup_state: bool = False): - self._evict_tree_cache_if_needed(num_tokens) - - if backup_state: - state = self.token_to_kv_pool_allocator.backup_state() - - out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens) - if out_cache_loc is None: - phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode" - error_msg = ( - f"{phase_str} out of memory. Try to lower your batch size.\n" - f"Try to allocate {num_tokens} tokens.\n" - f"{self._available_and_evictable_str()}" - ) - logger.error(error_msg) - if self.tree_cache is not None: - self.tree_cache.pretty_print() - raise RuntimeError(error_msg) - - if backup_state: - return out_cache_loc, state - else: - return out_cache_loc - - def alloc_paged_token_slots_extend( - self, - prefix_lens: torch.Tensor, - prefix_lens_cpu: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_cpu: torch.Tensor, - last_loc: torch.Tensor, - extend_num_tokens: int, - backup_state: bool = False, - ): - # Over estimate the number of tokens: assume each request needs a new page. - num_tokens = ( - extend_num_tokens - + len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size - ) - self._evict_tree_cache_if_needed(num_tokens) - - if backup_state: - state = self.token_to_kv_pool_allocator.backup_state() - - out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend( - prefix_lens, - prefix_lens_cpu, - seq_lens, - seq_lens_cpu, - last_loc, - extend_num_tokens, - ) - if out_cache_loc is None: - error_msg = ( - f"Prefill out of memory. Try to lower your batch size.\n" - f"Try to allocate {extend_num_tokens} tokens.\n" - f"{self._available_and_evictable_str()}" - ) - logger.error(error_msg) - raise RuntimeError(error_msg) - - if backup_state: - return out_cache_loc, state - else: - return out_cache_loc - - def alloc_paged_token_slots_decode( - self, - seq_lens: torch.Tensor, - seq_lens_cpu: torch.Tensor, - last_loc: torch.Tensor, - backup_state: bool = False, - ): - # Over estimate the number of tokens: assume each request needs a new page. - num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size - self._evict_tree_cache_if_needed(num_tokens) - - if backup_state: - state = self.token_to_kv_pool_allocator.backup_state() - - out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode( - seq_lens, seq_lens_cpu, last_loc - ) - if out_cache_loc is None: - error_msg = ( - f"Decode out of memory. Try to lower your batch size.\n" - f"Try to allocate {len(seq_lens)} tokens.\n" - f"{self._available_and_evictable_str()}" - ) - logger.error(error_msg) - raise RuntimeError(error_msg) - - if backup_state: - return out_cache_loc, state - else: - return out_cache_loc - def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): self.encoder_lens_cpu = [] self.encoder_cached = [] diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 9fb08ef615c3..4230a580bff5 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -1,11 +1,22 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + import torch import triton import triton.language as tl -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.server_args import ServerArgs from sglang.srt.utils import support_triton +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req + +logger = logging.getLogger(__name__) + GLOBAL_SERVER_ARGS_KEYS = ["attention_backend"] global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS} @@ -181,3 +192,135 @@ def get_last_loc_triton( BLOCK_SIZE, ) return result + + +def alloc_token_slots( + tree_cache: BasePrefixCache, + num_tokens: int, + backup_state: bool = False, +): + allocator = tree_cache.token_to_kv_pool_allocator + _evict_from_tree_cache(tree_cache, num_tokens) + + state = None + if backup_state: + state = allocator.backup_state() + + out_cache_loc = allocator.alloc(num_tokens) + + if out_cache_loc is None: + return None + + return (out_cache_loc, state) if backup_state else out_cache_loc + + +def _evict_from_tree_cache(tree_cache: BasePrefixCache | None, num_tokens: int): + """Helper to evict from tree cache if needed. Handles both hybrid and standard allocators.""" + if tree_cache is None: + return + + # Check if this is ChunkCache or SWAChunkCache - these don't support eviction + from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache + + if isinstance(tree_cache, (SWAChunkCache, ChunkCache)): + return + + allocator = tree_cache.token_to_kv_pool_allocator + + # Check if this is a hybrid allocator + if hasattr(allocator, "full_available_size"): + # Hybrid allocator + full_available_size = allocator.full_available_size() + swa_available_size = allocator.swa_available_size() + + if full_available_size < num_tokens or swa_available_size < num_tokens: + full_num_tokens = max(0, num_tokens - full_available_size) + swa_num_tokens = max(0, num_tokens - swa_available_size) + tree_cache.evict(full_num_tokens, swa_num_tokens) + else: + # Standard allocator + if allocator.available_size() < num_tokens: + tree_cache.evict(num_tokens) + + +def alloc_paged_token_slots_decode( + tree_cache: BasePrefixCache, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, + backup_state: bool = False, +): + """Allocate paged token slots for decode with overestimation for safety.""" + allocator = tree_cache.token_to_kv_pool_allocator + # Over estimate the number of tokens: assume each request needs a new page. + num_tokens = len(seq_lens) * allocator.page_size + _evict_from_tree_cache(tree_cache, num_tokens) + + state = None + if backup_state: + state = allocator.backup_state() + + out_cache_loc = allocator.alloc_decode(seq_lens, seq_lens_cpu, last_loc) + + if out_cache_loc is None: + return None + + return (out_cache_loc, state) if backup_state else out_cache_loc + + +def alloc_paged_token_slots_extend( + tree_cache: BasePrefixCache, + prefix_lens: torch.Tensor, + prefix_lens_cpu: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, + extend_num_tokens: int, + backup_state: bool = False, +): + # Over estimate the number of tokens: assume each request needs a new page. + allocator = tree_cache.token_to_kv_pool_allocator + num_tokens = extend_num_tokens + len(seq_lens_cpu) * allocator.page_size + _evict_from_tree_cache(tree_cache, num_tokens) + + state = None + if backup_state: + state = allocator.backup_state() + + out_cache_loc = allocator.alloc_extend( + prefix_lens, + prefix_lens_cpu, + seq_lens, + seq_lens_cpu, + last_loc, + extend_num_tokens, + ) + + if out_cache_loc is None: + return None + + return (out_cache_loc, state) if backup_state else out_cache_loc + + +def alloc_req_slots( + req_to_token_pool: ReqToTokenPool, + num_reqs: int, + reqs: list[Req] | None, +) -> list[int]: + """Allocate request slots from the pool.""" + match req_to_token_pool: + case ReqToTokenPool(): + req_pool_indices = req_to_token_pool.alloc(num_reqs) + case HybridReqToTokenPool(): + req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs) + case _: + raise ValueError(f"Unknown {type(req_to_token_pool)=}") + + if req_pool_indices is None: + raise RuntimeError( + "alloc_req_slots runs out of memory. " + "Please set a smaller number for `--max-running-requests`. " + f"{req_to_token_pool.available_size()=}, " + f"{num_reqs=}, " + ) + return req_pool_indices diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index dfda8fa0fc78..46ecc1b32490 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -12,7 +12,11 @@ from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator -from sglang.srt.mem_cache.common import get_last_loc +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, + get_last_loc, +) from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_utils import ( @@ -97,7 +101,10 @@ def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): batch.input_ids = self.draft_token if page_size == 1: - batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) + batch.out_cache_loc = alloc_token_slots( + batch.tree_cache, + len(batch.input_ids), + ) end_offset = batch.seq_lens + self.draft_token_num else: prefix_lens = batch.seq_lens @@ -109,7 +116,8 @@ def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): batch.req_pool_indices, prefix_lens, ) - batch.out_cache_loc = batch.alloc_paged_token_slots_extend( + batch.out_cache_loc = alloc_paged_token_slots_extend( + batch.tree_cache, prefix_lens, prefix_lens_cpu, end_offset, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 4ed77f266bab..3886f701ed90 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -16,7 +16,11 @@ from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.mem_cache.common import get_last_loc +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, + get_last_loc, +) from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, @@ -533,8 +537,10 @@ def _draft_preprocess_decode(self, batch: ScheduleBatch): # [ topk 0 ] [ topk 1 ] # [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2] if self.page_size == 1: - out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots( - num_seqs * self.speculative_num_steps * self.topk, backup_state=True + out_cache_loc, token_to_kv_pool_state_backup = alloc_token_slots( + batch.tree_cache, + num_seqs * self.speculative_num_steps * self.topk, + backup_state=True, ) else: if self.topk == 1: @@ -593,7 +599,8 @@ def _draft_preprocess_decode(self, batch: ScheduleBatch): extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item() out_cache_loc, token_to_kv_pool_state_backup = ( - batch.alloc_paged_token_slots_extend( + alloc_paged_token_slots_extend( + batch.tree_cache, prefix_lens, prefix_lens_cpu, seq_lens, diff --git a/python/sglang/srt/speculative/ngram_utils.py b/python/sglang/srt/speculative/ngram_utils.py index e40788673ff3..d9f8f369182c 100644 --- a/python/sglang/srt/speculative/ngram_utils.py +++ b/python/sglang/srt/speculative/ngram_utils.py @@ -7,6 +7,11 @@ import torch import triton +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, +) + logger = logging.getLogger(__name__) from dataclasses import dataclass @@ -71,7 +76,10 @@ def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): batch.input_ids = self.draft_token if page_size == 1: - batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) + batch.out_cache_loc = alloc_token_slots( + batch.tree_cache, + len(batch.input_ids), + ) end_offset = batch.seq_lens + self.draft_token_num else: # TODO(lsyin): add prefix lens cpu here to support page size > 1 @@ -84,7 +92,8 @@ def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): batch.req_pool_indices, prefix_lens, ) - batch.out_cache_loc = batch.alloc_paged_token_slots_extend( + batch.out_cache_loc = alloc_paged_token_slots_extend( + batch.tree_cache, prefix_lens, prefix_lens_cpu, end_offset, From e3e1d28e4ca4a4f698196c1650247c588831316f Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Tue, 7 Oct 2025 14:37:23 -0700 Subject: [PATCH 12/20] wip --- python/sglang/srt/managers/schedule_batch.py | 208 ++++--------------- python/sglang/srt/mem_cache/common.py | 139 +++++++++++++ 2 files changed, 185 insertions(+), 162 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 109b0d442012..73d040ba04f4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -60,13 +60,7 @@ ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache -from sglang.srt.mem_cache.common import ( - alloc_paged_token_slots_decode, - alloc_paged_token_slots_extend, - alloc_req_slots, - alloc_token_slots, - write_cache_indices, -) +from sglang.srt.mem_cache.common import alloc_decode_batch, alloc_extend_batch from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache @@ -850,112 +844,6 @@ def __repr__(self): ) -def alloc_for_extend( - batch: "ScheduleBatch", - reqs: list[Req], - prefix_lens: list[int], - seq_lens_cpu: torch.Tensor, - extend_num_tokens: int, -): - # Convert to tensors - prefix_lens_cpu = torch.tensor(prefix_lens, dtype=torch.int64) - extend_lens_cpu = seq_lens_cpu - prefix_lens_cpu - - # Copy to GPU - prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True) - seq_lens_device = seq_lens_cpu.to(batch.device, non_blocking=True) - extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True) - - # Allocate req slots - bs = len(reqs) - req_pool_indices = alloc_req_slots(batch.req_to_token_pool, bs, reqs) - req_pool_indices = torch.tensor( - req_pool_indices, dtype=torch.int64, device=batch.device - ) - req_pool_indices_device = req_pool_indices.to(batch.device, non_blocking=True) - - # Allocate memory - if batch.token_to_kv_pool_allocator.page_size == 1: - out_cache_loc: torch.Tensor = alloc_token_slots( - batch.tree_cache, - extend_num_tokens, - ) - else: - last_loc = [ - ( - r.prefix_indices[-1:] - if len(r.prefix_indices) > 0 - else torch.tensor([-1], device=batch.device) - ) - for r in reqs - ] - out_cache_loc: torch.Tensor = alloc_paged_token_slots_extend( - tree_cache=batch.tree_cache, - prefix_lens=prefix_lens_device, - prefix_lens_cpu=prefix_lens_cpu, - seq_lens=seq_lens_device, - seq_lens_cpu=seq_lens_cpu, - last_loc=torch.cat(last_loc), - extend_num_tokens=extend_num_tokens, - ) - - # Write allocated tokens to req_to_token_pool - write_cache_indices( - out_cache_loc, - req_pool_indices_device, - req_pool_indices, - prefix_lens_device, - prefix_lens_cpu, - seq_lens_device, - seq_lens_cpu, - extend_lens_device, - extend_lens_cpu, - [r.prefix_indices for r in reqs], - batch.req_to_token_pool, - ) - - # update requests - for req, req_pool_index in zip(reqs, req_pool_indices.tolist()): - req.req_pool_idx = req_pool_index - - return out_cache_loc, req_pool_indices_device - - -def alloc_for_decode(batch: "ScheduleBatch", token_per_req: int): - bs = len(batch.reqs) - seq_lens = batch.seq_lens - req_pool_indices = batch.req_pool_indices - - # Allocate token slots - if batch.token_to_kv_pool_allocator.page_size == 1: - out_cache_loc = alloc_token_slots( - batch.tree_cache, - bs * token_per_req, - ) - else: - # Get the last token's KV cache location - last_loc = batch.req_to_token_pool.req_to_token[req_pool_indices, seq_lens - 1] - # Prepare tensors for allocation - seq_lens_next = seq_lens + token_per_req - out_cache_loc = alloc_paged_token_slots_decode( - batch.tree_cache, - seq_lens_next, - batch.seq_lens_cpu + token_per_req, - last_loc, - ) - - if batch.model_config.is_encoder_decoder: - locs = batch.encoder_lens + seq_lens - else: - locs = seq_lens.clone() - - batch.req_to_token_pool.write( - (req_pool_indices, locs), out_cache_loc.to(torch.int32) - ) - - return out_cache_loc - - @dataclasses.dataclass class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): """Store all information of a batch on the scheduler.""" @@ -1030,6 +918,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): extend_logprob_start_lens: List[int] = None # It comes empty list if logprob is not required. extend_input_logprob_token_ids: Optional[torch.Tensor] = None + # For allocation - list of prefix_indices tensors per request + prefix_indices_list: Optional[List[torch.Tensor]] = None # For encoder-decoder architectures encoder_cached: Optional[List[bool]] = None @@ -1183,48 +1073,55 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) def prepare_for_extend(self): self.forward_mode = ForwardMode.EXTEND - - # Init tensors reqs = self.reqs + + # Phase 1: Prepare tensors and assign to batch input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] - extend_num_tokens = sum(len(ids) for ids in input_ids) - seq_lens = [len(r.fill_ids) for r in reqs] - orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs] - prefix_lens = [len(r.prefix_indices) for r in reqs] - extend_lens = [r.extend_input_len for r in reqs] + self.input_ids = torch.tensor( + list(chain.from_iterable(input_ids)), dtype=torch.int64 + ).to(self.device, non_blocking=True) + + self.orig_seq_lens = torch.tensor( + [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs], + dtype=torch.int32, + ).to(self.device, non_blocking=True) token_type_ids = [ r.token_type_ids for r in reqs if r.token_type_ids is not None ] - - input_ids_tensor = torch.tensor( - list(chain.from_iterable(input_ids)), dtype=torch.int64 - ).to(self.device, non_blocking=True) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( - self.device, non_blocking=True - ) - seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64) - orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to( - self.device, non_blocking=True - ) - - token_type_ids_tensor = None if len(token_type_ids) > 0: - token_type_ids_tensor = torch.tensor( + self.token_type_ids = torch.tensor( sum(token_type_ids, []), dtype=torch.int64 ).to(self.device, non_blocking=True) + else: + self.token_type_ids = None - # Allocate memory - out_cache_loc, req_pool_indices_tensor = alloc_for_extend( - self, reqs, prefix_lens, seq_lens_cpu, extend_num_tokens + # Batch-level data for allocation + self.prefix_lens = [len(r.prefix_indices) for r in reqs] + self.extend_lens = [r.extend_input_len for r in reqs] + self.extend_num_tokens = sum( + len(r.fill_ids) - len(r.prefix_indices) for r in reqs ) + self.seq_lens_cpu = torch.tensor( + [len(r.fill_ids) for r in reqs], dtype=torch.int64 + ) + self.seq_lens = self.seq_lens_cpu.to(self.device, non_blocking=True) + self.seq_lens_sum = self.seq_lens_cpu.sum().item() + self.prefix_indices_list = [r.prefix_indices for r in reqs] + + # Phase 2: Allocate memory - pure allocation using batch fields + out_cache_loc, req_pool_indices_tensor = alloc_extend_batch(self) + self.out_cache_loc = out_cache_loc + self.req_pool_indices = req_pool_indices_tensor - # Set fields + # Phase 3: Process requests - validation, adjustment, logprob setup input_embeds = [] extend_input_logprob_token_ids = [] multimodal_inputs = [] - for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)): + for req in reqs: + seq_len = len(req.fill_ids) + pre_len = len(req.prefix_indices) assert seq_len - pre_len == req.extend_input_len if pre_len > 0: @@ -1315,19 +1212,7 @@ def prepare_for_extend(self): ) ) - if self.return_logprob: - extend_input_logprob_token_ids = torch.tensor( - extend_input_logprob_token_ids - ) - else: - extend_input_logprob_token_ids = None - - self.input_ids = input_ids_tensor - self.req_pool_indices = req_pool_indices_tensor - self.seq_lens = seq_lens_tensor - self.seq_lens_cpu = seq_lens_cpu - self.orig_seq_lens = orig_seq_lens_tensor - self.out_cache_loc = out_cache_loc + # Assign processed data to batch self.input_embeds = ( torch.tensor(input_embeds).to(self.device, non_blocking=True) if input_embeds @@ -1341,21 +1226,20 @@ def prepare_for_extend(self): if isinstance(pixel_values, torch.Tensor): mm_item.feature = pixel_values.to(self.device, non_blocking=True) self.multimodal_inputs = multimodal_inputs - self.token_type_ids = token_type_ids_tensor - self.seq_lens_sum = sum(seq_lens) + + self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] if self.return_logprob: self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.token_ids_logprobs = [r.token_ids_logprob for r in reqs] - - self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] - self.extend_num_tokens = extend_num_tokens - self.prefix_lens = prefix_lens - self.extend_lens = extend_lens - self.extend_input_logprob_token_ids = extend_input_logprob_token_ids + self.extend_input_logprob_token_ids = torch.tensor( + extend_input_logprob_token_ids + ) + else: + self.extend_input_logprob_token_ids = None if self.model_config.is_encoder_decoder: - self.prepare_encoder_info_extend(input_ids, seq_lens) + self.prepare_encoder_info_extend(input_ids, self.seq_lens_cpu.tolist()) # Build sampling info self.sampling_info = SamplingBatchInfo.from_schedule_batch( @@ -1601,7 +1485,7 @@ def prepare_for_decode(self): ) # Allocate memory based on current state - self.out_cache_loc = alloc_for_decode(self, 1) # allocate 1 token per request + self.out_cache_loc = alloc_decode_batch(self, token_per_req=1) # Update seq_lens after allocation if self.enable_overlap: diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 4230a580bff5..f841df46cd1a 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -324,3 +324,142 @@ def alloc_req_slots( f"{num_reqs=}, " ) return req_pool_indices + + +def alloc_extend_batch(batch) -> tuple[torch.Tensor, torch.Tensor]: + """ + Allocate KV cache for extend batch and write to req_to_token_pool. + + Pure allocation function - no request traversal or computation. + All data must be pre-computed by caller and set on batch fields. + + This is a batch-level operation because: + - req_to_token_pool lives on GPU + - Allocation and write use triton kernels for efficiency + - Triton kernels operate on batched GPU tensors + + Expected batch fields (must be set before calling): + - prefix_lens: list[int] + - extend_lens: list[int] + - extend_num_tokens: int + - seq_lens_cpu: torch.Tensor + - prefix_indices_list: list[torch.Tensor] + - reqs: list[Req] (only for alloc_req_slots) + + Returns: + out_cache_loc: allocated KV cache locations + req_pool_indices: request pool indices (GPU tensor) + """ + from sglang.srt.managers.schedule_batch import ScheduleBatch + + assert isinstance(batch, ScheduleBatch) + bs = len(batch.reqs) + + # Convert batch data to tensors + prefix_lens_cpu = torch.tensor(batch.prefix_lens, dtype=torch.int64) + seq_lens_cpu = batch.seq_lens_cpu + extend_lens_cpu = seq_lens_cpu - prefix_lens_cpu + + # Copy to GPU + seq_lens_device = seq_lens_cpu.to(batch.device, non_blocking=True) + prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True) + extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True) + + # Allocate req slots + req_pool_indices = alloc_req_slots(batch.req_to_token_pool, bs, batch.reqs) + req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64) + req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True) + + # Allocate KV cache - abstract page_size branching + page_size = batch.token_to_kv_pool_allocator.page_size + if page_size == 1: + out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens) + else: + # Build last_loc from prefix_indices_list + last_loc = [ + ( + prefix_indices[-1:] + if len(prefix_indices) > 0 + else torch.tensor([-1], device=batch.device) + ) + for prefix_indices in batch.prefix_indices_list + ] + out_cache_loc = alloc_paged_token_slots_extend( + tree_cache=batch.tree_cache, + prefix_lens=prefix_lens_device, + prefix_lens_cpu=prefix_lens_cpu, + seq_lens=seq_lens_device, + seq_lens_cpu=seq_lens_cpu, + last_loc=torch.cat(last_loc), + extend_num_tokens=batch.extend_num_tokens, + ) + + # Write allocated locations to req_to_token_pool + write_cache_indices( + out_cache_loc, + req_pool_indices_device, + req_pool_indices_cpu, + prefix_lens_device, + prefix_lens_cpu, + seq_lens_device, + seq_lens_cpu, + extend_lens_device, + extend_lens_cpu, + batch.prefix_indices_list, + batch.req_to_token_pool, + ) + + # Update requests with pool indices + for req, req_pool_index in zip(batch.reqs, req_pool_indices): + req.req_pool_idx = req_pool_index + + return out_cache_loc, req_pool_indices_device + + +def alloc_decode_batch(batch, token_per_req: int = 1) -> torch.Tensor: + """ + Allocate KV cache for decode batch and write to req_to_token_pool. + + This is a batch-level operation for the same reasons as alloc_extend_batch. + + Args: + batch: ScheduleBatch with requests + token_per_req: number of tokens to allocate per request (usually 1) + + Returns: + out_cache_loc: allocated KV cache locations + """ + from sglang.srt.managers.schedule_batch import ScheduleBatch + + assert isinstance(batch, ScheduleBatch) + bs = len(batch.reqs) + seq_lens = batch.seq_lens + req_pool_indices = batch.req_pool_indices + + # Allocate KV cache - branch on page_size + page_size = batch.token_to_kv_pool_allocator.page_size + if page_size == 1: + # Non-paged allocation + out_cache_loc = alloc_token_slots(batch.tree_cache, bs * token_per_req) + else: + # Paged allocation + last_loc = batch.req_to_token_pool.req_to_token[req_pool_indices, seq_lens - 1] + seq_lens_next = seq_lens + token_per_req + out_cache_loc = alloc_paged_token_slots_decode( + batch.tree_cache, + seq_lens_next, + batch.seq_lens_cpu + token_per_req, + last_loc, + ) + + # Write to req_to_token_pool + if batch.model_config.is_encoder_decoder: + locs = batch.encoder_lens + seq_lens + else: + locs = seq_lens.clone() + + batch.req_to_token_pool.write( + (req_pool_indices, locs), out_cache_loc.to(torch.int32) + ) + + return out_cache_loc From 217fa5b6b29b27c2b58f2dcfc693910495a6ebd6 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Tue, 7 Oct 2025 17:47:21 -0700 Subject: [PATCH 13/20] wip --- python/sglang/srt/managers/schedule_batch.py | 33 ++-- python/sglang/srt/mem_cache/common.py | 165 +++++++++---------- 2 files changed, 100 insertions(+), 98 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 73d040ba04f4..d80b7da8c03e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -60,7 +60,7 @@ ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache -from sglang.srt.mem_cache.common import alloc_decode_batch, alloc_extend_batch +from sglang.srt.mem_cache.common import alloc_for_decode, alloc_for_extend from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache @@ -1109,10 +1109,25 @@ def prepare_for_extend(self): self.seq_lens_sum = self.seq_lens_cpu.sum().item() self.prefix_indices_list = [r.prefix_indices for r in reqs] - # Phase 2: Allocate memory - pure allocation using batch fields - out_cache_loc, req_pool_indices_tensor = alloc_extend_batch(self) + # Free memory before allocation + if isinstance(self.tree_cache, SWAChunkCache): + for req in reqs: + pre_len = len(req.prefix_indices) + if pre_len > 0: + self.tree_cache.evict_swa( + req, pre_len, self.model_config.attention_chunk_size + ) + + # Phase 2: Allocate memory and write to pool + out_cache_loc, req_pool_indices_device, req_pool_indices = alloc_for_extend( + self + ) self.out_cache_loc = out_cache_loc - self.req_pool_indices = req_pool_indices_tensor + self.req_pool_indices = req_pool_indices_device + + # Update requests with pool indices + for req, req_pool_idx in zip(reqs, req_pool_indices): + req.req_pool_idx = req_pool_idx # Phase 3: Process requests - validation, adjustment, logprob setup input_embeds = [] @@ -1124,12 +1139,6 @@ def prepare_for_extend(self): pre_len = len(req.prefix_indices) assert seq_len - pre_len == req.extend_input_len - if pre_len > 0: - if isinstance(self.tree_cache, SWAChunkCache): - self.tree_cache.evict_swa( - req, pre_len, self.model_config.attention_chunk_size - ) - # If input_embeds are available, store them if req.input_embeds is not None: # If req.input_embeds is already a list, append its content directly @@ -1484,8 +1493,8 @@ def prepare_for_decode(self): req, req.seqlen - 1, self.model_config.attention_chunk_size ) - # Allocate memory based on current state - self.out_cache_loc = alloc_decode_batch(self, token_per_req=1) + # Allocate memory + self.out_cache_loc = alloc_for_decode(self, token_per_req=1) # Update seq_lens after allocation if self.enable_overlap: diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index f841df46cd1a..399a88f5c9b1 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -13,7 +13,7 @@ from sglang.srt.utils import support_triton if TYPE_CHECKING: - from sglang.srt.managers.schedule_batch import Req + from sglang.srt.managers.schedule_batch import Req, ScheduleBatch logger = logging.getLogger(__name__) @@ -200,7 +200,7 @@ def alloc_token_slots( backup_state: bool = False, ): allocator = tree_cache.token_to_kv_pool_allocator - _evict_from_tree_cache(tree_cache, num_tokens) + evict_from_tree_cache(tree_cache, num_tokens) state = None if backup_state: @@ -214,7 +214,7 @@ def alloc_token_slots( return (out_cache_loc, state) if backup_state else out_cache_loc -def _evict_from_tree_cache(tree_cache: BasePrefixCache | None, num_tokens: int): +def evict_from_tree_cache(tree_cache: BasePrefixCache | None, num_tokens: int): """Helper to evict from tree cache if needed. Handles both hybrid and standard allocators.""" if tree_cache is None: return @@ -254,7 +254,7 @@ def alloc_paged_token_slots_decode( allocator = tree_cache.token_to_kv_pool_allocator # Over estimate the number of tokens: assume each request needs a new page. num_tokens = len(seq_lens) * allocator.page_size - _evict_from_tree_cache(tree_cache, num_tokens) + evict_from_tree_cache(tree_cache, num_tokens) state = None if backup_state: @@ -281,7 +281,7 @@ def alloc_paged_token_slots_extend( # Over estimate the number of tokens: assume each request needs a new page. allocator = tree_cache.token_to_kv_pool_allocator num_tokens = extend_num_tokens + len(seq_lens_cpu) * allocator.page_size - _evict_from_tree_cache(tree_cache, num_tokens) + evict_from_tree_cache(tree_cache, num_tokens) state = None if backup_state: @@ -326,125 +326,118 @@ def alloc_req_slots( return req_pool_indices -def alloc_extend_batch(batch) -> tuple[torch.Tensor, torch.Tensor]: +def alloc_kv_cache_for_extend( + tree_cache: BasePrefixCache, + token_to_kv_pool_allocator, + prefix_lens_device: torch.Tensor, + prefix_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + seq_lens_cpu: torch.Tensor, + prefix_indices_list: list[torch.Tensor], + extend_num_tokens: int, + device: str, +) -> torch.Tensor: """ - Allocate KV cache for extend batch and write to req_to_token_pool. - - Pure allocation function - no request traversal or computation. - All data must be pre-computed by caller and set on batch fields. + Allocate KV cache for extend batch. - This is a batch-level operation because: - - req_to_token_pool lives on GPU - - Allocation and write use triton kernels for efficiency - - Triton kernels operate on batched GPU tensors - - Expected batch fields (must be set before calling): - - prefix_lens: list[int] - - extend_lens: list[int] - - extend_num_tokens: int - - seq_lens_cpu: torch.Tensor - - prefix_indices_list: list[torch.Tensor] - - reqs: list[Req] (only for alloc_req_slots) + Pure allocation - no batch manipulation. + Handles page_size branching internally. Returns: out_cache_loc: allocated KV cache locations - req_pool_indices: request pool indices (GPU tensor) """ - from sglang.srt.managers.schedule_batch import ScheduleBatch - - assert isinstance(batch, ScheduleBatch) - bs = len(batch.reqs) + page_size = token_to_kv_pool_allocator.page_size - # Convert batch data to tensors - prefix_lens_cpu = torch.tensor(batch.prefix_lens, dtype=torch.int64) - seq_lens_cpu = batch.seq_lens_cpu - extend_lens_cpu = seq_lens_cpu - prefix_lens_cpu - - # Copy to GPU - seq_lens_device = seq_lens_cpu.to(batch.device, non_blocking=True) - prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True) - extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True) - - # Allocate req slots - req_pool_indices = alloc_req_slots(batch.req_to_token_pool, bs, batch.reqs) - req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64) - req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True) - - # Allocate KV cache - abstract page_size branching - page_size = batch.token_to_kv_pool_allocator.page_size if page_size == 1: - out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens) + return alloc_token_slots(tree_cache, extend_num_tokens) else: - # Build last_loc from prefix_indices_list + # Paged allocation - build last_loc last_loc = [ ( prefix_indices[-1:] if len(prefix_indices) > 0 - else torch.tensor([-1], device=batch.device) + else torch.tensor([-1], device=device) ) - for prefix_indices in batch.prefix_indices_list + for prefix_indices in prefix_indices_list ] - out_cache_loc = alloc_paged_token_slots_extend( - tree_cache=batch.tree_cache, + return alloc_paged_token_slots_extend( + tree_cache=tree_cache, prefix_lens=prefix_lens_device, prefix_lens_cpu=prefix_lens_cpu, seq_lens=seq_lens_device, seq_lens_cpu=seq_lens_cpu, last_loc=torch.cat(last_loc), - extend_num_tokens=batch.extend_num_tokens, + extend_num_tokens=extend_num_tokens, ) - # Write allocated locations to req_to_token_pool + +def alloc_for_extend( + batch: ScheduleBatch, +) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + """ + Allocate KV cache for extend batch and write to req_to_token_pool. + + Returns: + out_cache_loc: allocated cache locations + req_pool_indices_device: request pool indices on device + req_pool_indices: request pool indices as list + """ + bs = len(batch.reqs) + + # Create tensors for allocation + prefix_lens_cpu = torch.tensor(batch.prefix_lens, dtype=torch.int64) + extend_lens_cpu = torch.tensor(batch.extend_lens, dtype=torch.int64) + prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True) + extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True) + + # Allocate req slots + req_pool_indices = alloc_req_slots(batch.req_to_token_pool, bs, batch.reqs) + req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64) + req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True) + + # Allocate KV cache + out_cache_loc = alloc_kv_cache_for_extend( + tree_cache=batch.tree_cache, + token_to_kv_pool_allocator=batch.token_to_kv_pool_allocator, + prefix_lens_device=prefix_lens_device, + prefix_lens_cpu=prefix_lens_cpu, + seq_lens_device=batch.seq_lens, + seq_lens_cpu=batch.seq_lens_cpu, + prefix_indices_list=batch.prefix_indices_list, + extend_num_tokens=batch.extend_num_tokens, + device=batch.device, + ) + + # Write to req_to_token_pool write_cache_indices( out_cache_loc, req_pool_indices_device, req_pool_indices_cpu, prefix_lens_device, prefix_lens_cpu, - seq_lens_device, - seq_lens_cpu, + batch.seq_lens, + batch.seq_lens_cpu, extend_lens_device, extend_lens_cpu, batch.prefix_indices_list, batch.req_to_token_pool, ) - # Update requests with pool indices - for req, req_pool_index in zip(batch.reqs, req_pool_indices): - req.req_pool_idx = req_pool_index - - return out_cache_loc, req_pool_indices_device + return out_cache_loc, req_pool_indices_device, req_pool_indices -def alloc_decode_batch(batch, token_per_req: int = 1) -> torch.Tensor: - """ - Allocate KV cache for decode batch and write to req_to_token_pool. - - This is a batch-level operation for the same reasons as alloc_extend_batch. +def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor: + bs = batch.seq_lens.shape[0] - Args: - batch: ScheduleBatch with requests - token_per_req: number of tokens to allocate per request (usually 1) - - Returns: - out_cache_loc: allocated KV cache locations - """ - from sglang.srt.managers.schedule_batch import ScheduleBatch - - assert isinstance(batch, ScheduleBatch) - bs = len(batch.reqs) - seq_lens = batch.seq_lens - req_pool_indices = batch.req_pool_indices - - # Allocate KV cache - branch on page_size - page_size = batch.token_to_kv_pool_allocator.page_size - if page_size == 1: + if batch.tree_cache.page_size == 1: # Non-paged allocation out_cache_loc = alloc_token_slots(batch.tree_cache, bs * token_per_req) else: # Paged allocation - last_loc = batch.req_to_token_pool.req_to_token[req_pool_indices, seq_lens - 1] - seq_lens_next = seq_lens + token_per_req + last_loc = batch.req_to_token_pool.req_to_token[ + batch.req_pool_indices, batch.seq_lens - 1 + ] + seq_lens_next = batch.seq_lens + token_per_req out_cache_loc = alloc_paged_token_slots_decode( batch.tree_cache, seq_lens_next, @@ -454,12 +447,12 @@ def alloc_decode_batch(batch, token_per_req: int = 1) -> torch.Tensor: # Write to req_to_token_pool if batch.model_config.is_encoder_decoder: - locs = batch.encoder_lens + seq_lens + locs = batch.encoder_lens + batch.seq_lens else: - locs = seq_lens.clone() + locs = batch.seq_lens.clone() batch.req_to_token_pool.write( - (req_pool_indices, locs), out_cache_loc.to(torch.int32) + (batch.req_pool_indices, locs), out_cache_loc.to(torch.int32) ) return out_cache_loc From ba41c12a681bbbdc6a9a9130c7e995042f80cb54 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Tue, 7 Oct 2025 18:34:22 -0700 Subject: [PATCH 14/20] clean --- python/sglang/srt/managers/schedule_batch.py | 3 -- python/sglang/srt/mem_cache/common.py | 38 ++++++++------------ python/sglang/srt/speculative/ngram_utils.py | 11 +++--- 3 files changed, 20 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2153cbfb58ff..f1321b663e2d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -920,8 +920,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): extend_logprob_start_lens: List[int] = None # It comes empty list if logprob is not required. extend_input_logprob_token_ids: Optional[torch.Tensor] = None - # For allocation - list of prefix_indices tensors per request - prefix_indices_list: Optional[List[torch.Tensor]] = None # For encoder-decoder architectures encoder_cached: Optional[List[bool]] = None @@ -1109,7 +1107,6 @@ def prepare_for_extend(self): ) self.seq_lens = self.seq_lens_cpu.to(self.device, non_blocking=True) self.seq_lens_sum = self.seq_lens_cpu.sum().item() - self.prefix_indices_list = [r.prefix_indices for r in reqs] # Free memory before allocation if isinstance(self.tree_cache, SWAChunkCache): diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 399a88f5c9b1..8949a1283557 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -88,7 +88,8 @@ def write_cache_indices( ): if support_triton(global_server_args_dict.get("attention_backend")): prefix_pointers = torch.tensor( - [t.data_ptr() for t in prefix_tensors], device=req_to_token_pool.device + [t.data_ptr() for t in prefix_tensors], + device=req_to_token_pool.device, ) # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) write_req_to_token_pool_triton[(req_pool_indices_tensor.shape[0],)]( @@ -328,14 +329,12 @@ def alloc_req_slots( def alloc_kv_cache_for_extend( tree_cache: BasePrefixCache, - token_to_kv_pool_allocator, prefix_lens_device: torch.Tensor, prefix_lens_cpu: torch.Tensor, seq_lens_device: torch.Tensor, seq_lens_cpu: torch.Tensor, - prefix_indices_list: list[torch.Tensor], + prefix_tensors: list[torch.Tensor], extend_num_tokens: int, - device: str, ) -> torch.Tensor: """ Allocate KV cache for extend batch. @@ -346,19 +345,13 @@ def alloc_kv_cache_for_extend( Returns: out_cache_loc: allocated KV cache locations """ - page_size = token_to_kv_pool_allocator.page_size - - if page_size == 1: + if tree_cache.page_size == 1: return alloc_token_slots(tree_cache, extend_num_tokens) else: # Paged allocation - build last_loc last_loc = [ - ( - prefix_indices[-1:] - if len(prefix_indices) > 0 - else torch.tensor([-1], device=device) - ) - for prefix_indices in prefix_indices_list + (t[-1:] if len(t) > 0 else torch.tensor([-1], device=tree_cache.device)) + for t in prefix_tensors ] return alloc_paged_token_slots_extend( tree_cache=tree_cache, @@ -383,6 +376,7 @@ def alloc_for_extend( req_pool_indices: request pool indices as list """ bs = len(batch.reqs) + prefix_tensors = [r.prefix_indices for r in batch.reqs] # Create tensors for allocation prefix_lens_cpu = torch.tensor(batch.prefix_lens, dtype=torch.int64) @@ -397,15 +391,13 @@ def alloc_for_extend( # Allocate KV cache out_cache_loc = alloc_kv_cache_for_extend( - tree_cache=batch.tree_cache, - token_to_kv_pool_allocator=batch.token_to_kv_pool_allocator, - prefix_lens_device=prefix_lens_device, - prefix_lens_cpu=prefix_lens_cpu, - seq_lens_device=batch.seq_lens, - seq_lens_cpu=batch.seq_lens_cpu, - prefix_indices_list=batch.prefix_indices_list, - extend_num_tokens=batch.extend_num_tokens, - device=batch.device, + batch.tree_cache, + prefix_lens_device, + prefix_lens_cpu, + batch.seq_lens, + batch.seq_lens_cpu, + prefix_tensors, + batch.extend_num_tokens, ) # Write to req_to_token_pool @@ -419,7 +411,7 @@ def alloc_for_extend( batch.seq_lens_cpu, extend_lens_device, extend_lens_cpu, - batch.prefix_indices_list, + prefix_tensors, batch.req_to_token_pool, ) diff --git a/python/sglang/srt/speculative/ngram_utils.py b/python/sglang/srt/speculative/ngram_utils.py index d9f8f369182c..ce4557b89b5a 100644 --- a/python/sglang/srt/speculative/ngram_utils.py +++ b/python/sglang/srt/speculative/ngram_utils.py @@ -7,11 +7,6 @@ import torch import triton -from sglang.srt.mem_cache.common import ( - alloc_paged_token_slots_extend, - alloc_token_slots, -) - logger = logging.getLogger(__name__) from dataclasses import dataclass @@ -22,7 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict -from sglang.srt.mem_cache.common import get_last_loc +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, + get_last_loc, +) from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_utils import ( From d4984646e8f0fcfa370065aef8b4b754444f5769 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Tue, 7 Oct 2025 18:47:15 -0700 Subject: [PATCH 15/20] clean --- python/sglang/srt/managers/schedule_batch.py | 109 ++++++++++--------- 1 file changed, 55 insertions(+), 54 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f1321b663e2d..13a598215b99 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1073,71 +1073,64 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) def prepare_for_extend(self): self.forward_mode = ForwardMode.EXTEND - reqs = self.reqs - # Phase 1: Prepare tensors and assign to batch + # Init tensors + reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] - self.input_ids = torch.tensor( - list(chain.from_iterable(input_ids)), dtype=torch.int64 - ).to(self.device, non_blocking=True) - - self.orig_seq_lens = torch.tensor( - [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs], - dtype=torch.int32, - ).to(self.device, non_blocking=True) + extend_num_tokens = sum(len(ids) for ids in input_ids) + seq_lens = [len(r.fill_ids) for r in reqs] + orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs] + prefix_lens = [len(r.prefix_indices) for r in reqs] + extend_lens = [r.extend_input_len for r in reqs] token_type_ids = [ r.token_type_ids for r in reqs if r.token_type_ids is not None ] - if len(token_type_ids) > 0: - self.token_type_ids = torch.tensor( - sum(token_type_ids, []), dtype=torch.int64 - ).to(self.device, non_blocking=True) - else: - self.token_type_ids = None - # Batch-level data for allocation - self.prefix_lens = [len(r.prefix_indices) for r in reqs] - self.extend_lens = [r.extend_input_len for r in reqs] - self.extend_num_tokens = sum( - len(r.fill_ids) - len(r.prefix_indices) for r in reqs + input_ids_tensor = torch.tensor( + list(chain.from_iterable(input_ids)), dtype=torch.int64 + ).to(self.device, non_blocking=True) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( + self.device, non_blocking=True ) - self.seq_lens_cpu = torch.tensor( - [len(r.fill_ids) for r in reqs], dtype=torch.int64 + seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64) + orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to( + self.device, non_blocking=True ) - self.seq_lens = self.seq_lens_cpu.to(self.device, non_blocking=True) - self.seq_lens_sum = self.seq_lens_cpu.sum().item() - # Free memory before allocation - if isinstance(self.tree_cache, SWAChunkCache): - for req in reqs: - pre_len = len(req.prefix_indices) - if pre_len > 0: - self.tree_cache.evict_swa( - req, pre_len, self.model_config.attention_chunk_size - ) + token_type_ids_tensor = None + if len(token_type_ids) > 0: + token_type_ids_tensor = torch.tensor( + sum(token_type_ids, []), dtype=torch.int64 + ).to(self.device, non_blocking=True) - # Phase 2: Allocate memory and write to pool - out_cache_loc, req_pool_indices_device, req_pool_indices = alloc_for_extend( + # Set batch fields needed by alloc_for_extend + self.prefix_lens = prefix_lens + self.extend_lens = extend_lens + self.seq_lens = seq_lens_tensor + self.seq_lens_cpu = seq_lens_cpu + self.extend_num_tokens = extend_num_tokens + + # Allocate memory + out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend( self ) - self.out_cache_loc = out_cache_loc - self.req_pool_indices = req_pool_indices_device - # Update requests with pool indices - for req, req_pool_idx in zip(reqs, req_pool_indices): - req.req_pool_idx = req_pool_idx - - # Phase 3: Process requests - validation, adjustment, logprob setup + # Set fields input_embeds = [] extend_input_logprob_token_ids = [] multimodal_inputs = [] - for req in reqs: - seq_len = len(req.fill_ids) - pre_len = len(req.prefix_indices) + for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)): + req.req_pool_idx = req_pool_indices[i] assert seq_len - pre_len == req.extend_input_len + if pre_len > 0: + if isinstance(self.tree_cache, SWAChunkCache): + self.tree_cache.evict_swa( + req, pre_len, self.model_config.attention_chunk_size + ) + # If input_embeds are available, store them if req.input_embeds is not None: # If req.input_embeds is already a list, append its content directly @@ -1220,7 +1213,17 @@ def prepare_for_extend(self): ) ) - # Assign processed data to batch + if self.return_logprob: + extend_input_logprob_token_ids = torch.tensor( + extend_input_logprob_token_ids + ) + else: + extend_input_logprob_token_ids = None + + self.input_ids = input_ids_tensor + self.req_pool_indices = req_pool_indices_tensor + self.orig_seq_lens = orig_seq_lens_tensor + self.out_cache_loc = out_cache_loc self.input_embeds = ( torch.tensor(input_embeds).to(self.device, non_blocking=True) if input_embeds @@ -1234,20 +1237,18 @@ def prepare_for_extend(self): if isinstance(pixel_values, torch.Tensor): mm_item.feature = pixel_values.to(self.device, non_blocking=True) self.multimodal_inputs = multimodal_inputs - - self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] + self.token_type_ids = token_type_ids_tensor + self.seq_lens_sum = sum(seq_lens) if self.return_logprob: self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.token_ids_logprobs = [r.token_ids_logprob for r in reqs] - self.extend_input_logprob_token_ids = torch.tensor( - extend_input_logprob_token_ids - ) - else: - self.extend_input_logprob_token_ids = None + + self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] + self.extend_input_logprob_token_ids = extend_input_logprob_token_ids if self.model_config.is_encoder_decoder: - self.prepare_encoder_info_extend(input_ids, self.seq_lens_cpu.tolist()) + self.prepare_encoder_info_extend(input_ids, seq_lens) # Build sampling info self.sampling_info = SamplingBatchInfo.from_schedule_batch( From 4e0bb42f6e3b406021d22d86f5b8af6cfc940921 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Tue, 7 Oct 2025 18:59:06 -0700 Subject: [PATCH 16/20] clean --- python/sglang/srt/mem_cache/common.py | 45 ++++++++------------------- 1 file changed, 13 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 8949a1283557..1c07f5873466 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -216,7 +216,6 @@ def alloc_token_slots( def evict_from_tree_cache(tree_cache: BasePrefixCache | None, num_tokens: int): - """Helper to evict from tree cache if needed. Handles both hybrid and standard allocators.""" if tree_cache is None: return @@ -244,31 +243,6 @@ def evict_from_tree_cache(tree_cache: BasePrefixCache | None, num_tokens: int): tree_cache.evict(num_tokens) -def alloc_paged_token_slots_decode( - tree_cache: BasePrefixCache, - seq_lens: torch.Tensor, - seq_lens_cpu: torch.Tensor, - last_loc: torch.Tensor, - backup_state: bool = False, -): - """Allocate paged token slots for decode with overestimation for safety.""" - allocator = tree_cache.token_to_kv_pool_allocator - # Over estimate the number of tokens: assume each request needs a new page. - num_tokens = len(seq_lens) * allocator.page_size - evict_from_tree_cache(tree_cache, num_tokens) - - state = None - if backup_state: - state = allocator.backup_state() - - out_cache_loc = allocator.alloc_decode(seq_lens, seq_lens_cpu, last_loc) - - if out_cache_loc is None: - return None - - return (out_cache_loc, state) if backup_state else out_cache_loc - - def alloc_paged_token_slots_extend( tree_cache: BasePrefixCache, prefix_lens: torch.Tensor, @@ -372,7 +346,7 @@ def alloc_for_extend( Returns: out_cache_loc: allocated cache locations - req_pool_indices_device: request pool indices on device + req_pool_indices_device: request pool indices at a device tensor req_pool_indices: request pool indices as list """ bs = len(batch.reqs) @@ -419,6 +393,12 @@ def alloc_for_extend( def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor: + """ + Allocate KV cache for decode batch and write to req_to_token_pool. + + Returns: + out_cache_loc: allocated cache locations + """ bs = batch.seq_lens.shape[0] if batch.tree_cache.page_size == 1: @@ -430,11 +410,12 @@ def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor: batch.req_pool_indices, batch.seq_lens - 1 ] seq_lens_next = batch.seq_lens + token_per_req - out_cache_loc = alloc_paged_token_slots_decode( - batch.tree_cache, - seq_lens_next, - batch.seq_lens_cpu + token_per_req, - last_loc, + allocator = batch.tree_cache.token_to_kv_pool_allocator + # Over estimate the number of tokens: assume each request needs a new page. + num_tokens = len(seq_lens_next) * allocator.page_size + evict_from_tree_cache(batch.tree_cache, num_tokens) + out_cache_loc = allocator.alloc_decode( + seq_lens_next, batch.seq_lens_cpu + token_per_req, last_loc ) # Write to req_to_token_pool From 2e8d1a43984e68076180667336484767e64e4a66 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Wed, 8 Oct 2025 11:40:42 -0700 Subject: [PATCH 17/20] fix benchmark --- python/sglang/bench_one_batch.py | 10 ++- python/sglang/srt/mem_cache/common.py | 67 ++++++------------- python/sglang/srt/speculative/eagle_worker.py | 1 + test/srt/test_forward_split_prefill.py | 10 ++- 4 files changed, 40 insertions(+), 48 deletions(-) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index f8a35266bf6f..97bff98ab48a 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -51,6 +51,7 @@ import multiprocessing import os import time +from types import SimpleNamespace from typing import Tuple import numpy as np @@ -257,11 +258,18 @@ def prepare_synthetic_inputs_for_latency_test( @torch.no_grad def extend(reqs, model_runner): + # Create dummy tree_cache for benchmarks (no prefix caching, just allocation) + dummy_tree_cache = SimpleNamespace( + page_size=1, + device=model_runner.device, + token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, + ) + batch = ScheduleBatch.init_new( reqs=reqs, req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, - tree_cache=None, + tree_cache=dummy_tree_cache, model_config=model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 1c07f5873466..83a1019ab712 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -301,43 +301,6 @@ def alloc_req_slots( return req_pool_indices -def alloc_kv_cache_for_extend( - tree_cache: BasePrefixCache, - prefix_lens_device: torch.Tensor, - prefix_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - seq_lens_cpu: torch.Tensor, - prefix_tensors: list[torch.Tensor], - extend_num_tokens: int, -) -> torch.Tensor: - """ - Allocate KV cache for extend batch. - - Pure allocation - no batch manipulation. - Handles page_size branching internally. - - Returns: - out_cache_loc: allocated KV cache locations - """ - if tree_cache.page_size == 1: - return alloc_token_slots(tree_cache, extend_num_tokens) - else: - # Paged allocation - build last_loc - last_loc = [ - (t[-1:] if len(t) > 0 else torch.tensor([-1], device=tree_cache.device)) - for t in prefix_tensors - ] - return alloc_paged_token_slots_extend( - tree_cache=tree_cache, - prefix_lens=prefix_lens_device, - prefix_lens_cpu=prefix_lens_cpu, - seq_lens=seq_lens_device, - seq_lens_cpu=seq_lens_cpu, - last_loc=torch.cat(last_loc), - extend_num_tokens=extend_num_tokens, - ) - - def alloc_for_extend( batch: ScheduleBatch, ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: @@ -364,15 +327,27 @@ def alloc_for_extend( req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True) # Allocate KV cache - out_cache_loc = alloc_kv_cache_for_extend( - batch.tree_cache, - prefix_lens_device, - prefix_lens_cpu, - batch.seq_lens, - batch.seq_lens_cpu, - prefix_tensors, - batch.extend_num_tokens, - ) + if batch.tree_cache.page_size == 1: + out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens) + else: + # Paged allocation - build last_loc + last_loc = [ + ( + t[-1:] + if len(t) > 0 + else torch.tensor([-1], device=batch.tree_cache.device) + ) + for t in prefix_tensors + ] + out_cache_loc = alloc_paged_token_slots_extend( + tree_cache=batch.tree_cache, + prefix_lens=prefix_lens_device, + prefix_lens_cpu=prefix_lens_cpu, + seq_lens=batch.seq_lens, + seq_lens_cpu=batch.seq_lens_cpu, + last_loc=torch.cat(last_loc), + extend_num_tokens=batch.extend_num_tokens, + ) # Write to req_to_token_pool write_cache_indices( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 5276788718fc..162ce53ecf43 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -15,6 +15,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict +from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.mem_cache.common import ( alloc_paged_token_slots_extend, diff --git a/test/srt/test_forward_split_prefill.py b/test/srt/test_forward_split_prefill.py index 314e35ec9724..4ca3c12fe0d8 100644 --- a/test/srt/test_forward_split_prefill.py +++ b/test/srt/test_forward_split_prefill.py @@ -8,6 +8,7 @@ """ import unittest +from types import SimpleNamespace import numpy as np import torch @@ -95,11 +96,18 @@ def prepare_test_batch(self, batch_size=2, input_len=128, is_split_prefill=True) req.logprob_start_len = len(req.origin_input_ids) - 1 reqs.append(req) + # Create dummy tree_cache for tests (no prefix caching, just allocation) + dummy_tree_cache = SimpleNamespace( + page_size=1, + device=self.model_runner.device, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + ) + batch = ScheduleBatch.init_new( reqs=reqs, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, - tree_cache=None, + tree_cache=dummy_tree_cache, model_config=self.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, From 987a3dd06542670a452b5e2d70d411e97e7b0d72 Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Wed, 8 Oct 2025 13:23:24 -0700 Subject: [PATCH 18/20] fix hybrid --- python/sglang/srt/mem_cache/common.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 83a1019ab712..e59cafce83a5 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -283,13 +283,10 @@ def alloc_req_slots( reqs: list[Req] | None, ) -> list[int]: """Allocate request slots from the pool.""" - match req_to_token_pool: - case ReqToTokenPool(): - req_pool_indices = req_to_token_pool.alloc(num_reqs) - case HybridReqToTokenPool(): - req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs) - case _: - raise ValueError(f"Unknown {type(req_to_token_pool)=}") + if isinstance(req_to_token_pool, HybridReqToTokenPool): + req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs) + else: + req_pool_indices = req_to_token_pool.alloc(num_reqs) if req_pool_indices is None: raise RuntimeError( From bd22a188dea5b656e0b91071e42a334b7c5cd2ae Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Fri, 10 Oct 2025 14:10:44 -0700 Subject: [PATCH 19/20] update error msg --- python/sglang/srt/managers/schedule_batch.py | 17 ---- python/sglang/srt/mem_cache/common.py | 83 +++++++++++++++++--- 2 files changed, 74 insertions(+), 26 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 13a598215b99..d0dae17dacf4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1729,23 +1729,6 @@ def _is_available_size_sufficient(self, num_tokens: int) -> bool: else: return self.token_to_kv_pool_allocator.available_size() >= num_tokens - def _available_and_evictable_str(self) -> str: - if self.is_hybrid: - full_available_size = self.token_to_kv_pool_allocator.full_available_size() - swa_available_size = self.token_to_kv_pool_allocator.swa_available_size() - full_evictable_size = self.tree_cache.full_evictable_size() - swa_evictable_size = self.tree_cache.swa_evictable_size() - return ( - f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n" - f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n" - f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n" - f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n" - ) - else: - available_size = self.token_to_kv_pool_allocator.available_size() - evictable_size = self.tree_cache.evictable_size() - return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n" - def __str__(self): return ( f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index e59cafce83a5..f0754678947d 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -7,6 +7,7 @@ import triton import triton.language as tl +from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.server_args import ServerArgs @@ -210,7 +211,15 @@ def alloc_token_slots( out_cache_loc = allocator.alloc(num_tokens) if out_cache_loc is None: - return None + error_msg = ( + f"Out of memory. Try to lower your batch size.\n" + f"Try to allocate {num_tokens} tokens.\n" + f"{available_and_evictable_str(tree_cache)}" + ) + logger.error(error_msg) + if tree_cache is not None: + tree_cache.pretty_print() + raise RuntimeError(error_msg) return (out_cache_loc, state) if backup_state else out_cache_loc @@ -272,7 +281,15 @@ def alloc_paged_token_slots_extend( ) if out_cache_loc is None: - return None + error_msg = ( + f"Prefill out of memory. Try to lower your batch size.\n" + f"Try to allocate {extend_num_tokens} tokens.\n" + f"{available_and_evictable_str(tree_cache)}" + ) + logger.error(error_msg) + if tree_cache is not None: + tree_cache.pretty_print() + raise RuntimeError(error_msg) return (out_cache_loc, state) if backup_state else out_cache_loc @@ -323,7 +340,7 @@ def alloc_for_extend( req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64) req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True) - # Allocate KV cache + # Allocate KV cache (throws exception on failure) if batch.tree_cache.page_size == 1: out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens) else: @@ -364,6 +381,35 @@ def alloc_for_extend( return out_cache_loc, req_pool_indices_device, req_pool_indices +def alloc_paged_token_slots_decode( + tree_cache: BasePrefixCache, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, + token_per_req: int = 1, +) -> torch.Tensor: + """Allocate paged KV cache for decode batch.""" + allocator = tree_cache.token_to_kv_pool_allocator + # Over estimate the number of tokens: assume each request needs a new page. + num_tokens = len(seq_lens) * allocator.page_size + evict_from_tree_cache(tree_cache, num_tokens) + + out_cache_loc = allocator.alloc_decode(seq_lens, seq_lens_cpu, last_loc) + + if out_cache_loc is None: + error_msg = ( + f"Decode out of memory. Try to lower your batch size.\n" + f"Try to allocate {len(seq_lens) * token_per_req} tokens.\n" + f"{available_and_evictable_str(tree_cache)}" + ) + logger.error(error_msg) + if tree_cache is not None: + tree_cache.pretty_print() + raise RuntimeError(error_msg) + + return out_cache_loc + + def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor: """ Allocate KV cache for decode batch and write to req_to_token_pool. @@ -382,12 +428,12 @@ def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor: batch.req_pool_indices, batch.seq_lens - 1 ] seq_lens_next = batch.seq_lens + token_per_req - allocator = batch.tree_cache.token_to_kv_pool_allocator - # Over estimate the number of tokens: assume each request needs a new page. - num_tokens = len(seq_lens_next) * allocator.page_size - evict_from_tree_cache(batch.tree_cache, num_tokens) - out_cache_loc = allocator.alloc_decode( - seq_lens_next, batch.seq_lens_cpu + token_per_req, last_loc + out_cache_loc = alloc_paged_token_slots_decode( + tree_cache=batch.tree_cache, + seq_lens=seq_lens_next, + seq_lens_cpu=batch.seq_lens_cpu + token_per_req, + last_loc=last_loc, + token_per_req=token_per_req, ) # Write to req_to_token_pool @@ -401,3 +447,22 @@ def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor: ) return out_cache_loc + + +def available_and_evictable_str(tree_cache) -> str: + token_to_kv_pool_allocator = tree_cache.token_to_kv_pool_allocator + if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator): + full_available_size = token_to_kv_pool_allocator.full_available_size() + swa_available_size = token_to_kv_pool_allocator.swa_available_size() + full_evictable_size = tree_cache.full_evictable_size() + swa_evictable_size = tree_cache.swa_evictable_size() + return ( + f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n" + f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n" + f"Full LRU list evictable size: {tree_cache.full_lru_list_evictable_size()}\n" + f"SWA LRU list evictable size: {tree_cache.swa_lru_list_evictable_size()}\n" + ) + else: + available_size = token_to_kv_pool_allocator.available_size() + evictable_size = tree_cache.evictable_size() + return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n" From 563ba3a37ab4450ef5d0dcf5269a3b287ce45bbd Mon Sep 17 00:00:00 2001 From: Shiyang Chen Date: Fri, 10 Oct 2025 14:33:45 -0700 Subject: [PATCH 20/20] swa --- python/sglang/srt/managers/schedule_batch.py | 13 ------------- python/sglang/srt/mem_cache/common.py | 17 ++++++++++++++--- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index d0dae17dacf4..b9e93c53e692 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1125,12 +1125,6 @@ def prepare_for_extend(self): req.req_pool_idx = req_pool_indices[i] assert seq_len - pre_len == req.extend_input_len - if pre_len > 0: - if isinstance(self.tree_cache, SWAChunkCache): - self.tree_cache.evict_swa( - req, pre_len, self.model_config.attention_chunk_size - ) - # If input_embeds are available, store them if req.input_embeds is not None: # If req.input_embeds is already a list, append its content directly @@ -1486,13 +1480,6 @@ def prepare_for_decode(self): if self.model_config.is_encoder_decoder: self.prepare_encoder_info_decode() - # free memory - if isinstance(self.tree_cache, SWAChunkCache): - for req in self.reqs: - self.tree_cache.evict_swa( - req, req.seqlen - 1, self.model_config.attention_chunk_size - ) - # Allocate memory self.out_cache_loc = alloc_for_decode(self, token_per_req=1) diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index f0754678947d..040bc45bf9b0 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -9,6 +9,7 @@ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.server_args import ServerArgs from sglang.srt.utils import support_triton @@ -228,9 +229,6 @@ def evict_from_tree_cache(tree_cache: BasePrefixCache | None, num_tokens: int): if tree_cache is None: return - # Check if this is ChunkCache or SWAChunkCache - these don't support eviction - from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache - if isinstance(tree_cache, (SWAChunkCache, ChunkCache)): return @@ -326,6 +324,13 @@ def alloc_for_extend( req_pool_indices_device: request pool indices at a device tensor req_pool_indices: request pool indices as list """ + # free out-of-window swa tokens + if isinstance(batch.tree_cache, SWAChunkCache): + for req, pre_len in zip(batch.reqs, batch.prefix_lens): + batch.tree_cache.evict_swa( + req, pre_len, batch.model_config.attention_chunk_size + ) + bs = len(batch.reqs) prefix_tensors = [r.prefix_indices for r in batch.reqs] @@ -417,6 +422,12 @@ def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor: Returns: out_cache_loc: allocated cache locations """ + if isinstance(batch.tree_cache, SWAChunkCache): + for req in batch.reqs: + batch.tree_cache.evict_swa( + req, req.seqlen - 1, batch.model_config.attention_chunk_size + ) + bs = batch.seq_lens.shape[0] if batch.tree_cache.page_size == 1: