Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 50 additions & 19 deletions python/sglang/srt/mem_cache/hi_mamba_radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,13 +511,22 @@ def check_hicache_events(self):
self.cache_controller.storage_backend.get_stats()
)

def _protect_host_node(self, node: TreeNode):
def _protect_host_node(self, node: TreeNode, protect_mamba: bool = True):
node.protect_host()
self.evictable_full_host_leaves.discard(node)
if protect_mamba:
node.protect_host_mamba()
if self.mamba_host_lru_list.in_list(node):
self.mamba_host_lru_list.remove_node(node)

def _release_host_node(self, node: TreeNode):
def _release_host_node(self, node: TreeNode, release_mamba: bool = True):
node.release_host()
if node.host_ref_counter == 0:
if release_mamba:
node.release_host_mamba()
if node.host_mamba_ref_counter == 0 and node.mamba_host_value is not None:
if not self.mamba_host_lru_list.in_list(node):
self.mamba_host_lru_list.insert_mru(node)
if node.host_ref_counter == 0 and node.host_mamba_ref_counter == 0:
self._update_full_host_leaf_status(node)

def _discard_from_leaf_sets(self, node: TreeNode):
Expand All @@ -544,6 +553,7 @@ def _update_full_host_leaf_status(self, node: TreeNode):
or not node.backuped
or node == self.root_node
or node.host_ref_counter > 0
or node.host_mamba_ref_counter > 0
):
self.evictable_full_host_leaves.discard(node)
return
Expand Down Expand Up @@ -632,7 +642,10 @@ def _evict_host_leaf(self, node: TreeNode) -> int:
assert node.mamba_value is None, f"has device mamba, {node.id=}"
assert (
node.host_ref_counter == 0
), f"in use, {node.id=} {node.host_ref_counter=}"
), f"host kv in use, {node.id=} {node.host_ref_counter=}"
assert (
node.host_mamba_ref_counter == 0
), f"host mamba in use, {node.id=} {node.host_mamba_ref_counter=}"

full_num_evicted = self.cache_controller.evict_host(node.host_value)
node.host_value = None
Expand Down Expand Up @@ -665,7 +678,11 @@ def _delete_tombstone_leaf(self, node: TreeNode) -> None:

self._discard_from_leaf_sets(node)

if node.backuped and node.host_ref_counter == 0:
if (
node.backuped
and node.host_ref_counter == 0
and node.host_mamba_ref_counter == 0
):
self.cache_controller.evict_host(node.host_value)
node.host_value = None

Expand Down Expand Up @@ -782,15 +799,21 @@ def evict_mamba_host(self, num_mamba_hosts: int) -> int:
num_evicted = 0
while num_evicted < num_mamba_hosts and self.mamba_host_lru_list.in_list(x):
x_next = self.mamba_host_lru_list.get_prev_no_lock(x)
if x.host_ref_counter > 0:
x = x_next
continue

if x in self.evictable_full_host_leaves:
# Leaf: evictable_full_host_leaves guarantees both counters == 0
assert (
x.host_ref_counter == 0
), f"evict host leaf: host_ref_counter != 0 with {x.id=} {x.host_ref_counter=}"
assert (
x.host_mamba_ref_counter == 0
), f"evict host leaf: host_mamba_ref_counter != 0 with {x.id=} {x.host_mamba_ref_counter=}"
self._evict_host_leaf(x)
num_evicted += 1
else:
# internal host node: free host mamba only (tombstone)
# Internal host node
assert (
x.host_mamba_ref_counter == 0
), f"evict host mamba internal: host_mamba_ref_counter != 0 with {x.id=} {x.host_mamba_ref_counter=}"
self.mamba_host_lru_list.remove_node(x)
self.mamba_pool_host.free(x.mamba_host_value)
x.mamba_host_value = None
Expand Down Expand Up @@ -830,7 +853,7 @@ def evict_mamba(self, mamba_num: int) -> int:
# Leaf: evict KV + mamba atomically
assert (
x.full_lock_ref == 0
), f"evict leaf node invalid with {x.id=} {x.full_lock_ref=}"
), f"evict device leaf: full_lock_ref mismatch with {x.id=} {x.full_lock_ref=} {x.mamba_lock_ref=}"

x_next = self.mamba_lru_list.get_prev_no_lock(x)
_, mamba_evicted = self._evict_device_leaf(x)
Expand Down Expand Up @@ -1467,9 +1490,10 @@ def _force_release_pending_storage_ops(self):
logger.exception("Force release pending prefetch ops failed.")

try:
for ack_id, node in list(self.ongoing_backup.items()):
for ack_id, entry in list(self.ongoing_backup.items()):
try:
self._release_host_node(node)
node, mamba_host_protected = entry
self._release_host_node(node, release_mamba=mamba_host_protected)
except Exception:
logger.exception(
"Failed to release host protection for backup op %s", ack_id
Expand Down Expand Up @@ -1521,7 +1545,8 @@ def _drain_backup():
ack_id = operation.id
entry = self.ongoing_backup.pop(ack_id, None)
if entry is not None:
self._release_host_node(entry)
node, mamba_host_protected = entry
self._release_host_node(node, release_mamba=mamba_host_protected)
if log_metrics and self.enable_storage_metrics:
self.storage_metrics_collector.log_backuped_tokens(
operation.completed_tokens
Expand Down Expand Up @@ -1721,8 +1746,9 @@ def write_backup_storage(self, node: TreeNode):
prefix_keys,
extra_pools=extra_pools,
)
self.ongoing_backup[operation_id] = node
self._protect_host_node(node)
mamba_host_protected = extra_pools is not None
self.ongoing_backup[operation_id] = (node, mamba_host_protected)
self._protect_host_node(node, protect_mamba=mamba_host_protected)

def prefetch_from_storage(
self,
Expand All @@ -1743,7 +1769,7 @@ def prefetch_from_storage(
):
return

self._protect_host_node(last_host_node)
self._protect_host_node(last_host_node, protect_mamba=False)

# Allocate host KV memory
host_indices = self._alloc_with_evict(
Expand All @@ -1752,16 +1778,21 @@ def prefetch_from_storage(
self.evict_host,
)
if host_indices is None:
self._release_host_node(last_host_node)
self._release_host_node(last_host_node, release_mamba=False)
return

# Allocate host mamba slot
extra_pools = self.mamba_prefetch_alloc(new_input_tokens, last_hash)
if extra_pools is None:
self.cache_controller.mem_pool_host.free(host_indices)
self._release_host_node(last_host_node)
self._release_host_node(last_host_node, release_mamba=False)
return

# mamba is also being loaded, protect host mamba as well
last_host_node.protect_host_mamba()
if self.mamba_host_lru_list.in_list(last_host_node):
self.mamba_host_lru_list.remove_node(last_host_node)

operation = self.cache_controller.prefetch(
req_id,
host_indices,
Expand Down
16 changes: 14 additions & 2 deletions python/sglang/srt/mem_cache/mamba_radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self, id: Optional[int] = None):

self.hit_count = 0
self.host_ref_counter = 0
self.host_mamba_ref_counter = 0
# store the host indices of KV cache
self.host_value = None
# store hash values of each pages
Expand Down Expand Up @@ -122,16 +123,27 @@ def mamba_backuped(self):
return self.mamba_host_value is not None

def protect_host(self):
"""Protect the host value from eviction."""
"""Protect the host KV value from eviction."""
self.host_ref_counter += 1

def release_host(self):
"""Release the host value, allowing it to be evicted."""
"""Release the host KV value, allowing it to be evicted."""
if self.host_ref_counter > 0:
self.host_ref_counter -= 1
else:
raise RuntimeError("Host reference counter is already zero.")

def protect_host_mamba(self):
"""Protect the host mamba value from eviction."""
self.host_mamba_ref_counter += 1

def release_host_mamba(self):
"""Release the host mamba value, allowing it to be evicted."""
if self.host_mamba_ref_counter > 0:
self.host_mamba_ref_counter -= 1
else:
raise RuntimeError("Host mamba reference counter is already zero.")

def get_last_hash_value(self) -> Optional[str]:
"""Returns the hash value of the last page in this node."""
if self.hash_value is None or len(self.hash_value) == 0:
Expand Down
Loading