diff --git a/python/sglang/srt/mem_cache/hi_mamba_radix_cache.py b/python/sglang/srt/mem_cache/hi_mamba_radix_cache.py index 6130d9502752..fafc6d246e3c 100644 --- a/python/sglang/srt/mem_cache/hi_mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/hi_mamba_radix_cache.py @@ -511,13 +511,22 @@ def check_hicache_events(self): self.cache_controller.storage_backend.get_stats() ) - def _protect_host_node(self, node: TreeNode): + def _protect_host_node(self, node: TreeNode, protect_mamba: bool = True): node.protect_host() self.evictable_full_host_leaves.discard(node) + if protect_mamba: + node.protect_host_mamba() + if self.mamba_host_lru_list.in_list(node): + self.mamba_host_lru_list.remove_node(node) - def _release_host_node(self, node: TreeNode): + def _release_host_node(self, node: TreeNode, release_mamba: bool = True): node.release_host() - if node.host_ref_counter == 0: + if release_mamba: + node.release_host_mamba() + if node.host_mamba_ref_counter == 0 and node.mamba_host_value is not None: + if not self.mamba_host_lru_list.in_list(node): + self.mamba_host_lru_list.insert_mru(node) + if node.host_ref_counter == 0 and node.host_mamba_ref_counter == 0: self._update_full_host_leaf_status(node) def _discard_from_leaf_sets(self, node: TreeNode): @@ -544,6 +553,7 @@ def _update_full_host_leaf_status(self, node: TreeNode): or not node.backuped or node == self.root_node or node.host_ref_counter > 0 + or node.host_mamba_ref_counter > 0 ): self.evictable_full_host_leaves.discard(node) return @@ -632,7 +642,10 @@ def _evict_host_leaf(self, node: TreeNode) -> int: assert node.mamba_value is None, f"has device mamba, {node.id=}" assert ( node.host_ref_counter == 0 - ), f"in use, {node.id=} {node.host_ref_counter=}" + ), f"host kv in use, {node.id=} {node.host_ref_counter=}" + assert ( + node.host_mamba_ref_counter == 0 + ), f"host mamba in use, {node.id=} {node.host_mamba_ref_counter=}" full_num_evicted = self.cache_controller.evict_host(node.host_value) node.host_value = None @@ -665,7 +678,11 @@ def _delete_tombstone_leaf(self, node: TreeNode) -> None: self._discard_from_leaf_sets(node) - if node.backuped and node.host_ref_counter == 0: + if ( + node.backuped + and node.host_ref_counter == 0 + and node.host_mamba_ref_counter == 0 + ): self.cache_controller.evict_host(node.host_value) node.host_value = None @@ -782,15 +799,21 @@ def evict_mamba_host(self, num_mamba_hosts: int) -> int: num_evicted = 0 while num_evicted < num_mamba_hosts and self.mamba_host_lru_list.in_list(x): x_next = self.mamba_host_lru_list.get_prev_no_lock(x) - if x.host_ref_counter > 0: - x = x_next - continue - if x in self.evictable_full_host_leaves: + # Leaf: evictable_full_host_leaves guarantees both counters == 0 + assert ( + x.host_ref_counter == 0 + ), f"evict host leaf: host_ref_counter != 0 with {x.id=} {x.host_ref_counter=}" + assert ( + x.host_mamba_ref_counter == 0 + ), f"evict host leaf: host_mamba_ref_counter != 0 with {x.id=} {x.host_mamba_ref_counter=}" self._evict_host_leaf(x) num_evicted += 1 else: - # internal host node: free host mamba only (tombstone) + # Internal host node + assert ( + x.host_mamba_ref_counter == 0 + ), f"evict host mamba internal: host_mamba_ref_counter != 0 with {x.id=} {x.host_mamba_ref_counter=}" self.mamba_host_lru_list.remove_node(x) self.mamba_pool_host.free(x.mamba_host_value) x.mamba_host_value = None @@ -830,7 +853,7 @@ def evict_mamba(self, mamba_num: int) -> int: # Leaf: evict KV + mamba atomically assert ( x.full_lock_ref == 0 - ), f"evict leaf node invalid with {x.id=} {x.full_lock_ref=}" + ), f"evict device leaf: full_lock_ref mismatch with {x.id=} {x.full_lock_ref=} {x.mamba_lock_ref=}" x_next = self.mamba_lru_list.get_prev_no_lock(x) _, mamba_evicted = self._evict_device_leaf(x) @@ -1467,9 +1490,10 @@ def _force_release_pending_storage_ops(self): logger.exception("Force release pending prefetch ops failed.") try: - for ack_id, node in list(self.ongoing_backup.items()): + for ack_id, entry in list(self.ongoing_backup.items()): try: - self._release_host_node(node) + node, mamba_host_protected = entry + self._release_host_node(node, release_mamba=mamba_host_protected) except Exception: logger.exception( "Failed to release host protection for backup op %s", ack_id @@ -1521,7 +1545,8 @@ def _drain_backup(): ack_id = operation.id entry = self.ongoing_backup.pop(ack_id, None) if entry is not None: - self._release_host_node(entry) + node, mamba_host_protected = entry + self._release_host_node(node, release_mamba=mamba_host_protected) if log_metrics and self.enable_storage_metrics: self.storage_metrics_collector.log_backuped_tokens( operation.completed_tokens @@ -1721,8 +1746,9 @@ def write_backup_storage(self, node: TreeNode): prefix_keys, extra_pools=extra_pools, ) - self.ongoing_backup[operation_id] = node - self._protect_host_node(node) + mamba_host_protected = extra_pools is not None + self.ongoing_backup[operation_id] = (node, mamba_host_protected) + self._protect_host_node(node, protect_mamba=mamba_host_protected) def prefetch_from_storage( self, @@ -1743,7 +1769,7 @@ def prefetch_from_storage( ): return - self._protect_host_node(last_host_node) + self._protect_host_node(last_host_node, protect_mamba=False) # Allocate host KV memory host_indices = self._alloc_with_evict( @@ -1752,16 +1778,21 @@ def prefetch_from_storage( self.evict_host, ) if host_indices is None: - self._release_host_node(last_host_node) + self._release_host_node(last_host_node, release_mamba=False) return # Allocate host mamba slot extra_pools = self.mamba_prefetch_alloc(new_input_tokens, last_hash) if extra_pools is None: self.cache_controller.mem_pool_host.free(host_indices) - self._release_host_node(last_host_node) + self._release_host_node(last_host_node, release_mamba=False) return + # mamba is also being loaded, protect host mamba as well + last_host_node.protect_host_mamba() + if self.mamba_host_lru_list.in_list(last_host_node): + self.mamba_host_lru_list.remove_node(last_host_node) + operation = self.cache_controller.prefetch( req_id, host_indices, diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index d02eb7d9b3d3..d07702cf1efd 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -87,6 +87,7 @@ def __init__(self, id: Optional[int] = None): self.hit_count = 0 self.host_ref_counter = 0 + self.host_mamba_ref_counter = 0 # store the host indices of KV cache self.host_value = None # store hash values of each pages @@ -122,16 +123,27 @@ def mamba_backuped(self): return self.mamba_host_value is not None def protect_host(self): - """Protect the host value from eviction.""" + """Protect the host KV value from eviction.""" self.host_ref_counter += 1 def release_host(self): - """Release the host value, allowing it to be evicted.""" + """Release the host KV value, allowing it to be evicted.""" if self.host_ref_counter > 0: self.host_ref_counter -= 1 else: raise RuntimeError("Host reference counter is already zero.") + def protect_host_mamba(self): + """Protect the host mamba value from eviction.""" + self.host_mamba_ref_counter += 1 + + def release_host_mamba(self): + """Release the host mamba value, allowing it to be evicted.""" + if self.host_mamba_ref_counter > 0: + self.host_mamba_ref_counter -= 1 + else: + raise RuntimeError("Host mamba reference counter is already zero.") + def get_last_hash_value(self) -> Optional[str]: """Returns the hash value of the last page in this node.""" if self.hash_value is None or len(self.hash_value) == 0: