Skip to content

Commit

Permalink
Support chunked prefill when radix cache is disabled (#811)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Aug 1, 2024
1 parent ca600e8 commit c020f9c
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
limitations under the License.
"""

"""Base cache class."""
"""Base tool cache for constrained decoding tools."""

import time


class BaseCache:
class BaseToolCache:
def __init__(self, enable=True):
self.enable = enable
self.reset()
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/constrained/fsm_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
"""Cache for the compressed finite state machine."""

from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.base_tool_cache import BaseToolCache


class FSMCache(BaseCache):
class FSMCache(BaseToolCache):
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
super().__init__(enable=enable)

Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/constrained/jump_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
make_byte_level_fsm,
make_deterministic_fsm,
)
from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.base_tool_cache import BaseToolCache

IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"

Expand Down Expand Up @@ -151,7 +151,7 @@ def is_jump_forward_symbol_state(self, state):
)


class JumpForwardCache(BaseCache):
class JumpForwardCache(BaseToolCache):
def __init__(self):
super().__init__()

Expand Down
38 changes: 29 additions & 9 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.radix_cache import RadixCache

Expand Down Expand Up @@ -486,15 +487,33 @@ def retract_decode(self):
req = self.reqs[idx]
retracted_reqs.append(req)

# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices)

# release the last node
self.tree_cache.dec_lock_ref(req.last_node)
if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][: seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
del self.tree_cache.entries[req.rid]
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))

# release the last node
self.tree_cache.dec_lock_ref(req.last_node)

# NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size = (
len(sorted_indices) * global_config.retract_decode_steps
- self.token_to_kv_pool.available_size()
)
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)

req.prefix_indices = None
req.last_node = None
Expand Down Expand Up @@ -575,6 +594,7 @@ def check_for_jump_forward(self, model_runner):
if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.tolist()
self.tree_cache.cache_req(
rid=req.rid,
token_ids=cur_all_ids,
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
Expand Down
27 changes: 21 additions & 6 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
ForwardMode,
Req,
)
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner
Expand Down Expand Up @@ -144,11 +145,20 @@ def __init__(
)

# Init cache
self.tree_cache = RadixCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
disable=server_args.disable_radix_cache,
)
if (
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
self.tree_cache = ChunkCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = PolicyScheduler(
self.schedule_policy,
Expand Down Expand Up @@ -354,7 +364,10 @@ def get_new_prefill_batch(self) -> Optional[Batch]:
# Compute matched prefix length
for req in self.waiting_queue:
req.input_ids = req.origin_input_ids + req.output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
prefix_indices, last_node = self.tree_cache.match_prefix(
rid=req.rid,
key=req.input_ids,
)
if req.return_logprob:
prefix_indices = prefix_indices[: req.logprob_start_len]
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
Expand Down Expand Up @@ -614,6 +627,7 @@ def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.input_ids),
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
Expand Down Expand Up @@ -771,6 +785,7 @@ def handle_finished_requests(self, batch: Batch):
for i in finished_indices:
req = batch.reqs[i]
self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
Expand Down
43 changes: 43 additions & 0 deletions python/sglang/srt/mem_cache/base_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from abc import ABC, abstractmethod


class BasePrefixCache(ABC):
"""Cache can be indexed by either rid or key."""

@abstractmethod
def reset(self):
pass

@abstractmethod
def match_prefix(self, **kwargs):
pass

@abstractmethod
def insert(self, **kwargs):
pass

@abstractmethod
def cache_req(self, **kwargs):
pass

@abstractmethod
def evict(self, num_tokens, evict_callback):
pass

@abstractmethod
def inc_lock_ref(self, node):
pass

@abstractmethod
def dec_lock_ref(self, node):
pass

@abstractmethod
def evictable_size(self):
pass

def total_size(self):
raise NotImplementedError

def pretty_print(self):
raise NotImplementedError
60 changes: 60 additions & 0 deletions python/sglang/srt/mem_cache/chunk_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Cache for chunked prefill, used when RadixCache is disabled."""

from sglang.srt.mem_cache.base_cache import BasePrefixCache


class ChunkCacheEntry:
def __init__(self, rid, value):
self.rid = rid
self.value = value


class ChunkCache(BasePrefixCache):
def __init__(self, req_to_token_pool, token_to_kv_pool):
self.disable = True
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool

self.reset()

def reset(self):
self.entries = {}

def match_prefix(self, rid, **kwargs):
if rid not in self.entries:
return [], None

entry = self.entries[rid]
return entry.value, entry

def cache_req(
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
):
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
if del_in_memory_pool:
assert rid in self.entries
self.req_to_token_pool.free(req_pool_idx)
self.token_to_kv_pool.free(indices)
return

if rid not in self.entries:
self.entries[rid] = ChunkCacheEntry(rid, indices)

entry = self.entries[rid]
entry.value = indices
return indices, entry

def insert(self):
raise NotImplementedError

def evict(self, num_tokens, evict_callback):
pass

def inc_lock_ref(self, node):
return 0

def dec_lock_ref(self, node):
return 0

def evictable_size(self):
return 0
7 changes: 5 additions & 2 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import torch

from sglang.srt.mem_cache.base_cache import BasePrefixCache


class TreeNode:
def __init__(self):
Expand All @@ -46,7 +48,7 @@ def _key_match(key0, key1):
return i


class RadixCache:
class RadixCache(BasePrefixCache):
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
Expand All @@ -62,7 +64,7 @@ def reset(self):
self.root_node.lock_ref = 1
self.evictable_size_ = 0

def match_prefix(self, key):
def match_prefix(self, key, **kwargs):
if self.disable:
return [], self.root_node

Expand Down Expand Up @@ -90,6 +92,7 @@ def cache_req(
req_pool_idx,
del_in_memory_pool=True,
old_last_node=None,
**kwargs,
):
# Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
Expand Down
4 changes: 0 additions & 4 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,6 @@ def check_server_args(self):
self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported"

assert not (
self.chunked_prefill_size is not None and self.disable_radix_cache
), "chunked prefill is not supported with radix cache disabled currently"


@dataclasses.dataclass
class PortArgs:
Expand Down

0 comments on commit c020f9c

Please sign in to comment.