Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,6 +1873,9 @@ def new_tokens_required_next_decode(
new_pages = sum(1 for r in requests if r.kv_committed_len % page_size == 0)
return new_pages * page_size

if self.is_spec_v2:
return self._new_tokens_required_next_decode_spec_v2(requests, page_size)

server_args = get_global_server_args()
len_per_topk = server_args.speculative_num_steps or 1
spec_topk = server_args.speculative_eagle_topk or 1
Expand All @@ -1888,9 +1891,20 @@ def new_tokens_required_next_decode(
spec_tokens = ceil_align(spec_tokens, page_size)

num_tokens = max(len_per_topk * spec_topk, spec_tokens) * len(requests)

# v2 eagle has over-allocation
return num_tokens * (1 + self.is_spec_v2)
return num_tokens

def _new_tokens_required_next_decode_spec_v2(self, requests, page_size):
"""Tight estimate matching eagle_info_v2.prepare_for_decode allocation."""
from sglang.srt.managers.utils import get_alloc_len_per_decode

alloc_len = get_alloc_len_per_decode()
total = 0
for r in requests:
x = max(0, r.kv_committed_len + 2 * alloc_len - r.kv_allocated_len)
cur = r.kv_allocated_len
nxt = cur + x
total += ceil_align(nxt, page_size) - ceil_align(cur, page_size)
return total

def check_decode_mem(self, selected_indices: Optional[List[int]] = None):
num_tokens = self.new_tokens_required_next_decode(selected_indices)
Expand Down
15 changes: 13 additions & 2 deletions python/sglang/srt/mem_cache/hisparse_memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
from sglang.srt.utils import is_cuda, is_hip
from sglang.srt.utils.common import get_num_new_pages

# sgl_kernel.kvcacheio is only available in CUDA/ROCm sgl-kernel builds (not XPU/MPS/NPU/CPU).
_is_cuda = is_cuda()
Expand Down Expand Up @@ -246,9 +247,19 @@ def alloc_extend(
extend_num_tokens: int,
):
assert self.page_size > 1
num_tokens = extend_num_tokens + len(seq_lens) * self.page_size

if num_tokens > self.available_size():
num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu, page_size=self.page_size, prefix_lens=prefix_lens_cpu
)
if (
num_new_pages
> self.logical_attn_allocator.available_size() // self.page_size
):
return None
if (
num_new_pages
> self.hisparse_attn_allocator.available_size() // self.page_size
):
return None

logical_indices = self.logical_attn_allocator.alloc_extend(
Expand Down
10 changes: 7 additions & 3 deletions python/sglang/srt/mem_cache/swa_memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool
from sglang.srt.mem_cache.utils import maybe_init_custom_mem_pool
from sglang.srt.utils import is_npu
from sglang.srt.utils.common import get_num_new_pages

_is_npu = is_npu()

Expand Down Expand Up @@ -377,10 +378,13 @@ def alloc_extend(
extend_num_tokens: int,
):
assert self.page_size > 1
num_tokens = extend_num_tokens + len(seq_lens) * self.page_size
if num_tokens > self.full_attn_allocator.available_size():

num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu, page_size=self.page_size, prefix_lens=prefix_lens_cpu
)
if num_new_pages > self.full_attn_allocator.available_size() // self.page_size:
return None
if num_tokens > self.swa_attn_allocator.available_size():
if num_new_pages > self.swa_attn_allocator.available_size() // self.page_size:
return None

swa_last_loc = self.translate_loc_from_full_to_swa(last_loc)
Expand Down
Loading