Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,19 +491,27 @@ def open_session(
self,
capacity_of_str_len: int,
session_id: Optional[str] = None,
streaming: bool = False,
timeout: Optional[float] = None,
) -> str:
"""Open a session for multi-turn conversation with shared context.
Args:
capacity_of_str_len: Maximum string length capacity for the session.
session_id: Optional session ID. If not provided, a UUID will be generated.
streaming: Use low-overhead path for realtime streaming (append-only mode).
timeout: If set, the session is automatically closed after being inactive
for this many seconds. Inactivity is measured from session open or the
most recent request submission.
Returns:
The session ID (either the provided one or a newly generated UUID).
"""
obj = OpenSessionReqInput(
capacity_of_str_len=capacity_of_str_len,
session_id=session_id,
streaming=streaming,
timeout=timeout,
)
return self.loop.run_until_complete(
self.tokenizer_manager.open_session(obj, None)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,6 +1610,8 @@ class ConfigureLoggingReq(BaseReq):
class OpenSessionReqInput(BaseReq):
capacity_of_str_len: int
session_id: Optional[str] = None
streaming: Optional[bool] = None
timeout: Optional[float] = None


@dataclass
Expand Down
7 changes: 4 additions & 3 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
from typing import Any, Dict

from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers.session_controller import Session
from sglang.srt.observability.scheduler_metrics_mixin import PrefillStats
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
Expand Down Expand Up @@ -499,7 +500,7 @@ def __init__(
lora_id: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
token_type_ids: List[int] = None,
session_id: Optional[str] = None,
session: Optional[Session] = None,
custom_logit_processor: Optional[str] = None,
require_reasoning: bool = False,
return_hidden_states: bool = False,
Expand Down Expand Up @@ -535,7 +536,7 @@ def __init__(
self.output_ids = []
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
self.fill_ids = []
self.session_id = session_id
self.session = session
self.input_embeds = input_embeds

# For req-level memory management
Expand Down Expand Up @@ -874,7 +875,7 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
match_result = tree_cache.match_prefix(
MatchPrefixParams(
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
req=self if tree_cache.supports_mamba() else None,
req=self,
cow_mamba=tree_cache.supports_mamba(),
)
)
Expand Down
41 changes: 35 additions & 6 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,10 @@ def init_cache_with_memory_pool(self):
else:
self.tree_cache = RadixCache(params)

from sglang.srt.mem_cache.session_aware_cache import SessionAwareCache

self.tree_cache = SessionAwareCache(self.tree_cache)

if (
server_args.disaggregation_mode == "decode"
and server_args.disaggregation_decode_enable_offload_kvcache
Expand Down Expand Up @@ -751,6 +755,7 @@ def init_running_status(self):
self.num_retracted_reqs: int = 0
self.num_paused_reqs: int = 0
self.sessions: Dict[str, Session] = {}
self._last_reap_sessions: float = 0.0
self.forward_sleep_time = None
self._engine_paused = False

Expand Down Expand Up @@ -1118,7 +1123,8 @@ def event_loop_normal(self):
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else:
# When the server is idle, do self-check and re-init some states
# When the server is idle, do self-check and re-init some states.
# Skip if there are any streaming sessions (latency sensitive).
self.self_check_during_idle()

# Update last_batch
Expand Down Expand Up @@ -1349,7 +1355,10 @@ def _split_work_and_control_reqs(self, recv_reqs: List):
return work_reqs, control_reqs

def process_input_requests(self, recv_reqs: List):

now = time.monotonic()
if now - self._last_reap_sessions > 1.0: # reap sessions every second
self._last_reap_sessions = now
self.reap_timed_out_sessions()
for recv_req in recv_reqs:
# If it is a health check generation request and there are running requests, ignore it.
if is_health_check_generate_req(recv_req) and (
Expand Down Expand Up @@ -1458,7 +1467,7 @@ def _maybe_clear_mm_inputs(self, batch: ScheduleBatch) -> None:
if not req.finished() or not (mm_inputs := req.multimodal_inputs):
continue
# For session requests, keep mm_inputs for the next request
if req.session_id:
if req.session:
continue
# For non-session requests, clear features and mm_inputs
for item in mm_inputs.mm_items:
Expand Down Expand Up @@ -2929,17 +2938,37 @@ def open_session(self, recv_req: OpenSessionReqInput):
return OpenSessionReqOutput(session_id, False)
else:
self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id
recv_req.capacity_of_str_len,
session_id,
streaming=bool(recv_req.streaming),
timeout=recv_req.timeout,
)
return OpenSessionReqOutput(session_id, True)

def close_session(self, recv_req: CloseSessionReqInput):
# handle error
session_id = recv_req.session_id
if session_id not in self.sessions:
logger.warning(f"session id {session_id} does not exist, cannot delete.")
else:
del self.sessions[session_id]
self._close_session(session_id)

def _close_session(self, session_id: str):
session = self.sessions[session_id]
if session.streaming and session.req_nodes:
assert len(session.req_nodes) == 1
req = next(iter(session.req_nodes.values())).req
if not req.finished():
req.session = None
self.tree_cache.release_session(session_id)
del self.sessions[session_id]

def reap_timed_out_sessions(self):
timed_out = [
sid for sid, session in self.sessions.items() if session.is_timed_out()
]
for sid in timed_out:
logger.info(f"Session {sid} timed out, closing.")
self._close_session(sid)

def maybe_sleep_on_idle(self):
if self.idle_sleeper is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def process_batch_result_prefill(
release_kv_cache(req, self.tree_cache)
req.time_stats.set_completion_time()
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
# This updates radix so others can match
self.tree_cache.cache_unfinished_req(req)

self.maybe_collect_customized_info(i, req, logits_output)
Expand Down
30 changes: 19 additions & 11 deletions python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@ def _check_hybrid_memory(self: Scheduler):
swa_available_size,
swa_evictable_size,
) = self._get_swa_token_info()
memory_leak = full_num_used != 0 or swa_num_used != 0
session_held = self.tree_cache.session_held_tokens()
memory_leak = (full_num_used - session_held) != 0 or swa_num_used != 0
token_msg = (
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}, {session_held=}\n"
)
return memory_leak, token_msg

Expand All @@ -110,8 +111,9 @@ def _check_mamba_memory(self: Scheduler):
mamba_available_size,
mamba_evictable_size,
) = self._get_mamba_token_info()
session_held = self.tree_cache.session_held_tokens()
memory_leak = (
full_num_used != self.tree_cache.full_protected_size()
full_num_used != self.tree_cache.full_protected_size() + session_held
or mamba_num_used != self.tree_cache.mamba_protected_size()
)
if memory_leak:
Expand Down Expand Up @@ -150,14 +152,11 @@ def _check_mamba_memory(self: Scheduler):
def _check_radix_cache_memory(self: Scheduler):
_, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size()
session_held = self.tree_cache.session_held_tokens()
memory_leak = (available_size + evictable_size) != (
# self.max_total_num_tokens
# if not self.enable_hierarchical_cache
# else self.max_total_num_tokens - protected_size
self.max_total_num_tokens
- protected_size
self.max_total_num_tokens - protected_size - session_held
)
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}, {session_held=}\n"
return memory_leak, token_msg

def _get_batch_uncached_size(self: Scheduler, batch: ScheduleBatch) -> int:
Expand Down Expand Up @@ -205,7 +204,14 @@ def self_check_during_busy(self: Scheduler):
log_msg = f"[Mem Check (BUSY)] {available_size=}, {evictable_size=}, {protected_size=}, {uncached_size=}"
logger.info(log_msg)

total_tokens = available_size + evictable_size + protected_size + uncached_size
session_held = self.tree_cache.session_held_tokens()
total_tokens = (
available_size
+ evictable_size
+ protected_size
+ uncached_size
+ session_held
)
assert (
total_tokens == self.max_total_num_tokens
), f"Mem Leak Detected! {total_tokens=} vs {self.max_total_num_tokens=}"
Expand All @@ -218,10 +224,12 @@ def _check_req_pool(self: Scheduler):
else:
req_total_size = self.req_to_token_pool.size

if len(self.req_to_token_pool.free_slots) != req_total_size:
session_req_count = self.tree_cache.session_held_req_count()
if len(self.req_to_token_pool.free_slots) + session_req_count != req_total_size:
msg = (
"req_to_token_pool memory leak detected!"
f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"session_held={session_req_count}, "
f"total_size={self.req_to_token_pool.size}\n"
)
raise_error_or_warn(
Expand Down
60 changes: 52 additions & 8 deletions python/sglang/srt/managers/session_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# ==============================================================================

import logging
import time
import uuid
from typing import Dict, Optional

Expand Down Expand Up @@ -65,25 +66,59 @@ def _str_helper(self, prefix=""):


class Session:
def __init__(self, capacity_of_str_len: int, session_id: Optional[str] = None):
def __init__(
self,
capacity_of_str_len: int,
session_id: Optional[str] = None,
streaming: bool = False,
timeout: Optional[float] = None,
):
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
self.capacity_of_str_len = capacity_of_str_len
self.streaming = streaming
self.timeout = timeout
self.last_active_time: float = time.monotonic()
self.req_nodes: Dict[str, SessionReqNode] = {}

def is_timed_out(self) -> bool:
if self.timeout is None:
return False
return time.monotonic() - self.last_active_time > self.timeout

def create_req(self, req: TokenizedGenerateReqInput, tokenizer, vocab_size: int):
assert req.session_params is not None
self.last_active_time = time.monotonic()
session_params = req.session_params

last_req_node = None
last_req = None
abort = False
if session_params.replace:
abort_message = ""
if self.streaming:
# Streaming sessions: only simple appends allowed; reject otherwise.
if session_params.replace:
abort = True
abort_message = "Streaming sessions do not support replace."
elif session_params.drop_previous_output:
abort = True
abort_message = (
"Streaming sessions do not support drop_previous_output."
)
elif session_params.offset and session_params.offset != 0:
abort = True
abort_message = "Streaming sessions do not support offset."
elif self.req_nodes:
assert len(self.req_nodes) == 1
_, last_req_node = self.req_nodes.popitem()
last_req = last_req_node.req
elif session_params.replace:
if session_params.rid is None:
for _, req_node in self.req_nodes.items():
req_node.clear(self.req_nodes)
else:
if session_params.rid not in self.req_nodes:
abort = True
abort_message = "Invalid request session id"
else:
last_req_node = self.req_nodes[session_params.rid]
last_req_node.abort()
Expand All @@ -93,18 +128,22 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer, vocab_size: int)
if session_params.rid is not None:
if session_params.rid not in self.req_nodes:
abort = True
abort_message = "Invalid request session id"
else:
last_req_node = self.req_nodes[session_params.rid]
last_req = last_req_node.req
if not last_req.finished():
logging.warning(
"The request in a session is appending to a request that hasn't finished."
)
abort = True
abort_message = "Session request is appending to a request that hasn't finished."
logging.warning(abort_message)

if last_req is not None:
# trim bos token if it is an append
if tokenizer is not None and req.input_ids[0] == tokenizer.bos_token_id:
if (
tokenizer is not None
and req.input_ids
and req.input_ids[0] == tokenizer.bos_token_id
):
req.input_ids = req.input_ids[1:]

input_ids = (
Expand Down Expand Up @@ -136,14 +175,15 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer, vocab_size: int)
else:
input_ids = req.input_ids
input_ids_unpadded = req.input_ids

new_req = Req(
rid=req.rid,
origin_input_text=None,
origin_input_ids=input_ids,
origin_input_ids_unpadded=input_ids_unpadded,
sampling_params=req.sampling_params,
lora_id=req.lora_id,
session_id=self.session_id,
session=self,
custom_logit_processor=req.custom_logit_processor,
stream=req.stream,
return_logprob=req.return_logprob,
Expand All @@ -156,7 +196,11 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer, vocab_size: int)
new_req.tokenizer = tokenizer

if abort:
new_req.set_finish_with_abort("Invalid request session id")
new_req.set_finish_with_abort(abort_message)
elif self.streaming:
if last_req is not None:
last_req.session = None
self.req_nodes[req.rid] = SessionReqNode(new_req)
else:
new_req_node = SessionReqNode(new_req, last_req_node)
self.req_nodes[req.rid] = new_req_node
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/mem_cache/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,13 @@ def release_kv_cache(req: Req, tree_cache: BasePrefixCache, is_insert: bool = Tr

tree_cache.cache_finished_req(req, is_insert=is_insert)

# FIXME: SessionAwareCache.cache_finished_req sets req_pool_idx = None to
# transfer KV ownership to the SessionSlot, so we skip the remaining
# cleanup (overalloc free + pool slot free). This means over-allocated
# tokens from speculative decoding are NOT freed between turns.
if req.req_pool_idx is None:
Comment thread
hnyls2002 marked this conversation as resolved.
return

start_p, end_p = req.pop_overallocated_kv_cache()

global_server_args = get_global_server_args()
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def alloc(self, reqs: list[Req]) -> Optional[List[int]]:
reusing = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None]
if not any(r.is_dllm() for r in reqs):
assert (
len(reusing) <= 1
sum(1 for i in reusing if reqs[i].is_chunked > 0) <= 1
), "only one chunked request may reuse req_pool_idx in a batch"
assert all(
reqs[i].is_chunked > 0 or reqs[i].kv_committed_len > 0 for i in reusing
Expand Down
Loading
Loading