Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@
from sglang.srt.plugins import load_plugins
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
from sglang.srt.session.session_aware_cache import SessionAwareCache
from sglang.srt.session.session_controller import SessionController
from sglang.srt.session.streaming_session import StreamingSession
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
DynamicGradMode,
Expand Down Expand Up @@ -865,7 +865,6 @@ def init_cache_with_memory_pool(self):
ComponentType.SWA if self.is_hybrid_swa else ComponentType.MAMBA
)
params.tree_components = tuple(tree_components)
params.enable_streaming_session = server_args.enable_streaming_session
self.tree_cache = UnifiedRadixCache(params)
elif self.is_hybrid_swa:
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
Expand Down Expand Up @@ -894,7 +893,7 @@ def init_cache_with_memory_pool(self):
server_args.enable_streaming_session
and not self.tree_cache.supports_streaming_session()
):
self.tree_cache = SessionAwareCache(self.tree_cache)
self.tree_cache = StreamingSession(self.tree_cache)

if self.enable_hisparse:
# Coordinator was created inside ModelRunner.initialize() before CUDA graph capture
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/srt/mem_cache/cache_init_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,3 @@ class CacheInitParams:
cache_ttl_seconds: Optional[float] = None

tree_components: Optional[tuple[ComponentType, ...]] = None

enable_streaming_session: bool = False
2 changes: 1 addition & 1 deletion python/sglang/srt/mem_cache/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def release_kv_cache(req: Req, tree_cache: BasePrefixCache, is_insert: bool = Tr

tree_cache.cache_finished_req(req, is_insert=is_insert)

# SessionAwareCache.cache_finished_req handles speculative tail trim
# StreamingSession.cache_finished_req handles speculative tail trim
# and bookkeeping flag sync internally, then sets req_pool_idx = None.
if req.req_pool_idx is None:
return
Expand Down
59 changes: 21 additions & 38 deletions python/sglang/srt/mem_cache/unified_radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
get_and_increase_time_counter,
)
from sglang.srt.mem_cache.utils import convert_to_bigram_key
from sglang.srt.session.session_aware_cache import SessionAwareCache
from sglang.srt.session.streaming_session import StreamingSession

if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
Expand Down Expand Up @@ -210,12 +210,12 @@ def __init__(
else:
self.key_convert_fn = lambda key: key

# Streaming session: embedded SessionAwareCache with self as inner.
# Streaming session: embedded StreamingSession with self as inner.
# Always on -- zero overhead when no streaming session is open (the
# try_* entries short-circuit on non-streaming reqs / real TreeNodes).
# Dispatch methods below pre-check conditions so the session's
# internal fall-through to self.inner.xxx never fires -- no recursion.
self.session: Optional[SessionAwareCache] = (
SessionAwareCache(inner=self) if params.enable_streaming_session else None
)
self.session = StreamingSession(inner=self)

self.reset()
logger.info(f"Init Unified RadixTree with components {self.tree_components}")
Expand All @@ -231,14 +231,12 @@ def reset(self) -> None:
self.lru_lists = {
ct: UnifiedLRUList(ct, self.tree_components) for ct in self.tree_components
}
if self.session is not None:
self.session.slots.clear()
self.session.slots.clear()

def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
if self.session is not None:
result = self.session.try_match_prefix(params)
if result is not None:
return result
result = self.session.try_match_prefix(params)
if result is not None:
return result

key = params.key
key, _ = maybe_bigram_convert(self.is_eagle, key)
Expand Down Expand Up @@ -289,10 +287,9 @@ def evict(self, params: EvictParams) -> EvictResult:
)

def inc_lock_ref(self, node: Any) -> IncLockRefResult:
if self.session is not None:
result = self.session.try_inc_lock_ref(node)
if result is not None:
return result
result = self.session.try_inc_lock_ref(node)
if result is not None:
return result
if self.disable:
return IncLockRefResult()
result = IncLockRefResult()
Expand All @@ -303,10 +300,9 @@ def inc_lock_ref(self, node: Any) -> IncLockRefResult:
def dec_lock_ref(
self, node: Any, params: Optional[DecLockRefParams] = None
) -> DecLockRefResult:
if self.session is not None:
result = self.session.try_dec_lock_ref(node, params)
if result is not None:
return result
result = self.session.try_dec_lock_ref(node, params)
if result is not None:
return result
if self.disable:
return DecLockRefResult()
for component in self._components_tuple:
Expand All @@ -315,9 +311,7 @@ def dec_lock_ref(
return DecLockRefResult()

def cache_finished_req(self, req: Req, is_insert: bool = True, **kwargs) -> None:
if self.session is not None and self.session.try_cache_finished_req(
req, is_insert=is_insert, **kwargs
):
if self.session.try_cache_finished_req(req, is_insert=is_insert, **kwargs):
return

kv_committed_len = req.pop_committed_kv_cache()
Expand Down Expand Up @@ -389,9 +383,7 @@ def cache_finished_req(self, req: Req, is_insert: bool = True, **kwargs) -> None
)

def cache_unfinished_req(self, req: Req, chunked=False, **kwargs) -> None:
if self.session is not None and self.session.try_cache_unfinished_req(
req, chunked=chunked, **kwargs
):
if self.session.try_cache_unfinished_req(req, chunked=chunked, **kwargs):
return

token_ids = req.fill_ids
Expand Down Expand Up @@ -813,33 +805,24 @@ def supports_swa(self) -> bool:
def supports_mamba(self) -> bool:
return ComponentType.MAMBA in self.components

# ---- Streaming session API (delegates to composed SessionImpl) ----
# ---- Streaming session API (delegates to composed StreamingSession) ----

def supports_streaming_session(self) -> bool:
return self.session is not None
return True

def release_session(self, session_id: str) -> None:
if self.session is not None:
self.session.release_session(session_id)
self.session.release_session(session_id)

def session_held_tokens(self, active_pool_idxs: Optional[set] = None) -> int:
if self.session is None:
return 0
return self.session.session_held_tokens(active_pool_idxs)

def session_held_full_tokens(self, active_pool_idxs: Optional[set] = None) -> int:
if self.session is None:
return 0
return self.session.session_held_full_tokens(active_pool_idxs)

def session_held_swa_tokens(self, active_pool_idxs: Optional[set] = None) -> int:
if self.session is None:
return 0
return self.session.session_held_swa_tokens(active_pool_idxs)

def session_held_req_count(self, active_pool_idxs: Optional[set] = None) -> int:
if self.session is None:
return 0
return self.session.session_held_req_count(active_pool_idxs)

def evictable_size(self) -> int:
Expand Down Expand Up @@ -964,7 +947,7 @@ def sanity_check(self):
# Skip when streaming sessions hold tree locks: the check asserts
# all nodes are unlocked during idle, which streaming sessions break
# by design (they hold a first-turn lock across turns).
if self.session is not None and self.session.any_holding_kv():
if self.session.any_holding_kv():
return
try:
# 1. Collect all nodes from tree
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4486,7 +4486,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"--enable-streaming-session",
action="store_true",
default=ServerArgs.enable_streaming_session,
help="Enable streaming session mode and SessionAwareCache wrapper.",
help="Enable streaming session mode and StreamingSession wrapper.",
)
parser.add_argument(
"--random-seed",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ def _is_streaming(req: Optional[Req]) -> bool:
return req is not None and req.session is not None and req.session.streaming


class SessionAwareCache(BasePrefixCache):
class StreamingSession(BasePrefixCache):
"""Adds streaming-session KV save/restore on top of any BasePrefixCache.

Works both as an external wrapper (``SessionAwareCache(RadixCache(...))``)
and in embedded composition (``SessionAwareCache(inner=self)``). For the
Works both as an external wrapper (``StreamingSession(RadixCache(...))``)
and in embedded composition (``StreamingSession(inner=self)``). For the
embedded case, the composing cache must pre-check dispatch conditions
(``_is_streaming`` / ``find_active_slot`` / ``has_slot``) so the internal
fall-through to ``self.inner.xxx`` never fires -- otherwise it recurses.
Expand Down
12 changes: 6 additions & 6 deletions test/registered/unit/mem_cache/test_streaming_session_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sglang.srt.managers.schedule_batch import FINISH_ABORT
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.common import release_kv_cache
from sglang.srt.session.session_aware_cache import SessionAwareCache, SessionSlot
from sglang.srt.session.streaming_session import SessionSlot, StreamingSession
from sglang.test.ci.ci_register import register_cpu_ci

register_cpu_ci(est_time=8, suite="stage-a-test-cpu")
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_streaming_release_kv_cache_defers_tail_free(monkeypatch):
req_to_token = torch.arange(128, dtype=torch.int32).reshape(1, 128)
req_to_token_pool = SimpleNamespace(req_to_token=req_to_token, free_slots=[])
allocator = _FakeAllocator()
tree_cache = SessionAwareCache(
tree_cache = StreamingSession(
_FakeInnerCache(req_to_token_pool, allocator, page_size)
)
req = _FakeReq("session-a", req_pool_idx=0, committed=17, allocated=40)
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_preabort_detaches_session_and_preserves_slot():
)
],
)
tree_cache = SessionAwareCache(inner)
tree_cache = StreamingSession(inner)
tree_cache.slots["session-a"] = SessionSlot(
req_pool_idx=0,
kv_committed_len=48,
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_first_mid_abort_nukes_ephemeral_slot():
req_to_token_pool = SimpleNamespace(req_to_token=req_to_token, free_slots=[])
allocator = _FakeAllocator()
inner = _FakeInnerCache(req_to_token_pool, allocator, page_size)
tree_cache = SessionAwareCache(inner)
tree_cache = StreamingSession(inner)

# No slot exists yet (first request).
req = _FakeReq("session-a", req_pool_idx=0, committed=0, allocated=20)
Expand Down Expand Up @@ -202,7 +202,7 @@ def test_nth_mid_abort_nukes_session_slot():
req_to_token_pool = SimpleNamespace(req_to_token=req_to_token, free_slots=[])
allocator = _FakeAllocator()
inner = _FakeInnerCache(req_to_token_pool, allocator, page_size)
tree_cache = SessionAwareCache(inner)
tree_cache = StreamingSession(inner)

# Session already has a slot from a previous turn.
tree_cache.slots["session-a"] = SessionSlot(
Expand Down Expand Up @@ -249,7 +249,7 @@ def test_trim_overshoot_postcondition():
req_to_token = torch.arange(128, dtype=torch.int32).reshape(1, 128)
req_to_token_pool = SimpleNamespace(req_to_token=req_to_token, free_slots=[])
allocator = _FakeAllocator()
tree_cache = SessionAwareCache(
tree_cache = StreamingSession(
_FakeInnerCache(req_to_token_pool, allocator, page_size)
)

Expand Down
Loading