diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 0c178204a835..aa345da63f60 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1873,6 +1873,9 @@ def new_tokens_required_next_decode( new_pages = sum(1 for r in requests if r.kv_committed_len % page_size == 0) return new_pages * page_size + if self.is_spec_v2: + return self._new_tokens_required_next_decode_spec_v2(requests, page_size) + server_args = get_global_server_args() len_per_topk = server_args.speculative_num_steps or 1 spec_topk = server_args.speculative_eagle_topk or 1 @@ -1888,9 +1891,20 @@ def new_tokens_required_next_decode( spec_tokens = ceil_align(spec_tokens, page_size) num_tokens = max(len_per_topk * spec_topk, spec_tokens) * len(requests) - - # v2 eagle has over-allocation - return num_tokens * (1 + self.is_spec_v2) + return num_tokens + + def _new_tokens_required_next_decode_spec_v2(self, requests, page_size): + """Tight estimate matching eagle_info_v2.prepare_for_decode allocation.""" + from sglang.srt.managers.utils import get_alloc_len_per_decode + + alloc_len = get_alloc_len_per_decode() + total = 0 + for r in requests: + x = max(0, r.kv_committed_len + 2 * alloc_len - r.kv_allocated_len) + cur = r.kv_allocated_len + nxt = cur + x + total += ceil_align(nxt, page_size) - ceil_align(cur, page_size) + return total def check_decode_mem(self, selected_indices: Optional[List[int]] = None): num_tokens = self.new_tokens_required_next_decode(selected_indices) diff --git a/python/sglang/srt/mem_cache/hisparse_memory_pool.py b/python/sglang/srt/mem_cache/hisparse_memory_pool.py index 5af8d257ad6b..76c2284eb8f3 100644 --- a/python/sglang/srt/mem_cache/hisparse_memory_pool.py +++ b/python/sglang/srt/mem_cache/hisparse_memory_pool.py @@ -12,6 +12,7 @@ ) from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool from sglang.srt.utils import is_cuda, is_hip +from sglang.srt.utils.common import get_num_new_pages # sgl_kernel.kvcacheio is only available in CUDA/ROCm sgl-kernel builds (not XPU/MPS/NPU/CPU). _is_cuda = is_cuda() @@ -246,9 +247,19 @@ def alloc_extend( extend_num_tokens: int, ): assert self.page_size > 1 - num_tokens = extend_num_tokens + len(seq_lens) * self.page_size - if num_tokens > self.available_size(): + num_new_pages = get_num_new_pages( + seq_lens=seq_lens_cpu, page_size=self.page_size, prefix_lens=prefix_lens_cpu + ) + if ( + num_new_pages + > self.logical_attn_allocator.available_size() // self.page_size + ): + return None + if ( + num_new_pages + > self.hisparse_attn_allocator.available_size() // self.page_size + ): return None logical_indices = self.logical_attn_allocator.alloc_extend( diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index 96b0e3844914..e89b501edc66 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -13,6 +13,7 @@ from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool from sglang.srt.mem_cache.utils import maybe_init_custom_mem_pool from sglang.srt.utils import is_npu +from sglang.srt.utils.common import get_num_new_pages _is_npu = is_npu() @@ -377,10 +378,13 @@ def alloc_extend( extend_num_tokens: int, ): assert self.page_size > 1 - num_tokens = extend_num_tokens + len(seq_lens) * self.page_size - if num_tokens > self.full_attn_allocator.available_size(): + + num_new_pages = get_num_new_pages( + seq_lens=seq_lens_cpu, page_size=self.page_size, prefix_lens=prefix_lens_cpu + ) + if num_new_pages > self.full_attn_allocator.available_size() // self.page_size: return None - if num_tokens > self.swa_attn_allocator.available_size(): + if num_new_pages > self.swa_attn_allocator.available_size() // self.page_size: return None swa_last_loc = self.translate_loc_from_full_to_swa(last_loc)