diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 06cd1a660f59..52f3f1d22ebb 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -205,8 +205,8 @@ from sglang.srt.plugins import load_plugins from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args -from sglang.srt.session.session_aware_cache import SessionAwareCache from sglang.srt.session.session_controller import SessionController +from sglang.srt.session.streaming_session import StreamingSession from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( DynamicGradMode, @@ -865,7 +865,6 @@ def init_cache_with_memory_pool(self): ComponentType.SWA if self.is_hybrid_swa else ComponentType.MAMBA ) params.tree_components = tuple(tree_components) - params.enable_streaming_session = server_args.enable_streaming_session self.tree_cache = UnifiedRadixCache(params) elif self.is_hybrid_swa: from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache @@ -894,7 +893,7 @@ def init_cache_with_memory_pool(self): server_args.enable_streaming_session and not self.tree_cache.supports_streaming_session() ): - self.tree_cache = SessionAwareCache(self.tree_cache) + self.tree_cache = StreamingSession(self.tree_cache) if self.enable_hisparse: # Coordinator was created inside ModelRunner.initialize() before CUDA graph capture diff --git a/python/sglang/srt/mem_cache/cache_init_params.py b/python/sglang/srt/mem_cache/cache_init_params.py index 5a7c40ded976..6f6fafae05aa 100644 --- a/python/sglang/srt/mem_cache/cache_init_params.py +++ b/python/sglang/srt/mem_cache/cache_init_params.py @@ -42,5 +42,3 @@ class CacheInitParams: cache_ttl_seconds: Optional[float] = None tree_components: Optional[tuple[ComponentType, ...]] = None - - enable_streaming_session: bool = False diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index f59d9c7bc37d..59c1dd974e13 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -478,7 +478,7 @@ def release_kv_cache(req: Req, tree_cache: BasePrefixCache, is_insert: bool = Tr tree_cache.cache_finished_req(req, is_insert=is_insert) - # SessionAwareCache.cache_finished_req handles speculative tail trim + # StreamingSession.cache_finished_req handles speculative tail trim # and bookkeeping flag sync internally, then sets req_pool_idx = None. if req.req_pool_idx is None: return diff --git a/python/sglang/srt/mem_cache/unified_radix_cache.py b/python/sglang/srt/mem_cache/unified_radix_cache.py index 201fca688b09..608583cd7e86 100644 --- a/python/sglang/srt/mem_cache/unified_radix_cache.py +++ b/python/sglang/srt/mem_cache/unified_radix_cache.py @@ -40,7 +40,7 @@ get_and_increase_time_counter, ) from sglang.srt.mem_cache.utils import convert_to_bigram_key -from sglang.srt.session.session_aware_cache import SessionAwareCache +from sglang.srt.session.streaming_session import StreamingSession if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -210,12 +210,12 @@ def __init__( else: self.key_convert_fn = lambda key: key - # Streaming session: embedded SessionAwareCache with self as inner. + # Streaming session: embedded StreamingSession with self as inner. + # Always on -- zero overhead when no streaming session is open (the + # try_* entries short-circuit on non-streaming reqs / real TreeNodes). # Dispatch methods below pre-check conditions so the session's # internal fall-through to self.inner.xxx never fires -- no recursion. - self.session: Optional[SessionAwareCache] = ( - SessionAwareCache(inner=self) if params.enable_streaming_session else None - ) + self.session = StreamingSession(inner=self) self.reset() logger.info(f"Init Unified RadixTree with components {self.tree_components}") @@ -231,14 +231,12 @@ def reset(self) -> None: self.lru_lists = { ct: UnifiedLRUList(ct, self.tree_components) for ct in self.tree_components } - if self.session is not None: - self.session.slots.clear() + self.session.slots.clear() def match_prefix(self, params: MatchPrefixParams) -> MatchResult: - if self.session is not None: - result = self.session.try_match_prefix(params) - if result is not None: - return result + result = self.session.try_match_prefix(params) + if result is not None: + return result key = params.key key, _ = maybe_bigram_convert(self.is_eagle, key) @@ -289,10 +287,9 @@ def evict(self, params: EvictParams) -> EvictResult: ) def inc_lock_ref(self, node: Any) -> IncLockRefResult: - if self.session is not None: - result = self.session.try_inc_lock_ref(node) - if result is not None: - return result + result = self.session.try_inc_lock_ref(node) + if result is not None: + return result if self.disable: return IncLockRefResult() result = IncLockRefResult() @@ -303,10 +300,9 @@ def inc_lock_ref(self, node: Any) -> IncLockRefResult: def dec_lock_ref( self, node: Any, params: Optional[DecLockRefParams] = None ) -> DecLockRefResult: - if self.session is not None: - result = self.session.try_dec_lock_ref(node, params) - if result is not None: - return result + result = self.session.try_dec_lock_ref(node, params) + if result is not None: + return result if self.disable: return DecLockRefResult() for component in self._components_tuple: @@ -315,9 +311,7 @@ def dec_lock_ref( return DecLockRefResult() def cache_finished_req(self, req: Req, is_insert: bool = True, **kwargs) -> None: - if self.session is not None and self.session.try_cache_finished_req( - req, is_insert=is_insert, **kwargs - ): + if self.session.try_cache_finished_req(req, is_insert=is_insert, **kwargs): return kv_committed_len = req.pop_committed_kv_cache() @@ -389,9 +383,7 @@ def cache_finished_req(self, req: Req, is_insert: bool = True, **kwargs) -> None ) def cache_unfinished_req(self, req: Req, chunked=False, **kwargs) -> None: - if self.session is not None and self.session.try_cache_unfinished_req( - req, chunked=chunked, **kwargs - ): + if self.session.try_cache_unfinished_req(req, chunked=chunked, **kwargs): return token_ids = req.fill_ids @@ -813,33 +805,24 @@ def supports_swa(self) -> bool: def supports_mamba(self) -> bool: return ComponentType.MAMBA in self.components - # ---- Streaming session API (delegates to composed SessionImpl) ---- + # ---- Streaming session API (delegates to composed StreamingSession) ---- def supports_streaming_session(self) -> bool: - return self.session is not None + return True def release_session(self, session_id: str) -> None: - if self.session is not None: - self.session.release_session(session_id) + self.session.release_session(session_id) def session_held_tokens(self, active_pool_idxs: Optional[set] = None) -> int: - if self.session is None: - return 0 return self.session.session_held_tokens(active_pool_idxs) def session_held_full_tokens(self, active_pool_idxs: Optional[set] = None) -> int: - if self.session is None: - return 0 return self.session.session_held_full_tokens(active_pool_idxs) def session_held_swa_tokens(self, active_pool_idxs: Optional[set] = None) -> int: - if self.session is None: - return 0 return self.session.session_held_swa_tokens(active_pool_idxs) def session_held_req_count(self, active_pool_idxs: Optional[set] = None) -> int: - if self.session is None: - return 0 return self.session.session_held_req_count(active_pool_idxs) def evictable_size(self) -> int: @@ -964,7 +947,7 @@ def sanity_check(self): # Skip when streaming sessions hold tree locks: the check asserts # all nodes are unlocked during idle, which streaming sessions break # by design (they hold a first-turn lock across turns). - if self.session is not None and self.session.any_holding_kv(): + if self.session.any_holding_kv(): return try: # 1. Collect all nodes from tree diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 85cc561f5a12..24937d7f0dc5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -4486,7 +4486,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--enable-streaming-session", action="store_true", default=ServerArgs.enable_streaming_session, - help="Enable streaming session mode and SessionAwareCache wrapper.", + help="Enable streaming session mode and StreamingSession wrapper.", ) parser.add_argument( "--random-seed", diff --git a/python/sglang/srt/session/session_aware_cache.py b/python/sglang/srt/session/streaming_session.py similarity index 99% rename from python/sglang/srt/session/session_aware_cache.py rename to python/sglang/srt/session/streaming_session.py index c34c24651491..c52e9baee2c5 100644 --- a/python/sglang/srt/session/session_aware_cache.py +++ b/python/sglang/srt/session/streaming_session.py @@ -113,11 +113,11 @@ def _is_streaming(req: Optional[Req]) -> bool: return req is not None and req.session is not None and req.session.streaming -class SessionAwareCache(BasePrefixCache): +class StreamingSession(BasePrefixCache): """Adds streaming-session KV save/restore on top of any BasePrefixCache. - Works both as an external wrapper (``SessionAwareCache(RadixCache(...))``) - and in embedded composition (``SessionAwareCache(inner=self)``). For the + Works both as an external wrapper (``StreamingSession(RadixCache(...))``) + and in embedded composition (``StreamingSession(inner=self)``). For the embedded case, the composing cache must pre-check dispatch conditions (``_is_streaming`` / ``find_active_slot`` / ``has_slot``) so the internal fall-through to ``self.inner.xxx`` never fires -- otherwise it recurses. diff --git a/test/registered/unit/mem_cache/test_streaming_session_unit.py b/test/registered/unit/mem_cache/test_streaming_session_unit.py index 4a349f9f6e6b..3caee68878c1 100644 --- a/test/registered/unit/mem_cache/test_streaming_session_unit.py +++ b/test/registered/unit/mem_cache/test_streaming_session_unit.py @@ -5,7 +5,7 @@ from sglang.srt.managers.schedule_batch import FINISH_ABORT from sglang.srt.mem_cache.base_prefix_cache import MatchResult from sglang.srt.mem_cache.common import release_kv_cache -from sglang.srt.session.session_aware_cache import SessionAwareCache, SessionSlot +from sglang.srt.session.streaming_session import SessionSlot, StreamingSession from sglang.test.ci.ci_register import register_cpu_ci register_cpu_ci(est_time=8, suite="stage-a-test-cpu") @@ -97,7 +97,7 @@ def test_streaming_release_kv_cache_defers_tail_free(monkeypatch): req_to_token = torch.arange(128, dtype=torch.int32).reshape(1, 128) req_to_token_pool = SimpleNamespace(req_to_token=req_to_token, free_slots=[]) allocator = _FakeAllocator() - tree_cache = SessionAwareCache( + tree_cache = StreamingSession( _FakeInnerCache(req_to_token_pool, allocator, page_size) ) req = _FakeReq("session-a", req_pool_idx=0, committed=17, allocated=40) @@ -137,7 +137,7 @@ def test_preabort_detaches_session_and_preserves_slot(): ) ], ) - tree_cache = SessionAwareCache(inner) + tree_cache = StreamingSession(inner) tree_cache.slots["session-a"] = SessionSlot( req_pool_idx=0, kv_committed_len=48, @@ -173,7 +173,7 @@ def test_first_mid_abort_nukes_ephemeral_slot(): req_to_token_pool = SimpleNamespace(req_to_token=req_to_token, free_slots=[]) allocator = _FakeAllocator() inner = _FakeInnerCache(req_to_token_pool, allocator, page_size) - tree_cache = SessionAwareCache(inner) + tree_cache = StreamingSession(inner) # No slot exists yet (first request). req = _FakeReq("session-a", req_pool_idx=0, committed=0, allocated=20) @@ -202,7 +202,7 @@ def test_nth_mid_abort_nukes_session_slot(): req_to_token_pool = SimpleNamespace(req_to_token=req_to_token, free_slots=[]) allocator = _FakeAllocator() inner = _FakeInnerCache(req_to_token_pool, allocator, page_size) - tree_cache = SessionAwareCache(inner) + tree_cache = StreamingSession(inner) # Session already has a slot from a previous turn. tree_cache.slots["session-a"] = SessionSlot( @@ -249,7 +249,7 @@ def test_trim_overshoot_postcondition(): req_to_token = torch.arange(128, dtype=torch.int32).reshape(1, 128) req_to_token_pool = SimpleNamespace(req_to_token=req_to_token, free_slots=[]) allocator = _FakeAllocator() - tree_cache = SessionAwareCache( + tree_cache = StreamingSession( _FakeInnerCache(req_to_token_pool, allocator, page_size) )