Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@
OpenSessionReqInput,
ParseFunctionCallReq,
PauseGenerationReqInput,
PinPrefixReqInput,
ProfileReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
Expand Down Expand Up @@ -845,25 +844,6 @@ async def hicache_storage_backend_status():
}


@app.api_route("/hicache/pin_prefix", methods=["POST"])
@auth_level(AuthLevel.ADMIN_OPTIONAL)
async def pin_prefix(obj: PinPrefixReqInput):
"""Pin a prefix by token_ids to resist eviction."""
if not _global_state.tokenizer_manager.server_args.admin_api_key:
return _admin_api_key_missing_response()
ret = await _global_state.tokenizer_manager.pin_prefix(
obj.token_ids, obj.ttl_seconds
)
return ORJSONResponse(
content={
"status": "ok" if ret.success else "error",
"nodes_pinned": ret.nodes_pinned,
"message": ret.message,
},
status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,
)


@app.api_route("/start_profile", methods=["GET", "POST"])
@auth_level(AuthLevel.ADMIN_OPTIONAL)
async def start_profile_async(obj: Optional[ProfileReqInput] = None):
Expand Down
4 changes: 0 additions & 4 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,17 +284,13 @@ class Envs:
SGLANG_HICACHE_DECODE_OFFLOAD_STRIDE = EnvInt(None)
SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR = EnvStr(None)
SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR = EnvStr(None)
# Max fraction of cache (by token count) that can be pinned; 0 = disable pinning.
SGLANG_HICACHE_MAX_PINNED_RATIO = EnvFloat(0.0)

# Staging buffer for heterogeneous TP KV transfer
SGLANG_DISAGG_STAGING_BUFFER = EnvBool(False)
SGLANG_DISAGG_STAGING_BUFFER_SIZE_MB = EnvInt(64)
SGLANG_DISAGG_STAGING_POOL_SIZE_MB = EnvInt(4096)
# TODO(yangminl): remove SGLANG_STAGING_USE_TORCH and the torch fallback in
# staging_buffer.py once Triton kernels are fully validated in production.
SGLANG_STAGING_USE_TORCH = EnvBool(False)

# Mooncake KV Transfer
SGLANG_MOONCAKE_CUSTOM_MEM_POOL = EnvStr(None)
ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE = EnvBool(False)
Expand Down
15 changes: 0 additions & 15 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,21 +1181,6 @@ class DetachHiCacheStorageReqOutput(BaseReq):
message: str = ""


@dataclass
class PinPrefixReqInput(BaseReq):
"""Pin a prefix by token_ids to resist eviction."""

token_ids: List[int] = field(default_factory=list)
ttl_seconds: int = 300 # TTL in seconds, default 5 minutes


@dataclass
class PinPrefixReqOutput(BaseReq):
success: bool
nodes_pinned: int = 0
message: str = ""


@dataclass
class PauseGenerationReqInput(BaseReq):
"""
Expand Down
34 changes: 0 additions & 34 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@
LoadLoRAAdapterReqOutput,
OpenSessionReqInput,
PauseGenerationReqInput,
PinPrefixReqInput,
PinPrefixReqOutput,
ProfileReq,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
Expand Down Expand Up @@ -1185,7 +1183,6 @@ def init_request_dispatcher(self):
(ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
(AttachHiCacheStorageReqInput, self.attach_hicache_storage_wrapped),
(DetachHiCacheStorageReqInput, self.detach_hicache_storage_wrapped),
(PinPrefixReqInput, self.pin_prefix_wrapped),
(AbortReq, self.abort_request),
(OpenSessionReqInput, self.open_session),
(CloseSessionReqInput, self.close_session),
Expand Down Expand Up @@ -3009,37 +3006,6 @@ def detach_hicache_storage_wrapped(

return DetachHiCacheStorageReqOutput(success=False, message=msg)

def pin_prefix_wrapped(self, recv_req: PinPrefixReqInput):
if not hasattr(self.tree_cache, "pin_prefix"):
return PinPrefixReqOutput(
success=False,
nodes_pinned=0,
message="PIN requires --enable-hierarchical-cache",
)
if getattr(self.tree_cache, "_max_pinned_tokens", 0) <= 0:
return PinPrefixReqOutput(
success=False,
nodes_pinned=0,
message="Pinning is disabled (SGLANG_HICACHE_MAX_PINNED_RATIO is 0)",
)
nodes_pinned, reject_reason = self.tree_cache.pin_prefix(
recv_req.token_ids, recv_req.ttl_seconds
)
if nodes_pinned == 0:
return PinPrefixReqOutput(
success=False,
nodes_pinned=0,
message=reject_reason or "No matching prefix found in cache to pin",
)
msg = f"Pinned {nodes_pinned} nodes (ttl={recv_req.ttl_seconds}s)"
if reject_reason:
msg += f"; {reject_reason}"
return PinPrefixReqOutput(
success=True,
nodes_pinned=nodes_pinned,
message=msg,
)

def flush_cache(self):
"""Flush the memory pool and cache."""
if self.is_fully_idle():
Expand Down
22 changes: 0 additions & 22 deletions python/sglang/srt/managers/tokenizer_communicator_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@
LoadLoRAAdapterReqOutput,
LoRAUpdateOutput,
OpenSessionReqInput,
PinPrefixReqInput,
PinPrefixReqOutput,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
Expand Down Expand Up @@ -216,9 +214,6 @@ def init_communicators(self: TokenizerManager, server_args: ServerArgs):
self.detach_hicache_storage_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.pin_prefix_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
Expand Down Expand Up @@ -309,10 +304,6 @@ def _get_communicator_dispatcher(self: TokenizerManager):
DetachHiCacheStorageReqOutput,
self.detach_hicache_storage_communicator.handle_recv,
),
(
PinPrefixReqOutput,
self.pin_prefix_communicator.handle_recv,
),
(
FlushCacheReqOutput,
self.flush_cache_communicator.handle_recv,
Expand Down Expand Up @@ -421,19 +412,6 @@ async def detach_hicache_storage(
self.server_args.hicache_storage_backend_extra_config = None
return out

async def pin_prefix(
self: TokenizerManager, token_ids: List[int], ttl_seconds: int = 300
) -> PinPrefixReqOutput:
"""Pin a prefix by token_ids to resist eviction."""
results = await self.pin_prefix_communicator(
PinPrefixReqInput(token_ids=token_ids, ttl_seconds=ttl_seconds)
)
all_success, all_message = _Communicator.merge_results(results)
total = sum(r.nodes_pinned for r in results)
return PinPrefixReqOutput(
success=all_success, nodes_pinned=total, message=all_message
)

async def start_profile(
self: TokenizerManager,
output_dir: Optional[str] = None,
Expand Down
125 changes: 1 addition & 124 deletions python/sglang/srt/mem_cache/hiradix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
import threading
import time
from queue import Empty
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional

import torch

from sglang.srt.environ import envs
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
from sglang.srt.mem_cache.base_prefix_cache import (
DecLockRefParams,
Expand Down Expand Up @@ -161,18 +160,6 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs):

self.evictable_host_leaves = set()

# Pin budget: max tokens that can be pinned = ratio * host pool capacity.
pin_ratio = envs.SGLANG_HICACHE_MAX_PINNED_RATIO.get()
if pin_ratio < 0 or pin_ratio >= 1:
raise ValueError(
f"SGLANG_HICACHE_MAX_PINNED_RATIO must be in [0, 1), got {pin_ratio}"
)
self._max_pinned_tokens = int(self.token_to_kv_pool_host.size * pin_ratio)
self.pinned_size_ = 0
logger.info(
"Pin budget: %d tokens (ratio=%.3f)", self._max_pinned_tokens, pin_ratio
)

super().__init__(params=params)

def shutdown(self):
Expand Down Expand Up @@ -589,7 +576,6 @@ def reset(self):
# Clear per-request tracking dicts
self.prefetch_loaded_tokens_by_reqid.clear()
self.evictable_host_leaves.clear()
self.pinned_size_ = 0
super().reset()

def get_height(self, node: TreeNode):
Expand Down Expand Up @@ -729,79 +715,6 @@ def loading_check(self):
def evictable_size(self):
return self.evictable_size_

def _is_pinned(self, node: TreeNode) -> bool:
"""Check if a node has an active (non-expired) pin."""
return node.pin_expiry > 0 and time.monotonic() <= node.pin_expiry

def _clear_pin(self, node: TreeNode):
"""Clear expired pin state and release host_ref_counter hold."""
if node.pin_expiry > 0:
self.pinned_size_ = max(0, self.pinned_size_ - len(node.key))
node.host_ref_counter = max(0, node.host_ref_counter - 1)
node.pin_expiry = 0.0
node.pin_ttl = 0

def pin_prefix(
self, token_ids: List[int], ttl_seconds: int = 300
) -> Tuple[int, Optional[str]]:
"""Pin nodes along a prefix path. Returns (nodes_pinned, reject_reason)."""
if self.disable or not token_ids:
return (0, None)

key, _ = self.maybe_bigram_convert(self._to_radix_key(token_ids))
if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len]
if len(key) == 0:
return (0, None)

expiry = time.monotonic() + ttl_seconds
nodes_pinned = 0
budget_exceeded = False
node = self.root_node
child_key = self.get_child_key_fn(key)

while len(key) > 0 and child_key in node.children:
child = node.children[child_key]
prefix_len = self.key_match_fn(child.key, key)

# First pin on this node: check budget, then acquire hold
if child.pin_expiry == 0:
if self.pinned_size_ + len(child.key) > self._max_pinned_tokens:
budget_exceeded = True
break
child.host_ref_counter += 1
self.pinned_size_ += len(child.key)

# Eagerly back up to host so eviction finds pinned nodes
# already backuped and never enters the write_back drain
# path, which would leak lock_ref on in-flight
# write-through entries. No-op under write_back policy.
self._inc_hit_count(child)

# Extend expiry and store TTL for refresh-on-hit
child.pin_expiry = max(child.pin_expiry, expiry)
child.pin_ttl = max(child.pin_ttl, ttl_seconds)
nodes_pinned += 1

if prefix_len < len(child.key):
break

node = child
key = key[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)

logger.info(
"[PIN] pin_prefix: nodes_pinned=%d, ttl=%ds", nodes_pinned, ttl_seconds
)
if budget_exceeded:
msg = f"Pin budget exhausted ({self.pinned_size_}/{self._max_pinned_tokens} tokens pinned)"
if nodes_pinned == 0:
return (0, msg)
return (nodes_pinned, f"prefix partially pinned; {msg}")
return (nodes_pinned, None)

def _to_radix_key(self, token_ids: List[int]) -> RadixKey:
"""Convert raw token_ids to a RadixKey for tree walking.

Expand Down Expand Up @@ -880,26 +793,6 @@ def evict(self, params: EvictParams) -> EvictResult:
if x.lock_ref > 0:
continue

if self._is_pinned(x):
# Still active: demote to host if possible
if x.backuped:
num_evicted += self._evict_backuped(x)
continue
written = self.write_backup(x, write_back=True)
if written > 0:
num_evicted += written
write_back_nodes.append(x)
continue # backup succeeded, pin holds on host
# Host full -- drop pin so GPU can be freed
self._clear_pin(x)
logger.warning(
"[PIN] evict: can't backup node %d to host, releasing pin",
x.id,
)
elif x.pin_expiry > 0:
# Expired pin: clear and fall through to normal eviction
self._clear_pin(x)

if not x.backuped:
if self.cache_controller.write_policy == "write_back":
# write to host if the node is not backuped
Expand Down Expand Up @@ -965,11 +858,6 @@ def evict_host(self, num_tokens: int):
if not x.evicted:
continue

# Expire stale pins before checking host_ref_counter
if x.pin_expiry > 0 and time.monotonic() > x.pin_expiry:
self._clear_pin(x)

# node is protected from eviction as it has ongoing prefetch, backup, or pin
if x.host_ref_counter > 0:
continue

Expand Down Expand Up @@ -1348,9 +1236,6 @@ def _insert_helper_host(
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.monotonic()
# Refresh pin TTL on host insert hit
if self._is_pinned(node):
node.pin_expiry = time.monotonic() + node.pin_ttl
prefix_len = self.key_match_fn(node.key, key)
key = key[prefix_len:]
host_value = host_value[prefix_len:]
Expand Down Expand Up @@ -1386,9 +1271,6 @@ def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
child.last_access_time = time.monotonic()
# Refresh pin TTL on cache hit
if self._is_pinned(child):
child.pin_expiry = time.monotonic() + child.pin_ttl
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
Expand All @@ -1413,11 +1295,6 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.pin_expiry = child.pin_expiry
new_node.pin_ttl = child.pin_ttl
# If child is pinned, new parent inherits a host_ref_counter hold
if child.pin_expiry > 0:
new_node.host_ref_counter += 1
new_node.key = child.key[:split_len]
new_node.hit_count = child.hit_count

Expand Down
4 changes: 0 additions & 4 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,6 @@ def __init__(self, id: Optional[int] = None, priority: int = 0):
self.key: RadixKey = None
self.value: Optional[torch.Tensor] = None
self.lock_ref = 0
self.pin_expiry: float = (
0.0 # absolute expiry time (time.monotonic()), 0 = not pinned
)
self.pin_ttl: int = 0 # original TTL in seconds, for refresh-on-hit
self.last_access_time = time.monotonic()
self.creation_time = time.monotonic()

Expand Down
Loading