diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 0fedbcf3d99e..ac4baace15e7 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -613,7 +613,7 @@ def _generate_storage_config( pp_size=self.pp_size, is_mla_model=is_mla_backend, enable_storage_metrics=self.enable_storage_metrics, - is_page_first_layout=self.mem_pool_host.layout == "page_first", + layout=self.mem_pool_host.layout, model_name=model_name, tp_lcm_size=tp_lcm_size, should_split_heads=should_split_heads, @@ -827,6 +827,50 @@ def _page_get_zero_copy( inc += self.page_size operation.increment(inc) + @staticmethod + def _count_consecutive_true(results: List[bool]) -> int: + for i, ok in enumerate(results): + if not ok: + return i + return len(results) + + def _kv_get_pages(self, hash_values, host_indices, extra_info=None) -> int: + if self.storage_backend_type == "mooncake": + results = self.storage_backend.batch_get_v1( + hash_values, host_indices, extra_info + ) + return self._count_consecutive_true(results) + + dummy_page_dst = [ + self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values + ] + page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst) + if page_data is None: + return 0 + success_pages = 0 + for i in range(len(hash_values)): + if page_data[i] is None: + break + self.mem_pool_host.set_from_flat_data_page( + host_indices[i * self.page_size], page_data[i] + ) + success_pages += 1 + return success_pages + + def _kv_set_pages(self, hash_values, host_indices, extra_info=None) -> bool: + if self.storage_backend_type == "mooncake": + return all( + self.storage_backend.batch_set_v1(hash_values, host_indices, extra_info) + ) + data = [ + self.mem_pool_host.get_data_page(host_indices[i * self.page_size]) + for i in range(len(hash_values)) + ] + return self.storage_backend.batch_set(hash_values, data) + + def _storage_hit_page_num(self, batch_hashes, extra_info=None) -> int: + return self.storage_backend.batch_exists(batch_hashes, extra_info) + # todo: deprecate def _generic_page_get(self, operation, hash_values, host_indices, extra_info=None): dummy_page_dst = [ @@ -922,7 +966,7 @@ def _storage_hit_query(self, operation) -> tuple[list[str], int]: ) batch_hashes.append(last_hash) extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys) - hit_page_num = self.storage_backend.batch_exists(batch_hashes, extra_info) + hit_page_num = self._storage_hit_page_num(batch_hashes, extra_info) hash_value.extend(batch_hashes[:hit_page_num]) storage_query_count += hit_page_num * self.page_size if hit_page_num < len(batch_hashes): diff --git a/python/sglang/srt/mem_cache/hi_mamba_radix_cache.py b/python/sglang/srt/mem_cache/hi_mamba_radix_cache.py index 7ce972a85b6f..8a65ab9fd6d5 100644 --- a/python/sglang/srt/mem_cache/hi_mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/hi_mamba_radix_cache.py @@ -12,23 +12,31 @@ import torch -from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation from sglang.srt.mem_cache.base_prefix_cache import ( - DecLockRefParams, - DecLockRefResult, EvictParams, EvictResult, - IncLockRefResult, MatchPrefixParams, MatchResult, ) +from sglang.srt.mem_cache.hicache_storage import PoolTransfer, PoolTransferResult +from sglang.srt.mem_cache.hybrid_cache.hybrid_cache_controller import ( + HybridCacheController, + PrefetchOperation, +) from sglang.srt.mem_cache.mamba_radix_cache import ( + LRUList, MambaRadixCache, TreeNode, get_last_access_time, ) -from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool -from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost +from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, HybridReqToTokenPool +from sglang.srt.mem_cache.memory_pool_host import ( + HostPoolGroup, + MambaPoolHost, + MHATokenToKVPoolHost, + MLATokenToKVPoolHost, + PoolEntry, +) from sglang.srt.mem_cache.radix_cache import ( RadixKey, compute_node_hash_values, @@ -44,6 +52,41 @@ logger = logging.getLogger(__name__) +class HostLRUList(LRUList): + def __init__(self): + super().__init__(mamba=True) + self.prv = "host_mamba_prev" + self.nxt = "host_mamba_next" + setattr(self.head, self.nxt, self.tail) + setattr(self.tail, self.prv, self.head) + + def reset_node_mru(self, node): + assert node.id in self.cache, f"Resetting node {node.id=} not in host mamba lru" + assert ( + node.mamba_host_value is not None + ), f"Resetting host mamba tombstone node in lru list: {node.id=}" + self._remove_node(node) + self._add_node(node) + + def insert_mru(self, node): + assert ( + node.mamba_host_value is not None + ), f"Inserting host mamba tombstone node in lru list: {node.id=}" + assert ( + node.id not in self.cache + ), f"Inserting node {node.id=} already in host mamba lru list" + self.cache[node.id] = node + self._add_node(node) + + def remove_node(self, node: TreeNode): + assert node.id in self.cache, f"Removing node {node.id=} not in host mamba lru" + assert ( + node.mamba_host_value is not None + ), f"Removing host mamba tombstone node from lru list: {node.id=}" + del self.cache[node.id] + self._remove_node(node) + + class HiMambaRadixCache(MambaRadixCache): """Hierarchical cache for hybrid Mamba models. @@ -65,13 +108,33 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): bind_to_closest_numa_node_cuda() self.page_size = params.page_size - kvcache = params.token_to_kv_pool_allocator.get_kvcache() - - if isinstance(kvcache, HybridLinearKVPool): - kvcache = kvcache.full_kv_pool - self.kvcache = kvcache + self.hybrid_kv_cache = params.token_to_kv_pool_allocator.get_kvcache() + if not isinstance(self.hybrid_kv_cache, HybridLinearKVPool): + raise ValueError( + "HiMambaRadixCache requires HybridLinearKVPool for hybrid SSM models." + ) + if not isinstance(params.req_to_token_pool, HybridReqToTokenPool): + raise ValueError( + "HiMambaRadixCache requires HybridReqToTokenPool for hybrid SSM models." + ) - self.full_kv_pool_host = MHATokenToKVPoolHost( + self.hybrid_model_layer_ids: list[int] = [] + full_layer_ids = sorted( + self.hybrid_kv_cache.full_attention_layer_id_mapping.keys() + ) + mamba_layer_ids = sorted(params.req_to_token_pool.mamba_map.keys()) + self.hybrid_model_layer_ids = sorted(set(full_layer_ids) | set(mamba_layer_ids)) + self.transfer_layer_num = len(self.hybrid_model_layer_ids) + self.hybrid_kv_cache.set_model_layer_id_mapping(self.hybrid_model_layer_ids) + params.req_to_token_pool.set_model_layer_id_mapping(self.hybrid_model_layer_ids) + + self.kvcache = self.hybrid_kv_cache.full_kv_pool + kv_host_pool_cls = ( + MLATokenToKVPoolHost + if self.hybrid_kv_cache.use_mla + else MHATokenToKVPoolHost + ) + self.full_kv_pool_host = kv_host_pool_cls( self.kvcache, server_args.hicache_ratio, server_args.hicache_size, @@ -79,6 +142,49 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): server_args.hicache_mem_layout, allocator_type=server_args.hicache_storage_backend, ) + self.mamba_pool_host = MambaPoolHost( + params.req_to_token_pool.mamba_pool, + server_args.hicache_ratio, + server_args.hicache_size, + allocator_type=server_args.hicache_storage_backend, + ) + + full_layer_mapping = dict(self.hybrid_kv_cache.full_attention_layer_id_mapping) + mamba_layer_mapping = dict(params.req_to_token_pool.mamba_map) + + def kv_layer_mapper(model_layer_local_id: int) -> Optional[int]: + if not 0 <= model_layer_local_id < len(self.hybrid_model_layer_ids): + return None + return full_layer_mapping.get( + self.hybrid_model_layer_ids[model_layer_local_id] + ) + + def mamba_layer_mapper(model_layer_local_id: int) -> Optional[int]: + if not 0 <= model_layer_local_id < len(self.hybrid_model_layer_ids): + return None + return mamba_layer_mapping.get( + self.hybrid_model_layer_ids[model_layer_local_id] + ) + + self.host_pool_group = HostPoolGroup( + [ + PoolEntry( + name="kv", + host_pool=self.full_kv_pool_host, + device_pool=self.kvcache, + layer_mapper=kv_layer_mapper, + is_primary_index_anchor=True, + ), + PoolEntry( + name="mamba", + host_pool=self.mamba_pool_host, + device_pool=params.req_to_token_pool.mamba_pool, + layer_mapper=mamba_layer_mapper, + host_evict_fn=self.evict_mamba_host, + device_evict_fn=self.evict_mamba, + ), + ] + ) self.tp_group = params.tp_cache_group self.tp_world_size = ( @@ -104,9 +210,9 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): self.prefetch_stop_policy = server_args.hicache_storage_prefetch_policy self.load_cache_event = threading.Event() - self.cache_controller = HiCacheController( + self.cache_controller = HybridCacheController( params.token_to_kv_pool_allocator, - self.full_kv_pool_host, + self.host_pool_group, params.page_size, self.tp_group, load_cache_event=self.load_cache_event, @@ -118,6 +224,10 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): storage_backend_extra_config=extra_config, pp_rank=params.pp_rank, pp_size=params.pp_size, + transfer_layer_num=self.transfer_layer_num, + ) + params.req_to_token_pool.register_layer_transfer_counter( + self.cache_controller.layer_done_counter ) self._apply_storage_runtime_config( storage_backend=server_args.hicache_storage_backend, @@ -145,6 +255,7 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): self.evictable_full_device_leaves: set[TreeNode] = set() self.evictable_full_host_leaves: set[TreeNode] = set() + self.mamba_host_lru_list = HostLRUList() # Detach storage backend automatically on process shutdown atexit.register(self.shutdown) @@ -153,8 +264,10 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): def reset(self) -> None: TreeNode.counter = 0 + self._flush_pending_storage_backups_before_reset() self.cache_controller.reset() self.full_kv_pool_host.clear() + self.mamba_pool_host.clear() self.ongoing_write_through = {} self.ongoing_load_back = {} self.ongoing_prefetch = {} @@ -162,33 +275,60 @@ def reset(self) -> None: self.prefetch_loaded_tokens_by_reqid.clear() self.evictable_full_device_leaves.clear() self.evictable_full_host_leaves.clear() + self.mamba_host_lru_list = HostLRUList() + logger.info( + "HiMambaRadixCache reset completed: host_kv_available=%s host_mamba_available=%s", + self.full_kv_pool_host.available_size(), + self.mamba_pool_host.available_size(), + ) super().reset() def write_backup(self, node: TreeNode, write_back=False): + # If mamba host slot already exists, refresh its LRU position. + if node.mamba_value is not None and node.mamba_host_value is not None: + if self.mamba_host_lru_list.in_list(node): + self.mamba_host_lru_list.reset_node_mru(node) + + extra_pools = self.backup_transfers(node) host_indices = self.cache_controller.write( device_indices=node.value, node_id=node.id, + extra_pools=extra_pools, ) if host_indices is None: self.evict_host(len(node.value)) host_indices = self.cache_controller.write( device_indices=node.value, node_id=node.id, + extra_pools=extra_pools, ) if host_indices is not None: node.host_value = host_indices + if extra_pools is not None: + self.backup_commit(node, extra_pools) assert len(node.host_value) > 0 self.ongoing_write_through[node.id] = node if not write_back: # no need to lock nodes if write back self.inc_lock_ref(node) + if extra_pools is not None: + logger.info( + "HiCache mamba offload prepared for node %s: kv_tokens=%s mamba_states=%s", + node.id, + len(node.host_value), + ( + len(node.mamba_host_value) + if node.mamba_host_value is not None + else 0 + ), + ) else: return 0 return len(host_indices) def load_back( - self, node: TreeNode, mem_quota: Optional[int] = None + self, node: TreeNode, mem_quota: Optional[int] = None, req=None ) -> Optional[torch.Tensor]: """Load full KV back from host.""" last_hit_node = node @@ -201,34 +341,52 @@ def load_back( else: ancestor_node = node - result = self.inc_lock_ref(ancestor_node) - delta = result.delta + mamba_restore_nodes = [] + if last_hit_node.mamba_backuped and last_hit_node.mamba_evicted: + mamba_restore_nodes.append(last_hit_node) - full_host_indices = torch.cat([n.host_value for n in nodes_to_load]) - if (len(full_host_indices) < self.load_back_threshold) or ( - len(full_host_indices) > mem_quota + delta - if mem_quota is not None - else False + delta = self.inc_lock_ref(ancestor_node) + + if nodes_to_load: + full_host_indices = torch.cat([n.host_value for n in nodes_to_load]) + else: + full_host_indices = torch.empty((0,), dtype=torch.int64, device="cpu") + + if len(full_host_indices) > 0 and ( + (len(full_host_indices) < self.load_back_threshold) + or ( + len(full_host_indices) > mem_quota + delta + if mem_quota is not None + else False + ) ): # skip loading back if the total size is too small or exceeding the memory quota self.dec_lock_ref(ancestor_node) return None + mamba_pools = self.restore_transfers(last_hit_node, mamba_restore_nodes, req) full_device_indices = self.cache_controller.load( host_indices=full_host_indices, node_id=last_hit_node.id, + extra_pools=mamba_pools, ) if full_device_indices is None: self.evict(EvictParams(num_tokens=len(full_host_indices))) + mamba_pools = self.restore_transfers( + last_hit_node, mamba_restore_nodes, req + ) full_device_indices = self.cache_controller.load( host_indices=full_host_indices, node_id=last_hit_node.id, + extra_pools=mamba_pools, ) self.dec_lock_ref(ancestor_node) if full_device_indices is None: # no sufficient GPU memory to load back KV caches return None + self.restore_commit(mamba_restore_nodes, mamba_pools) + offset = 0 for n in nodes_to_load: n_len = len(n.host_value) @@ -237,16 +395,15 @@ def load_back( self.full_lru_list.insert_mru(n) self.full_evictable_size_ += n_len - - if n.mamba_value is not None: - if self.mamba_lru_list.in_list(n): - self.mamba_lru_list.reset_node_mru(n) - else: - self.mamba_lru_list.insert_mru(n) - self.mamba_evictable_size_ += len(n.mamba_value) - self._update_leaf_status(n) + for n in mamba_restore_nodes: + if self.mamba_lru_list.in_list(n): + self.mamba_lru_list.reset_node_mru(n) + else: + self.mamba_lru_list.insert_mru(n) + self.mamba_evictable_size_ += len(n.mamba_value) + self._update_leaf_status(ancestor_node) self.inc_lock_ref(last_hit_node) @@ -259,16 +416,25 @@ def init_load_back( last_node: TreeNode, host_hit_length: int, mem_quota: Optional[int] = None, + req=None, ): - if last_node.evicted: - loading_values = self.load_back(last_node, mem_quota) + if last_node.evicted or ( + last_node.mamba_value is None and last_node.mamba_host_value is not None + ): + loading_values = self.load_back(last_node, mem_quota, req=req) if loading_values is not None: logger.debug( f"loading back {len(loading_values)} tokens for node {last_node.id}" ) return loading_values, last_node - while last_node.evicted: + while last_node is not self.root_node and ( + last_node.evicted + or ( + last_node.mamba_value is None + and last_node.mamba_host_value is not None + ) + ): last_node = last_node.parent return ( @@ -343,9 +509,6 @@ def loading_check(self): def ready_to_load_host_cache(self) -> int: return self.cache_controller.start_loading() - def flush_write_through_acks(self) -> None: - self.writing_check() - def check_hicache_events(self): self.writing_check() self.loading_check() @@ -399,28 +562,119 @@ def _update_full_host_leaf_status(self, node: TreeNode): return self.evictable_full_host_leaves.add(node) - def _evict_to_host(self, node: TreeNode) -> int: - """Evict full KV to host. Mamba stays on device. Returns num evicted.""" - num_full = len(node.value) + def _free_gpu_mamba(self, node: TreeNode) -> int: + """Free GPU mamba on a node. Returns num mamba tokens freed.""" + if node.mamba_value is None: + return 0 + mamba_num = len(node.mamba_value) + self.req_to_token_pool.mamba_pool.free(node.mamba_value) + if node.mamba_lock_ref > 0: + self.mamba_protected_size_ -= mamba_num + node.mamba_lock_ref = 0 + else: + self.mamba_evictable_size_ -= mamba_num + if self.mamba_lru_list.in_list(node): + self.mamba_lru_list.remove_node(node) + node.mamba_value = None + return mamba_num + + def _evict_to_host(self, node: TreeNode) -> Tuple[int, int]: + """Evict a backuped device node to host: free GPU KV + GPU mamba. + + Node stays in the tree as evicted+backuped. + Caller must ensure node.backuped is True before calling. + Returns (full_num_evicted, mamba_num_evicted). + """ + assert not node.evicted, f"_evict_to_host on already-evicted node, {node.id=}" + assert node.backuped, f"_evict_to_host on non-backuped node, {node.id=}" - if not node.backuped: - if self.cache_controller.write_policy == "write_back": - self.write_backup(node, write_back=True) - self.writing_check(write_back=True) + num_full = len(node.value) self.cache_controller.evict_device(node.value) self.full_evictable_size_ -= num_full if self.full_lru_list.in_list(node): self.full_lru_list.remove_node(node) + mamba_num = self._free_gpu_mamba(node) + node.value = None self._update_leaf_status(node) self._update_full_device_leaf_status(node.parent) - return num_full + return num_full, mamba_num + + def _evict_regular_leaf(self, node: TreeNode) -> Tuple[int, int]: + """Evict a non-backuped device leaf: free GPU KV + mamba, delete from tree. + + Used in write_through when a node was never backed up. + Returns (full_num_evicted, mamba_num_evicted). + """ + assert not node.evicted, f"_evict_regular on already-evicted node, {node.id=}" + assert not node.backuped, f"_evict_regular on backuped node, {node.id=}" + assert len(node.children) == 0, f"_evict_regular on non-leaf, {node.id=}" + + num_full = len(node.value) + + self.cache_controller.evict_device(node.value) + self.full_evictable_size_ -= num_full + if self.full_lru_list.in_list(node): + self.full_lru_list.remove_node(node) + + mamba_num = self._free_gpu_mamba(node) + + if node.mamba_host_value is not None: + if self.mamba_host_lru_list.in_list(node): + self.mamba_host_lru_list.remove_node(node) + self.mamba_pool_host.free(node.mamba_host_value) + node.mamba_host_value = None + + node.value = None + self._discard_from_leaf_sets(node) + + parent = node.parent + key = self.get_child_key_fn(node.key) + v = parent.children.pop(key, None) + assert v == node, f"parent does not have child key, {key}" + + self._update_leaf_status(parent) + self._iteratively_delete_tombstone_leaf(node) + return num_full, mamba_num + + def _evict_host_leaf_node(self, x: TreeNode) -> int: + """Evict a host-resident leaf: free host KV + host mamba, delete from tree, cascade. + + Returns num host KV tokens evicted. + """ + assert x.evicted, f"host leaf not evicted, {x.id=}" + assert x.backuped, f"host leaf not backuped, {x.id=}" + assert x.mamba_value is None, f"host leaf has GPU mamba, {x.id=}" + assert ( + x.host_ref_counter == 0 + ), f"host leaf in use, {x.id=} {x.host_ref_counter=}" + + full_num_evicted = self.cache_controller.evict_host(x.host_value) + x.host_value = None + + if x.mamba_host_value is not None: + if self.mamba_host_lru_list.in_list(x): + self.mamba_host_lru_list.remove_node(x) + self.mamba_pool_host.free(x.mamba_host_value) + x.mamba_host_value = None + + self._discard_from_leaf_sets(x) + parent = x.parent + key = self.get_child_key_fn(x.key) + v = parent.children.pop(key, None) + assert v == x, f"parent does not have child key, {key}" + + self._update_leaf_status(parent) + self._iteratively_delete_tombstone_leaf(x) + + return full_num_evicted def _delete_tombstone_leaf(self, node: TreeNode) -> None: """Remove a tombstone leaf from the tree and free HiCache resources.""" assert node.mamba_value is None, f"node has mamba value, {node.id=}" + assert node.mamba_host_value is None, f"node has mamba host value, {node.id=}" assert len(node.children) == 0, f"leaf node has children, {node.id=}" parent = node.parent key = self.get_child_key_fn(node.key) @@ -449,6 +703,8 @@ def _iteratively_delete_tombstone_leaf( break if node.parent.mamba_value is not None: break + if node.parent.mamba_host_value is not None: + break if node.parent.full_lock_ref > 0 or node.parent.mamba_lock_ref > 0: break @@ -467,6 +723,22 @@ def _iteratively_delete_tombstone_leaf( return node, full_num_evicted, mamba_num_evicted + def _evict_device_leaf(self, x: TreeNode) -> Tuple[int, int]: + """Evict a device leaf node, choosing the right strategy: + + - backuped: demote to host via _evict_to_host (node stays in tree) + - not backuped + write_back: write_backup first, then demote + - not backuped + write_through: _evict_regular_leaf (delete from tree) + """ + if not x.backuped: + if self.cache_controller.write_policy == "write_back": + self.write_backup(x, write_back=True) + self.writing_check(write_back=True) + return self._evict_to_host(x) + else: + return self._evict_regular_leaf(x) + return self._evict_to_host(x) + def evict(self, params: EvictParams) -> EvictResult: if self.disable: return EvictResult() @@ -485,15 +757,16 @@ def evict(self, params: EvictParams) -> EvictResult: if x not in self.evictable_full_device_leaves: continue - evicted_full = self._evict_to_host(x) + evicted_full, evicted_mamba = self._evict_device_leaf(x) full_num_evicted += evicted_full + mamba_num_evicted += evicted_mamba parent = x.parent if parent in self.evictable_full_device_leaves: heapq.heappush(eviction_heap, (parent.last_access_time, parent)) if params.mamba_num > 0: - mamba_num_evicted = self.evict_mamba(params.mamba_num) + mamba_num_evicted += self.evict_mamba(params.mamba_num) return EvictResult( num_tokens_evicted=full_num_evicted, @@ -501,27 +774,7 @@ def evict(self, params: EvictParams) -> EvictResult: ) def evict_host(self, num_tokens: int): - if self.enable_storage: - host_leaves = list(self.evictable_full_host_leaves) - heap = [(n.last_access_time, n) for n in host_leaves] - heapq.heapify(heap) - - num_evicted = 0 - while num_evicted < num_tokens and heap: - _, x = heapq.heappop(heap) - if x not in self.evictable_full_host_leaves: - continue - - num_evicted += self.cache_controller.evict_host(x.host_value) - x.host_value = None - - self.evictable_full_host_leaves.discard(x) - self._update_full_host_leaf_status(x.parent) - if x.parent in self.evictable_full_host_leaves: - heapq.heappush(heap, (x.parent.last_access_time, x.parent)) - return - - # Non-L3 path: evict host leaves and clean up tree + """Evict host-resident leaf nodes: free host KV + mamba, delete from tree, cascade.""" heap = [(n.last_access_time, n) for n in self.evictable_full_host_leaves] heapq.heapify(heap) @@ -531,28 +784,50 @@ def evict_host(self, num_tokens: int): if x not in self.evictable_full_host_leaves: continue - num_evicted += self.cache_controller.evict_host(x.host_value) - x.host_value = None + num_evicted += self._evict_host_leaf_node(x) + + if x.parent in self.evictable_full_host_leaves: + heapq.heappush(heap, (x.parent.last_access_time, x.parent)) + + def evict_mamba_host(self, num_mamba_hosts: int) -> int: + """Evict host mamba states. - if x.mamba_value is not None: - if self.mamba_lru_list.in_list(x): - self.mamba_lru_list.remove_node(x) - self.mamba_evictable_size_ -= len(x.mamba_value) - self.req_to_token_pool.mamba_pool.free(x.mamba_value) - x.mamba_value = None + Internal host node: free host mamba only (tombstone). + Host leaf node: same as Full host evict — _evict_host_leaf_node frees + host KV + mamba, deletes from tree, cascades. + """ + if self.disable or num_mamba_hosts <= 0: + return 0 - self.evictable_full_host_leaves.discard(x) + x = self.mamba_host_lru_list.get_lru_no_lock() + num_evicted = 0 + while num_evicted < num_mamba_hosts and self.mamba_host_lru_list.in_list(x): + x_next = self.mamba_host_lru_list.get_prev_no_lock(x) + if x.host_ref_counter > 0: + x = x_next + continue - parent = x.parent - child_key = self.get_child_key_fn(x.key) - v = parent.children.pop(child_key, None) - assert v == x, f"parent does not have child key, {x.id=}" + if x in self.evictable_full_host_leaves: + # Host leaf: evict host KV + mamba, delete from tree + self._evict_host_leaf_node(x) + num_evicted += 1 + else: + # Internal host node: free host mamba only (tombstone) + self.mamba_host_lru_list.remove_node(x) + self.mamba_pool_host.free(x.mamba_host_value) + x.mamba_host_value = None + num_evicted += 1 - self._update_leaf_status(parent) - if parent in self.evictable_full_host_leaves: - heapq.heappush(heap, (parent.last_access_time, parent)) + x = x_next + return num_evicted def evict_mamba(self, mamba_num: int) -> int: + """Evict mamba states. + + Internal node: tombstone — free GPU mamba only, KV stays on GPU. + Leaf node: same as Full evict — _evict_to_host moves KV+mamba to host, + node stays in tree, then cascade tombstone parent device leaves. + """ if self.disable or mamba_num <= 0: return 0 @@ -562,64 +837,25 @@ def evict_mamba(self, mamba_num: int) -> int: assert x.mamba_value is not None, f"node has no mamba value, {x.id=}" assert x != self.root_node, f"root node is not evictable, {x.id=}" assert x.mamba_lock_ref == 0, f"node is in use, {x.id=}" + assert ( + not x.evicted + ), f"evicted node should not be in mamba_lru_list, {x.id=}" - if x.evicted: - self.req_to_token_pool.mamba_pool.free(x.mamba_value) - mamba_num_evicted += len(x.mamba_value) + if len(x.children) > 0: + # Internal: free GPU mamba only, KV stays on GPU (tombstone) x_next = self.mamba_lru_list.get_prev_no_lock(x) - self.mamba_lru_list.remove_node(x) - self.mamba_evictable_size_ -= len(x.mamba_value) - x.mamba_value = None - - if len(x.children) == 0: - self._delete_tombstone_leaf(x) - _, _, cascade_mamba = self._iteratively_delete_tombstone_leaf(x) - mamba_num_evicted += cascade_mamba - elif len(x.children) > 0: - self.req_to_token_pool.mamba_pool.free(x.mamba_value) mamba_num_evicted += len(x.mamba_value) - x_next = self.mamba_lru_list.get_prev_no_lock(x) self.mamba_lru_list.remove_node(x) self._tombstone_internal_node(x) else: + # Leaf: evict KV + mamba atomically assert ( x.full_lock_ref == 0 ), f"evict leaf node invalid with {x.id=} {x.full_lock_ref=}" - if not x.backuped: - if self.cache_controller.write_policy == "write_back": - self.write_backup(x, write_back=True) - self.writing_check(write_back=True) - - self.cache_controller.evict_device(x.value) - self.full_evictable_size_ -= len(x.value) - - self.req_to_token_pool.mamba_pool.free(x.mamba_value) - mamba_num_evicted += len(x.mamba_value) - x_next = self.mamba_lru_list.get_prev_no_lock(x) - if self.full_lru_list.in_list(x): - self.full_lru_list.remove_node(x) - self.mamba_lru_list.remove_node(x) - self.mamba_evictable_size_ -= len(x.mamba_value) - - if x.backuped: - self.cache_controller.evict_host(x.host_value) - x.host_value = None - - x.value = None - x.mamba_value = None - - self._discard_from_leaf_sets(x) - - parent = x.parent - child_key = self.get_child_key_fn(x.key) - v = parent.children.pop(child_key, None) - assert v == x, f"parent does not have child key, {x.id=}" - - self._update_leaf_status(parent) - _, _, cascade_mamba = self._iteratively_delete_tombstone_leaf(x) - mamba_num_evicted += cascade_mamba + _, mamba_evicted = self._evict_device_leaf(x) + mamba_num_evicted += mamba_evicted if not self.mamba_lru_list.in_list(x_next): x_next = self.mamba_lru_list.get_lru_no_lock() @@ -630,6 +866,11 @@ def evict_mamba(self, mamba_num: int) -> int: def _unevict_node(self, node: TreeNode, fresh_value: torch.Tensor): """Restore an evicted node with fresh device KV from the request.""" + assert node.evicted, f"_unevict_node on non-evicted node, {node.id=}" + # invariant: evicted => no GPU mamba + assert ( + node.mamba_value is None + ), f"evicted node should not have GPU mamba, {node.id=}" n = len(fresh_value) node.value = fresh_value.clone() @@ -654,7 +895,6 @@ def _insert_helper( value, mamba_value, chunked: bool = False, - prev_prefix_len: int = 0, ) -> Tuple[int, bool]: assert mamba_value is not None, "Mamba value should not be None here." node.last_access_time = get_last_access_time() @@ -685,11 +925,12 @@ def _insert_helper( node = new_node if node.evicted: + # Unevicted nodes take ownership of the request's KV pages. + # Do NOT count them in total_prefix_length, otherwise + # cache_finished_req / cache_unfinished_req will free those + # pages even though the tree now references them. self._unevict_node(node, value[:prefix_len]) else: - if prev_prefix_len < total_prefix_length + prefix_len: - start = max(0, prev_prefix_len - total_prefix_length) - self.token_to_kv_pool_allocator.free(value[start:prefix_len]) total_prefix_length += prefix_len self._inc_hit_count(node, chunked) @@ -783,7 +1024,7 @@ def _match_prefix_helper( if child.evicted and not child.backuped: break - if node.mamba_value is not None: + if node.mamba_value is not None or node.mamba_backuped: best_value_len = len(value) best_last_node = node @@ -802,11 +1043,11 @@ def _match_prefix_helper( if len(key): child_key = self.get_child_key_fn(key) - if node.mamba_value is not None: + if node.mamba_value is not None or node.mamba_backuped: best_value_len = len(value) best_last_node = node - deepest_node = node + deepest_node = best_last_node return value, best_last_node, best_value_len, deepest_node def _match_post_processor( @@ -849,33 +1090,26 @@ def _match_post_processor( else: mamba_branching_seqlen = None - # last_device_node & host_hit_length: from best_last_node (mamba boundary) host_hit_length = 0 last_device_node = best_last_node - while last_device_node.evicted: + while last_device_node is not self.root_node and last_device_node.evicted: host_hit_length += len(last_device_node.host_value) last_device_node = last_device_node.parent last_host_node = best_last_node - - # last_host_backup_node: from deepest_node, find backuped ancestor - last_host_backup_node = deepest_node - while ( - last_host_backup_node is not self.root_node - and not last_host_backup_node.backuped - ): - last_host_backup_node = last_host_backup_node.parent + while last_host_node is not self.root_node and not last_host_node.backuped: + last_host_node = last_host_node.parent mamba_node = best_last_node if cow_mamba and mamba_node.mamba_value is not None: if req.mamba_pool_idx is None: - dst_index = self.req_to_token_pool.mamba_pool.alloc(1) - if dst_index is None: - self.inc_lock_ref(mamba_node) - self.evict_mamba(1) - dst_index = self.req_to_token_pool.mamba_pool.alloc(1) - self.dec_lock_ref(mamba_node) - assert dst_index is not None, "Can not alloc mamba cache" + dst_index = self._alloc_with_evict( + self.req_to_token_pool.mamba_pool, + 1, + self.evict_mamba, + lock_node=mamba_node, + error_message="Can not alloc mamba cache", + ) src_index = mamba_node.mamba_value self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index) req.mamba_pool_idx = dst_index[0] @@ -894,7 +1128,7 @@ def _match_post_processor( device_indices=value, last_device_node=last_device_node, last_host_node=last_host_node, - last_host_backup_node=last_host_backup_node, + last_host_backup_node=last_host_node, host_hit_length=host_hit_length, mamba_branching_seqlen=mamba_branching_seqlen, ) @@ -995,9 +1229,9 @@ def sanity_check(self): return super().sanity_check() - def inc_lock_ref(self, node: TreeNode) -> IncLockRefResult: + def inc_lock_ref(self, node: TreeNode) -> Optional[int]: if self.disable: - return IncLockRefResult(delta=0) + return 0 delta = 0 if node.mamba_value is not None: @@ -1021,13 +1255,11 @@ def inc_lock_ref(self, node: TreeNode) -> IncLockRefResult: self.evictable_full_device_leaves.discard(node) node.full_lock_ref += 1 node = node.parent - return IncLockRefResult(delta=delta) + return delta - def dec_lock_ref( - self, node: TreeNode, params: Optional[DecLockRefParams] = None - ) -> DecLockRefResult: + def dec_lock_ref(self, node: TreeNode): if self.disable: - return DecLockRefResult(delta=0) + return 0 delta = 0 @@ -1053,7 +1285,7 @@ def dec_lock_ref( if node.full_lock_ref == 0: self._update_full_device_leaf_status(node) node = node.parent - return DecLockRefResult(delta=delta) + return delta # ---- L3 Support ---- @@ -1183,6 +1415,7 @@ def attach_storage_backend( prefetch_threshold=prefetch_threshold, model_name=served_model_name, storage_backend_extra_config=extra_config, + host_pools=self.host_pool_group.entries, ) except Exception as e: logger.exception( @@ -1219,6 +1452,14 @@ def detach_storage_backend(self) -> tuple: self.storage_metrics_collector = None return True, "Detached HiCache storage backend successfully." + def prefetch_abort(self, pool_transfers: Optional[list[PoolTransfer]]) -> None: + """Free any allocated mamba host slots on prefetch abort/revoke.""" + for transfer in pool_transfers or []: + if transfer.name == "mamba": + if transfer.host_indices is not None: + self.mamba_pool_host.free(transfer.host_indices) + break + def _force_release_pending_storage_ops(self): cc = self.cache_controller @@ -1236,6 +1477,12 @@ def _force_release_pending_storage_ops(self): logger.exception( "Failed to free host indices for prefetch %s", req_id ) + try: + self.prefetch_abort(getattr(_operation, "pool_transfers", None)) + except Exception: + logger.exception( + "Failed to release mamba host indices for prefetch %s", req_id + ) try: self._release_host_node(last_host_node) except Exception: @@ -1295,7 +1542,8 @@ def _drain_revoke(): for req_id in _drain_queue(cc.prefetch_revoke_queue, n_revoke): info = self.ongoing_prefetch.pop(req_id, None) if info is not None: - last_host_node, token_ids, _, _ = info + last_host_node, token_ids, _, operation = info + self.prefetch_abort(operation.pool_transfers) self._release_host_node(last_host_node) cc.prefetch_tokens_occupied -= len(token_ids) if cc.prefetch_tokens_occupied < 0: @@ -1498,12 +1746,13 @@ def write_backup_storage(self, node: TreeNode): if self.hicache_storage_pass_prefix_keys else None ) - + extra_pools = self.archive_transfers(node) operation_id = self.cache_controller.write_storage( node.host_value, node.key, node.hash_value, prefix_keys, + extra_pools=extra_pools, ) self.ongoing_backup[operation_id] = node self._protect_host_node(node) @@ -1528,15 +1777,36 @@ def prefetch_from_storage( return self._protect_host_node(last_host_node) - host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) - if host_indices is None: - self.evict_host(prefetch_length) - host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) + host_indices = self._alloc_with_evict( + self.cache_controller.mem_pool_host, + prefetch_length, + self.evict_host, + ) if host_indices is None: self._release_host_node(last_host_node) return + extra_pools = None + prepared = self.prefetch_prepare(new_input_tokens, last_hash) + if prefetch_length > 0 and prepared is None: + self.cache_controller.mem_pool_host.free(host_indices) + self._release_host_node(last_host_node) + return + if prepared is not None: + _, extra_pools = prepared + if extra_pools is not None: + logger.info( + "HiCache mamba prefetch scheduled for request %s: kv_hit_pages=%s mamba_states=%s", + req_id, + prefetch_length // self.page_size, + 1, + ) operation = self.cache_controller.prefetch( - req_id, host_indices, new_input_tokens, last_hash, prefix_keys + req_id, + host_indices, + new_input_tokens, + last_hash, + prefix_keys, + extra_pools=extra_pools, ) self.ongoing_prefetch[req_id] = ( last_host_node, @@ -1577,6 +1847,11 @@ def check_prefetch_progress(self, req_id: str) -> bool: min_completed_tokens = completed_tokens_tensor.item() fetched_token_ids = token_ids[:min_completed_tokens] written_indices = host_indices[:min_completed_tokens] + mamba_host_indices = None + for transfer in operation.pool_transfers or []: + if transfer.name == "mamba": + mamba_host_indices = transfer.host_indices + break matched_length = self._insert_helper_host( last_host_node, RadixKey( @@ -1585,9 +1860,17 @@ def check_prefetch_progress(self, req_id: str) -> bool: ), written_indices, hash_value[: min_completed_tokens // self.page_size], + mamba_host_indices, + operation.pool_storage_result, ) self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) + mamba_loaded = self.prefetch_commit( + operation.pool_transfers, + matched_length, + min_completed_tokens, + operation.pool_storage_result, + ) self.cache_controller.append_host_mem_release( host_indices[min_completed_tokens:completed_tokens] ) @@ -1600,11 +1883,24 @@ def check_prefetch_progress(self, req_id: str) -> bool: if self.enable_storage_metrics: self.storage_metrics_collector.log_prefetched_tokens(loaded_from_storage) + if loaded_from_storage > 0 and operation.pool_transfers: + logger.info( + "HiCache mamba prefetch completed for request %s: prefetched_tokens=%s mamba_states=%s", + req_id, + loaded_from_storage, + int(mamba_loaded), + ) return True def _insert_helper_host( - self, node: TreeNode, key: RadixKey, host_value, hash_value + self, + node: TreeNode, + key: RadixKey, + host_value, + hash_value, + mamba_host_value: Optional[torch.Tensor] = None, + pool_storage_result: Optional[PoolTransferResult] = None, ): node.last_access_time = get_last_access_time() if len(key) == 0: @@ -1614,6 +1910,10 @@ def _insert_helper_host( matched_length = 0 host_value_inserted = False + final_mamba_node: Optional[TreeNode] = None + has_mamba = pool_storage_result is None or ( + pool_storage_result.extra_pool_hit_pages.get("mamba", 0) >= 1 + ) while len(key) > 0 and child_key in node.children.keys(): node = node.children[child_key] node.last_access_time = get_last_access_time() @@ -1624,6 +1924,8 @@ def _insert_helper_host( if node.evicted and not node.backuped: node.host_value = host_value[:prefix_len].clone() host_value_inserted = True + if prefix_len == len(key): + final_mamba_node = node self._update_full_host_leaf_status(node) if node.parent is not None: self._update_full_host_leaf_status(node.parent) @@ -1654,8 +1956,13 @@ def _insert_helper_host( new_node.host_value = host_value.clone() new_node.hash_value = hash_value node.children[child_key] = new_node + final_mamba_node = new_node self._update_full_host_leaf_status(new_node) self._update_full_host_leaf_status(node) + if final_mamba_node is not None and mamba_host_value is not None and has_mamba: + final_mamba_node.mamba_host_value = mamba_host_value.clone() + if not self.mamba_host_lru_list.in_list(final_mamba_node): + self.mamba_host_lru_list.insert_mru(final_mamba_node) return matched_length def release_aborted_request(self, rid: str): @@ -1674,4 +1981,215 @@ def release_aborted_request(self, rid: str): self._release_host_node(last_host_node) del self.ongoing_prefetch[rid] self.cache_controller.append_host_mem_release(host_indices[:completed_tokens]) + self.prefetch_abort(operation.pool_transfers) self.cache_controller.prefetch_tokens_occupied -= len(token_ids) + + def _flush_pending_storage_backups_before_reset(self) -> None: + if not self.enable_storage: + return + + self.writing_check(write_back=True) + deadline = time.monotonic() + 30.0 + last_log_time = 0.0 + while time.monotonic() < deadline: + self.drain_storage_control_queues() + backup_qsize = self.cache_controller.backup_queue.qsize() + ack_backup_qsize = self.cache_controller.ack_backup_queue.qsize() + ongoing_backup = len(self.ongoing_backup) + ongoing_write = len(self.ongoing_write_through) + if ( + backup_qsize == 0 + and ack_backup_qsize == 0 + and ongoing_backup == 0 + and ongoing_write == 0 + ): + return + time.sleep(0.05) + + logger.warning( + "Timed out waiting for HiCache storage backups to drain before reset: " + "ongoing_write=%s ongoing_backup=%s backup_queue=%s ack_backup_queue=%s", + len(self.ongoing_write_through), + len(self.ongoing_backup), + self.cache_controller.backup_queue.qsize(), + self.cache_controller.ack_backup_queue.qsize(), + ) + + def _alloc_with_evict( + self, + pool, + size: int, + evict_fn, + lock_node: Optional[TreeNode] = None, + error_message: Optional[str] = None, + ) -> Optional[torch.Tensor]: + indices = pool.alloc(size) + if indices is None: + if lock_node is not None: + self.inc_lock_ref(lock_node) + evict_fn(size) + indices = pool.alloc(size) + if lock_node is not None: + self.dec_lock_ref(lock_node) + if indices is None and error_message is not None: + raise RuntimeError(error_message) + return indices + + def _last_page_index(self, token_count: int) -> Optional[int]: + if token_count <= 0: + return None + return token_count // self.page_size - 1 + + def backup_transfers(self, node: TreeNode) -> Optional[list[PoolTransfer]]: + """PoolTransfers for D→H backup: mamba device → mamba host.""" + if node.mamba_value is None: + return None + return [ + PoolTransfer( + name="mamba", + host_indices=node.mamba_host_value, + device_indices=node.mamba_value, + ) + ] + + def backup_commit(self, node: TreeNode, transfers: list[PoolTransfer]) -> None: + """After D→H backup succeeds: store auto-allocated mamba host indices into node.""" + if not transfers: + return + mamba_host = transfers[0].host_indices + if node.mamba_host_value is None and mamba_host is not None: + node.mamba_host_value = mamba_host + self.mamba_host_lru_list.insert_mru(node) + + def archive_transfers(self, node: TreeNode) -> Optional[list[PoolTransfer]]: + """PoolTransfers for H→Storage archive (write_backup_storage).""" + mamba_host_value = getattr(node, "mamba_host_value", None) + if mamba_host_value is None or not node.hash_value: + return None + return [ + PoolTransfer( + name="mamba", + host_indices=mamba_host_value, + keys=[node.hash_value[-1]], + hit_policy="trailing_pages", + ) + ] + + def prefetch_prepare( + self, + token_ids: List[int], + last_hash: Optional[str], + ) -> Optional[tuple[torch.Tensor, list[PoolTransfer]]]: + """Alloc mamba host slot and build PoolTransfers for Storage→H prefetch. + + Returns (host_handle, transfers) so the caller can track the allocation, + or None if allocation fails (caller should cancel the prefetch). + """ + if not token_ids: + return None + mamba_host_index = self._alloc_with_evict( + self.mamba_pool_host, 1, self.evict_mamba_host + ) + if mamba_host_index is None: + return None + last_page_hash = last_hash + for start in range(0, len(token_ids), self.page_size): + last_page_hash = self.cache_controller.get_hash_str( + token_ids[start : start + self.page_size], last_page_hash + ) + transfers = [ + PoolTransfer( + name="mamba", + host_indices=mamba_host_index, + keys=[last_page_hash], + hit_policy="trailing_pages", + ) + ] + return mamba_host_index, transfers + + def restore_transfers( + self, + last_hit_node: TreeNode, + nodes_to_restore: list[TreeNode], + req, + ) -> Optional[list[PoolTransfer]]: + """PoolTransfers for H→D restore (load_back).""" + mamba_host_list: list[torch.Tensor] = [] + for node in nodes_to_restore: + if not node.mamba_backuped: + continue + mamba_host_list.append(node.mamba_host_value) + + transfers: list[PoolTransfer] = [] + # Transfer from host to node's device mamba space + if mamba_host_list: + transfers.append( + PoolTransfer( + name="mamba", + host_indices=torch.cat(mamba_host_list), + device_indices=None, # will be allocated by cache_controller + ) + ) + + # Transfer from host to request's device mamba space + if ( + req is not None + and last_hit_node in nodes_to_restore + and last_hit_node.mamba_host_value is not None + ): + if req.mamba_pool_idx is None: + req.mamba_pool_idx = self._alloc_with_evict( + self.req_to_token_pool.mamba_pool, + len(last_hit_node.mamba_host_value), + self.evict_mamba, + lock_node=last_hit_node, + error_message="Can not alloc request mamba cache for host load back", + )[0] + transfers.append( + PoolTransfer( + name="mamba", + host_indices=last_hit_node.mamba_host_value, + device_indices=req.mamba_pool_idx.unsqueeze(0), + ) + ) + + return transfers if transfers else None + + def restore_commit( + self, + extra_nodes: list[TreeNode], + transfers: Optional[list[PoolTransfer]], + ) -> None: + """After H→D restore succeeds: write back controller-allocated mamba device indices.""" + if not extra_nodes or not transfers or transfers[0].device_indices is None: + return + mamba_device = transfers[0].device_indices + offset = 0 + for n in extra_nodes: + n_len = len(n.mamba_host_value) + n.mamba_value = mamba_device[offset : offset + n_len].clone() + offset += n_len + + def prefetch_commit( + self, + pool_transfers: Optional[list[PoolTransfer]], + matched_length: int, + min_completed_tokens: int, + result: PoolTransferResult, + ) -> bool: + """After Storage→H prefetch completes: free mamba host slot if not inserted into tree. + + Returns True if a mamba state was successfully loaded into the radix tree. + """ + mamba_host_indices = None + for transfer in pool_transfers or []: + if transfer.name == "mamba": + mamba_host_indices = transfer.host_indices + break + if mamba_host_indices is None: + return False + # mamba covers the entire prefix as a single page; loaded means >= 1 page hit. + mamba_loaded = result.extra_pool_hit_pages.get("mamba", 0) >= 1 + if matched_length == min_completed_tokens or not mamba_loaded: + self.mamba_pool_host.free(mamba_host_indices) + return mamba_loaded diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index e76cfebd4463..6cffaa1069d0 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import hashlib import logging import os from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, List, Optional +from enum import Enum +from typing import Any, List, Literal, Optional import torch @@ -13,6 +16,12 @@ logger = logging.getLogger(__name__) +def _pool_name_key(pool_name: PoolName | str | None) -> Optional[str]: + if pool_name is None: + return None + return pool_name.value if isinstance(pool_name, Enum) else str(pool_name) + + def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str: hasher = hashlib.sha256() @@ -21,24 +30,16 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str: for t in token_ids: if isinstance(t, tuple): - # EAGLE bigram mode: hash both elements to uniquely identify the bigram for elem in t: hasher.update(elem.to_bytes(4, byteorder="little", signed=False)) else: - # Regular mode: single integer token hasher.update(t.to_bytes(4, byteorder="little", signed=False)) return hasher.hexdigest() def hash_str_to_int64(hash_str: str) -> int: - """Convert SHA256 hex string to signed 64-bit integer for events. - - Takes first 16 hex characters (64 bits) and converts to signed int64 range. - """ - # Take first 16 hex chars to get 64-bit value uint64_val = int(hash_str[:16], 16) - # Convert to signed int64 range [-2^63, 2^63-1] if uint64_val >= 2**63: return uint64_val - 2**64 return uint64_val @@ -52,7 +53,7 @@ class HiCacheStorageConfig: pp_size: int is_mla_model: bool enable_storage_metrics: bool - is_page_first_layout: bool + layout: Literal["layer_first", "page_first", "page_first_direct", "page_head"] model_name: Optional[str] tp_lcm_size: Optional[int] = None should_split_heads: bool = False @@ -65,27 +66,93 @@ class HiCacheStorageExtraInfo: extra_info: Optional[dict] = None -class HiCacheStorage(ABC): - """ - HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache. - It abstracts the underlying storage mechanism, allowing different implementations to be used. - """ +class PoolName(str, Enum): + KV = "kv" + MAMBA = "mamba" + NSA = "nsa" + - # todo, the page size of storage backend does not have to be the same as the same as host memory pool +class PoolHitPolicy(str, Enum): + ALL_PAGES = "all_pages" + TRAILING_PAGES = "trailing_pages" + + +@dataclass +class PoolTransfer: + name: PoolName + host_indices: Optional[torch.Tensor] = None + device_indices: Optional[torch.Tensor] = None + keys: Optional[List[str]] = None + hit_policy: PoolHitPolicy = PoolHitPolicy.ALL_PAGES + use_anchor_host_indices: bool = False + use_anchor_device_indices: bool = False + + +@dataclass +class PoolTransferResult: + kv_hit_pages: int + extra_pool_hit_pages: dict[str, int] + + @classmethod + def empty(cls) -> "PoolTransferResult": + return cls(0, {}) + + @staticmethod + def _count_consecutive_true(results: List[bool]) -> int: + for i, ok in enumerate(results): + if not ok: + return i + return len(results) + + def update_kv_hit_pages(self, kv_hit_pages: int) -> None: + self.kv_hit_pages = max(self.kv_hit_pages, kv_hit_pages) + + def update_extra_pool_hit_pages(self, results: dict[str, List[bool]]) -> None: + for name, rs in results.items(): + self.extra_pool_hit_pages[name] = self.extra_pool_hit_pages.get(name, 0) + ( + self._count_consecutive_true(rs) + ) + + +class HiCacheStorage(ABC): + _NSA_INDEXER_SUFFIX = "__nsa_idx" def register_mem_pool_host(self, mem_pool_host: HostKVCache): self.mem_pool_host = mem_pool_host + def register_mem_host_pool_v2(self, host_pool: HostKVCache, host_pool_name): + if not hasattr(self, "registered_pools"): + self.registered_pools = {} + self.registered_pools[host_pool_name] = host_pool + + def batch_exists_v2( + self, + keys: List[str], + pool_transfers: Optional[List[PoolTransfer]] = None, + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> PoolTransferResult: + raise NotImplementedError() + + def batch_get_v2( + self, + transfers: List[PoolTransfer], + extra_info: Optional["HiCacheStorageExtraInfo"] = None, + ) -> dict[str, List[bool]]: + raise NotImplementedError() + + def batch_set_v2( + self, + transfers: List[PoolTransfer], + extra_info: Optional["HiCacheStorageExtraInfo"] = None, + ) -> dict[str, List[bool]]: + raise NotImplementedError() + def batch_get_v1( self, keys: List[str], host_indices: torch.Tensor, extra_info: Optional[HiCacheStorageExtraInfo] = None, ) -> List[bool]: - """ - Retrieve values for multiple keys. - Returns a list of booleans indicating success for each key. - """ pass def batch_set_v1( @@ -94,10 +161,6 @@ def batch_set_v1( host_indices: torch.Tensor, extra_info: Optional[HiCacheStorageExtraInfo] = None, ) -> List[bool]: - """ - Store multiple key-value pairs. - Returns a list of booleans indicating success for each key. - """ pass @abstractmethod @@ -107,13 +170,8 @@ def get( target_location: Optional[Any] = None, target_sizes: Optional[Any] = None, ) -> torch.Tensor | None: - """ - Retrieve the value associated with the given key. - Returns None if the key does not exist. - """ pass - # TODO: Deprecate @abstractmethod def batch_get( self, @@ -121,10 +179,6 @@ def batch_get( target_locations: Optional[Any] = None, target_sizes: Optional[Any] = None, ) -> List[torch.Tensor | None] | int: - """ - Retrieve values for multiple keys. - Returns a list of tensors or None for each key. - """ pass @abstractmethod @@ -135,13 +189,8 @@ def set( target_location: Optional[Any] = None, target_sizes: Optional[Any] = None, ) -> bool: - """ - Store the value associated with the given key. - Returns True if the operation was successful, False otherwise. - """ pass - # TODO: Deprecate @abstractmethod def batch_set( self, @@ -150,31 +199,17 @@ def batch_set( target_locations: Optional[Any] = None, target_sizes: Optional[Any] = None, ) -> bool: - """ - Store multiple key-value pairs. - Returns True if all operations were successful, False otherwise. - """ pass @abstractmethod def exists(self, key: str) -> bool: - """ - Check if the key exists in the storage. - Returns True if the key exists, False otherwise. - """ pass - # TODO: Use a finer-grained return type (e.g., List[bool]) def batch_exists( self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None ) -> int: - """ - Check if the keys exist in the storage. - return the number of consecutive existing keys from the start. - Can be overridden by subclasses for more efficient implementation. - """ - for i in range(len(keys)): - if not self.exists(keys[i]): + for i, key in enumerate(keys): + if not self.exists(key): return i return len(keys) @@ -186,7 +221,6 @@ def get_stats(self): class HiCacheFile(HiCacheStorage): - def __init__( self, storage_config: HiCacheStorageConfig, file_path: str = "/tmp/hicache" ): @@ -206,19 +240,31 @@ def __init__( if not os.path.exists(self.file_path) and tp_rank == 0: os.makedirs(self.file_path) - logger.info(f"Created HiCacheFile storage directory at {self.file_path}") + logger.info("Created HiCacheFile storage directory at %s", self.file_path) def _get_suffixed_key(self, key: str) -> str: return key + self.config_suffix + def _component_key(self, key: str, pool_name: PoolName | str | None = None) -> str: + pool_name_key = _pool_name_key(pool_name) + if pool_name_key in (None, "__default__", PoolName.KV.value): + return self._get_suffixed_key(key) + if pool_name_key == PoolName.NSA.value: + return self._get_suffixed_key(f"{key}{self._NSA_INDEXER_SUFFIX}") + return self._get_suffixed_key(f"{key}.{pool_name_key}") + + def _component_path(self, key: str, pool_name: PoolName | str | None = None) -> str: + return os.path.join( + self.file_path, f"{self._component_key(key, pool_name)}.bin" + ) + def get( self, key: str, target_location: torch.Tensor, target_sizes: Optional[Any] = None, ) -> torch.Tensor | None: - key = self._get_suffixed_key(key) - tensor_path = os.path.join(self.file_path, f"{key}.bin") + tensor_path = self._component_path(key) try: expected = target_location.numel() * target_location.element_size() with open(tensor_path, "rb", buffering=0) as f: @@ -227,7 +273,7 @@ def get( raise IOError(f"Short read for {key}") return target_location except FileNotFoundError: - logger.warning(f"Failed to fetch {key} from HiCacheFile storage.") + logger.warning("Failed to fetch %s from HiCacheFile storage.", key) return None def batch_get( @@ -251,16 +297,15 @@ def set( target_sizes: Optional[Any] = None, ) -> bool: if self.exists(key): - logger.debug(f"Key {key} already exists. Skipped.") + logger.debug("Key %s already exists. Skipped.", key) return True - key = self._get_suffixed_key(key) - tensor_path = os.path.join(self.file_path, f"{key}.bin") + tensor_path = self._component_path(key) try: value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path) return True except Exception as e: - logger.error(f"Failed to save tensor {key}: {e}") + logger.error("Failed to save tensor %s: %s", key, e) return False def batch_set( @@ -276,9 +321,134 @@ def batch_set( return True def exists(self, key: str) -> bool: - key = self._get_suffixed_key(key) - tensor_path = os.path.join(self.file_path, f"{key}.bin") - return os.path.exists(tensor_path) + return os.path.exists(self._component_path(key)) + + def _has_component(self, key: str, pool_name: PoolName | str | None = None) -> bool: + return os.path.exists(self._component_path(key, pool_name)) + + def batch_exists_v2( + self, + keys: List[str], + pool_transfers: Optional[List[PoolTransfer]] = None, + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> PoolTransferResult: + kv_pages = next( + ( + i + for i, key in enumerate(keys) + if not self._has_component(key, PoolName.KV) + ), + len(keys), + ) + + hit_count: dict[str, int] = {PoolName.KV.value: kv_pages} if kv_pages else {} + final_pages = kv_pages + + for transfer in pool_transfers or []: + if final_pages == 0: + break + name = transfer.name + if transfer.hit_policy == PoolHitPolicy.ALL_PAGES: + boundary = next( + ( + i + for i in range(kv_pages) + if not self._has_component(keys[i], name) + ), + kv_pages, + ) + else: + trailing = max(1, len(transfer.keys) if transfer.keys else 1) + boundary = 0 + for prefix_len in range(kv_pages, 0, -1): + if all( + self._has_component(keys[i], name) + for i in range(max(0, prefix_len - trailing), prefix_len) + ): + boundary = prefix_len + break + if boundary: + hit_count[_pool_name_key(name)] = boundary + final_pages = min(final_pages, boundary) + + if pool_transfers: + logger.info( + "HiCacheFile batch_exists_v2: kv_pages=%s final_pages=%s hit_count=%s first_key=%s last_key=%s", + kv_pages, + final_pages, + hit_count, + keys[0] if keys else None, + keys[final_pages - 1] if final_pages > 0 else None, + ) + + return PoolTransferResult(final_pages, hit_count) + + def _log_key(self, pool_name: PoolName | str, key: str) -> str: + pool_name_key = _pool_name_key(pool_name) + if pool_name_key == PoolName.KV.value: + return key + if pool_name_key == PoolName.NSA.value: + return f"{key}{self._NSA_INDEXER_SUFFIX}" + return f"{key}.{pool_name_key}" + + def _read_page(self, pool_name, key: str, host_pool, page_offset: int) -> bool: + storage_key = self._log_key(pool_name, key) + data_page = self.get(storage_key, host_pool.get_dummy_flat_data_page()) + if data_page is None: + return False + host_pool.set_from_flat_data_page(page_offset, data_page) + return True + + def _write_page(self, pool_name, key: str, host_pool, page_offset: int) -> bool: + storage_key = self._log_key(pool_name, key) + data_page = host_pool.get_data_page(page_offset, flat=True) + return self.set(storage_key, data_page) + + def _batch_io_v2(self, transfers: List[PoolTransfer], op_fn): + results: dict[str, List[bool]] = {} + for transfer in transfers: + transfer_name = _pool_name_key(transfer.name) + host_pool = self.registered_pools[transfer_name] + keys = transfer.keys or [] + page_size = getattr(host_pool, "page_size", 1) or 1 + expected = len(keys) * page_size + host_indices = transfer.host_indices + + if host_indices is None or host_indices.numel() != expected: + logger.error( + "%s indices length mismatch for %s: expected %s, got %s", + op_fn.__name__, + transfer.name, + expected, + host_indices.numel() if host_indices is not None else 0, + ) + results[transfer_name] = [False] * len(keys) + continue + + results[transfer_name] = [ + op_fn( + transfer.name, + key, + host_pool, + host_indices[i * page_size].item(), + ) + for i, key in enumerate(keys) + ] + return results + + def batch_get_v2( + self, + transfers: List[PoolTransfer], + extra_info: Optional["HiCacheStorageExtraInfo"] = None, + ) -> dict[str, List[bool]]: + return self._batch_io_v2(transfers, self._read_page) + + def batch_set_v2( + self, + transfers: List[PoolTransfer], + extra_info: Optional["HiCacheStorageExtraInfo"] = None, + ) -> dict[str, List[bool]]: + return self._batch_io_v2(transfers, self._write_page) def clear(self) -> bool: try: @@ -289,5 +459,5 @@ def clear(self) -> bool: logger.info("Cleared all entries in HiCacheFile storage.") return True except Exception as e: - logger.error(f"Failed to clear HiCacheFile storage: {e}") + logger.error("Failed to clear HiCacheFile storage: %s", e) return False diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index d1e9cf1ae89d..d776ef08a77b 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -25,15 +25,26 @@ MatchPrefixParams, MatchResult, ) +from sglang.srt.mem_cache.hicache_storage import ( + PoolHitPolicy, + PoolName, + PoolTransfer, +) +from sglang.srt.mem_cache.hybrid_cache.hybrid_cache_controller import ( + HybridCacheController, +) from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, MLATokenToKVPool, NSATokenToKVPool, ) from sglang.srt.mem_cache.memory_pool_host import ( + HostPoolGroup, MHATokenToKVPoolHost, MLATokenToKVPoolHost, + NSAIndexerHostPool, NSATokenToKVPoolHost, + PoolEntry, ) from sglang.srt.mem_cache.radix_cache import ( RadixCache, @@ -70,6 +81,23 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): self.page_size = params.page_size self.kv_cache = params.token_to_kv_pool_allocator.get_kvcache() + self.use_nsa_pool_controller = isinstance(self.kv_cache, NSATokenToKVPool) + + if ( + self.use_nsa_pool_controller + and server_args.hicache_storage_backend == "mooncake" + and server_args.hicache_mem_layout + not in ["page_first", "page_first_direct"] + ): + server_args.hicache_mem_layout = ( + "page_first_direct" + if server_args.hicache_io_backend == "direct" + else "page_first" + ) + logger.warning( + "Mooncake storage backend with NSA requires page_first layout, " + f"switching to {server_args.hicache_mem_layout}." + ) if isinstance(self.kv_cache, MHATokenToKVPool): self.token_to_kv_pool_host = MHATokenToKVPoolHost( @@ -123,22 +151,62 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): self.prefetch_stop_policy = server_args.hicache_storage_prefetch_policy self.load_cache_event = threading.Event() - self.cache_controller = HiCacheController( - params.token_to_kv_pool_allocator, - self.token_to_kv_pool_host, - self.page_size, - self.tp_group, - load_cache_event=self.load_cache_event, - write_policy=server_args.hicache_write_policy, - io_backend=server_args.hicache_io_backend, - storage_backend=server_args.hicache_storage_backend, - prefetch_threshold=prefetch_threshold, - model_name=server_args.served_model_name, - storage_backend_extra_config=extra_config, - pp_rank=self.pp_rank, - pp_size=self.pp_size, - enable_storage_metrics=self.enable_storage_metrics, - ) + if self.use_nsa_pool_controller: + if server_args.hicache_storage_backend not in (None, "file", "mooncake"): + raise ValueError( + "NSA pool-based HiCache only supports file and mooncake storage backends." + ) + self.nsa_indexer_host_pool = NSAIndexerHostPool(self.token_to_kv_pool_host) + self.host_pool_group = HostPoolGroup( + [ + PoolEntry( + name=PoolName.KV, + host_pool=self.token_to_kv_pool_host, + device_pool=self.kv_cache, + layer_mapper=lambda layer_id: layer_id, + is_primary_index_anchor=True, + ), + PoolEntry( + name=PoolName.NSA, + host_pool=self.nsa_indexer_host_pool, + device_pool=self.kv_cache, + layer_mapper=lambda layer_id: layer_id, + ), + ] + ) + self.cache_controller = HybridCacheController( + params.token_to_kv_pool_allocator, + self.host_pool_group, + self.page_size, + self.tp_group, + load_cache_event=self.load_cache_event, + write_policy=server_args.hicache_write_policy, + io_backend=server_args.hicache_io_backend, + storage_backend=server_args.hicache_storage_backend, + prefetch_threshold=prefetch_threshold, + model_name=server_args.served_model_name, + storage_backend_extra_config=extra_config, + pp_rank=self.pp_rank, + pp_size=self.pp_size, + enable_storage_metrics=self.enable_storage_metrics, + ) + else: + self.cache_controller = HiCacheController( + params.token_to_kv_pool_allocator, + self.token_to_kv_pool_host, + self.page_size, + self.tp_group, + load_cache_event=self.load_cache_event, + write_policy=server_args.hicache_write_policy, + io_backend=server_args.hicache_io_backend, + storage_backend=server_args.hicache_storage_backend, + prefetch_threshold=prefetch_threshold, + model_name=server_args.served_model_name, + storage_backend_extra_config=extra_config, + pp_rank=self.pp_rank, + pp_size=self.pp_size, + enable_storage_metrics=self.enable_storage_metrics, + ) self._apply_storage_runtime_config( storage_backend=server_args.hicache_storage_backend, prefetch_threshold=prefetch_threshold, @@ -257,6 +325,11 @@ def attach_storage_backend( prefetch/backup paths. Caller must ensure there are no running/queued requests to avoid races. """ + if self.use_nsa_pool_controller and storage_backend not in ("file", "mooncake"): + return ( + False, + "NSA pool-based HiCache only supports file and mooncake storage backends.", + ) # Validate inputs first (no side effects). if hicache_storage_prefetch_policy is not None: allowed = ["best_effort", "wait_complete", "timeout"] @@ -338,12 +411,15 @@ def attach_storage_backend( ) try: - self.cache_controller.attach_storage_backend( + attach_kwargs = dict( storage_backend=storage_backend, prefetch_threshold=prefetch_threshold, model_name=served_model_name, storage_backend_extra_config=extra_config, ) + if self.use_nsa_pool_controller: + attach_kwargs["host_pools"] = self.host_pool_group.entries + self.cache_controller.attach_storage_backend(**attach_kwargs) except Exception as e: logger.exception( f"Failed to attach storage backend '{storage_backend}': {e}" @@ -662,7 +738,11 @@ def write_backup_storage(self, node: TreeNode): ) operation_id = self.cache_controller.write_storage( - node.host_value, node.key, node.hash_value, prefix_keys + node.host_value, + node.key, + node.hash_value, + prefix_keys, + extra_pools=self.nsa_archive_transfers(node), ) self.ongoing_backup[operation_id] = node node.protect_host() @@ -1207,6 +1287,10 @@ def check_prefetch_progress(self, req_id: str) -> bool: logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") min_completed_tokens = completed_tokens + if self.use_nsa_pool_controller: + min_completed_tokens = ( + self.cache_controller.get_usable_prefetch_token_count(operation) + ) if self.tp_world_size > 1: # synchrnoize TP workers to make the same update to hiradix cache completed_tokens_tensor = torch.tensor( @@ -1336,7 +1420,12 @@ def prefetch_from_storage( # no sufficient host memory for prefetch return operation = self.cache_controller.prefetch( - req_id, host_indices, new_input_tokens, last_hash, prefix_keys + req_id, + host_indices, + new_input_tokens, + last_hash, + prefix_keys, + extra_pools=self.nsa_prefetch_transfers(), ) self.ongoing_prefetch[req_id] = ( last_host_node, @@ -1346,6 +1435,55 @@ def prefetch_from_storage( ) self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens) + def nsa_backup_transfers(self) -> Optional[list[PoolTransfer]]: + if not self.use_nsa_pool_controller: + return None + return [ + PoolTransfer( + name=PoolName.NSA, + hit_policy=PoolHitPolicy.ALL_PAGES, + use_anchor_host_indices=True, + use_anchor_device_indices=True, + ) + ] + + def nsa_archive_transfers(self, node: TreeNode) -> Optional[list[PoolTransfer]]: + if not self.use_nsa_pool_controller or not node.hash_value: + return None + return [ + PoolTransfer( + name=PoolName.NSA, + keys=node.hash_value, + hit_policy=PoolHitPolicy.ALL_PAGES, + use_anchor_host_indices=True, + use_anchor_device_indices=True, + ) + ] + + def nsa_prefetch_transfers(self) -> Optional[list[PoolTransfer]]: + if not self.use_nsa_pool_controller: + return None + return [ + PoolTransfer( + name=PoolName.NSA, + hit_policy=PoolHitPolicy.ALL_PAGES, + use_anchor_host_indices=True, + use_anchor_device_indices=True, + ) + ] + + def nsa_restore_transfers(self) -> Optional[list[PoolTransfer]]: + if not self.use_nsa_pool_controller: + return None + return [ + PoolTransfer( + name=PoolName.NSA, + hit_policy=PoolHitPolicy.ALL_PAGES, + use_anchor_host_indices=True, + use_anchor_device_indices=True, + ) + ] + def _insert_helper_host( self, node: TreeNode, key: RadixKey, host_value, hash_value ): diff --git a/python/sglang/srt/mem_cache/hybrid_cache/__init__.py b/python/sglang/srt/mem_cache/hybrid_cache/__init__.py new file mode 100644 index 000000000000..9881313609aa --- /dev/null +++ b/python/sglang/srt/mem_cache/hybrid_cache/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/python/sglang/srt/mem_cache/hybrid_cache/hybrid_cache_controller.py b/python/sglang/srt/mem_cache/hybrid_cache/hybrid_cache_controller.py new file mode 100644 index 000000000000..fa3de98731e2 --- /dev/null +++ b/python/sglang/srt/mem_cache/hybrid_cache/hybrid_cache_controller.py @@ -0,0 +1,579 @@ +from __future__ import annotations + +import logging +import threading +import time +from typing import TYPE_CHECKING, Any, List, Optional + +import torch + +from sglang.srt.managers.cache_controller import CacheOperation as BaseCacheOperation +from sglang.srt.managers.cache_controller import ( + HiCacheAck, +) +from sglang.srt.managers.cache_controller import ( + HiCacheController as BaseHiCacheController, +) +from sglang.srt.managers.cache_controller import ( + LayerDoneCounter, +) +from sglang.srt.managers.cache_controller import ( + StorageOperation as BaseStorageOperation, +) +from sglang.srt.mem_cache.hicache_storage import ( + HiCacheStorageExtraInfo, + PoolHitPolicy, + PoolTransfer, + PoolTransferResult, +) +from sglang.srt.mem_cache.memory_pool_host import PoolEntry +from sglang.srt.utils import get_device_module + +if TYPE_CHECKING: + from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator + +logger = logging.getLogger(__name__) +device_module = get_device_module() + + +def _pool_name_key(pool_name) -> str: + return pool_name.value if hasattr(pool_name, "value") else str(pool_name) + + +class CacheOperation(BaseCacheOperation): + def __init__( + self, + host_indices: torch.Tensor, + device_indices: torch.Tensor, + node_id: int, + priority: Optional[int] = None, + pool_transfers: Optional[list[PoolTransfer]] = None, + ): + super().__init__(host_indices, device_indices, node_id, priority) + self.pool_transfers = pool_transfers + + @staticmethod + def merge_pool_transfers( + ops: List["CacheOperation"], + ) -> Optional[list[PoolTransfer]]: + grouped: dict[str, list[PoolTransfer]] = {} + for op in ops: + for transfer in op.pool_transfers or []: + grouped.setdefault(_pool_name_key(transfer.name), []).append(transfer) + if not grouped: + return None + + def cat_or_none(tensors): + parts = [x for x in tensors if x is not None] + return torch.cat(parts) if parts else None + + merged = [] + for transfers in grouped.values(): + first = transfers[0] + merged.append( + PoolTransfer( + name=first.name, + host_indices=cat_or_none(t.host_indices for t in transfers), + device_indices=cat_or_none(t.device_indices for t in transfers), + keys=[k for t in transfers if t.keys for k in t.keys] or None, + hit_policy=first.hit_policy, + use_anchor_host_indices=first.use_anchor_host_indices, + use_anchor_device_indices=first.use_anchor_device_indices, + ) + ) + return merged + + @staticmethod + def merge_ops(ops: List["CacheOperation"]) -> "CacheOperation": + if len(ops) == 1: + return ops[0] + host_indices = torch.cat([op.host_indices for op in ops]) + device_indices = torch.cat([op.device_indices for op in ops]) + node_ids = [] + priority = min(op.priority for op in ops) + for op in ops: + node_ids.extend(op.node_ids) + merged = CacheOperation( + host_indices, + device_indices, + -1, + priority, + pool_transfers=CacheOperation.merge_pool_transfers(ops), + ) + merged.node_ids = node_ids + return merged + + +class StorageOperation(BaseStorageOperation): + def __init__( + self, + host_indices: torch.Tensor, + token_ids: List[int], + last_hash: Optional[str] = None, + hash_value: Optional[List[str]] = None, + prefix_keys: Optional[List[str]] = None, + pool_transfers: Optional[list[PoolTransfer]] = None, + ): + super().__init__(host_indices, token_ids, last_hash, hash_value, prefix_keys) + self.pool_transfers = pool_transfers + self.pool_storage_result = PoolTransferResult.empty() + + +class PrefetchOperation(StorageOperation): + def __init__( + self, + request_id: str, + host_indices: torch.Tensor, + token_ids: List[int], + last_hash: Optional[str] = None, + prefix_keys: Optional[List[str]] = None, + pool_transfers: Optional[list[PoolTransfer]] = None, + ): + self.request_id = request_id + self._lock = threading.Lock() + self._terminated_flag = False + self.start_time = time.monotonic() + super().__init__( + host_indices, + token_ids, + last_hash, + prefix_keys=prefix_keys, + pool_transfers=pool_transfers, + ) + + def increment(self, num_tokens: int): + with self._lock: + if self._terminated_flag: + return False + self.completed_tokens += num_tokens + return True + + def mark_terminate(self): + with self._lock: + self._terminated_flag = True + + def is_terminated(self) -> bool: + return self._terminated_flag + + +class HybridCacheController(BaseHiCacheController): + def __init__( + self, + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, + mem_pool_host: Any, + page_size: int, + tp_group: torch.distributed.ProcessGroup, + load_cache_event: threading.Event, + write_policy: str = "write_through_selective", + io_backend: str = "", + storage_backend: Optional[str] = None, + prefetch_threshold: int = 256, + model_name: Optional[str] = None, + storage_backend_extra_config: Optional[dict] = None, + pp_rank: int = 0, + pp_size: int = 1, + transfer_layer_num: Optional[int] = None, + enable_storage_metrics: bool = False, + ): + self.transfer_layer_num = transfer_layer_num + startup_storage_backend = storage_backend + super().__init__( + token_to_kv_pool_allocator=token_to_kv_pool_allocator, + mem_pool_host=mem_pool_host, + page_size=page_size, + tp_group=tp_group, + load_cache_event=load_cache_event, + write_policy=write_policy, + io_backend=io_backend, + storage_backend=None, + prefetch_threshold=prefetch_threshold, + model_name=model_name, + storage_backend_extra_config=storage_backend_extra_config, + pp_rank=pp_rank, + pp_size=pp_size, + enable_storage_metrics=enable_storage_metrics, + ) + self.transfer_layer_num = self.transfer_layer_num or self.layer_num + if self.transfer_layer_num != self.layer_done_counter.num_layers: + self.layer_done_counter = LayerDoneCounter(self.transfer_layer_num) + self.mem_pool_device.register_layer_transfer_counter( + self.layer_done_counter + ) + if startup_storage_backend is not None: + self.attach_storage_backend( + storage_backend=startup_storage_backend, + prefetch_threshold=prefetch_threshold, + model_name=model_name, + storage_backend_extra_config=storage_backend_extra_config, + host_pools=getattr(mem_pool_host, "entries", None), + ) + + def attach_storage_backend( + self, + storage_backend: str, + prefetch_threshold: int = 256, + model_name: Optional[str] = None, + storage_backend_extra_config: Optional[dict] = None, + host_pools: Optional[list[PoolEntry]] = None, + ): + super().attach_storage_backend( + storage_backend=storage_backend, + prefetch_threshold=prefetch_threshold, + model_name=model_name, + storage_backend_extra_config=storage_backend_extra_config, + ) + + for entry in host_pools or []: + self.storage_backend.register_mem_host_pool_v2( + entry.host_pool, _pool_name_key(entry.name) + ) + logger.info("Using pool-based interface for hybrid storage operations") + + def reset(self): + super().reset() + if self.enable_storage: + self.host_mem_release_queue.queue.clear() + self.prefetch_tokens_occupied = 0 + + def write( + self, + device_indices: torch.Tensor, + priority: Optional[int] = None, + node_id: int = -1, + extra_pools: Optional[list[PoolTransfer]] = None, + ) -> Optional[torch.Tensor]: + host_indices = self.mem_pool_host.alloc(len(device_indices)) + if host_indices is None: + return None + pool_transfers = self._resolve_pool_transfers_allocation( + extra_pools, alloc_host=True + ) + if pool_transfers is None and extra_pools: + self.mem_pool_host.free(host_indices) + return None + + self.write_queue.append( + CacheOperation( + host_indices, + device_indices, + node_id, + priority, + pool_transfers=pool_transfers or None, + ) + ) + self.start_writing() + return host_indices + + def start_writing(self) -> None: + if not self.write_queue: + return + op = CacheOperation.merge_ops(self.write_queue) + host_indices, device_indices = self.move_indices(op) + self.write_queue.clear() + start_event = device_module.Event() + finish_event = device_module.Event() + start_event.record() + with device_module.stream(self.write_stream): + start_event.wait(self.write_stream) + self.mem_pool_host.backup_from_device_all_layer( + self.mem_pool_device, + host_indices, + device_indices, + self.io_backend, + pool_transfers=op.pool_transfers, + ) + finish_event.record() + if host_indices.is_cuda: + host_indices.record_stream(self.write_stream) + if device_indices.is_cuda: + device_indices.record_stream(self.write_stream) + self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids)) + + def load( + self, + host_indices: torch.Tensor, + priority: Optional[int] = None, + node_id: int = -1, + extra_pools: Optional[list[PoolTransfer]] = None, + ) -> Optional[torch.Tensor]: + need_load_kv = host_indices.numel() > 0 + if need_load_kv: + device_indices = self.mem_pool_device_allocator.alloc(len(host_indices)) + if device_indices is None: + return None + else: + device_indices = torch.empty((0,), dtype=torch.int64, device=self.device) + + pool_transfers = self._resolve_pool_transfers_allocation( + extra_pools, alloc_host=False + ) + if pool_transfers is None and extra_pools: + if need_load_kv: + self.mem_pool_device_allocator.free(device_indices) + return None + + self.load_queue.append( + CacheOperation( + host_indices, + device_indices, + node_id, + priority, + pool_transfers=pool_transfers or None, + ) + ) + return device_indices + + def start_loading(self) -> int: + if not self.load_queue: + return -1 + producer_id = self.layer_done_counter.update_producer() + op = CacheOperation.merge_ops(self.load_queue) + host_indices, device_indices = self.move_indices(op) + self.load_queue.clear() + producer_event = self.layer_done_counter.events[producer_id] + producer_event.start_event.record() + with device_module.stream(self.load_stream): + producer_event.start_event.wait(self.load_stream) + for i in range(self.transfer_layer_num): + self.mem_pool_host.load_to_device_per_layer( + self.mem_pool_device, + host_indices, + device_indices, + i, + self.io_backend, + pool_transfers=op.pool_transfers, + ) + producer_event.complete(i) + if host_indices.is_cuda: + host_indices.record_stream(self.load_stream) + if device_indices.is_cuda: + device_indices.record_stream(self.load_stream) + self.ack_load_queue.append( + HiCacheAck( + producer_event.start_event, + producer_event.finish_event, + op.node_ids, + ) + ) + return producer_id + + def prefetch( + self, + request_id: str, + host_indices: torch.Tensor, + new_input_tokens: List[int], + last_hash: Optional[str] = None, + prefix_keys: Optional[List[str]] = None, + extra_pools: Optional[list[PoolTransfer]] = None, + ) -> PrefetchOperation: + operation = PrefetchOperation( + request_id, + host_indices, + new_input_tokens, + last_hash, + prefix_keys=prefix_keys, + pool_transfers=extra_pools, + ) + self.prefetch_queue.put(operation) + return operation + + def write_storage( + self, + host_indices: torch.Tensor, + token_ids: List[int], + hash_value: Optional[List[str]] = None, + prefix_keys: Optional[List[str]] = None, + extra_pools: Optional[list[PoolTransfer]] = None, + ) -> int: + operation = StorageOperation( + host_indices, + token_ids, + hash_value=hash_value, + prefix_keys=prefix_keys, + pool_transfers=extra_pools, + ) + self.backup_queue.put(operation) + return operation.id + + def _populate_transfer_keys( + self, + pool_transfers: list[PoolTransfer], + all_hashes: list[str], + kv_hit_pages: int, + ) -> None: + for transfer in pool_transfers: + if transfer.hit_policy == PoolHitPolicy.ALL_PAGES: + transfer.keys = all_hashes[:kv_hit_pages] + else: + trailing_n = len(transfer.keys) if transfer.keys else 1 + transfer.keys = all_hashes[ + max(0, kv_hit_pages - trailing_n) : kv_hit_pages + ] + + def _storage_hit_query(self, operation) -> tuple[list[str], int]: + last_hash = operation.last_hash + hash_value = [] + for start in range(0, len(operation.token_ids), self.page_size): + last_hash = self.get_hash_str( + operation.token_ids[start : start + self.page_size], last_hash + ) + hash_value.append(last_hash) + + extra_info = HiCacheStorageExtraInfo( + prefix_keys=operation.prefix_keys.copy() if operation.prefix_keys else None + ) + if operation.pool_transfers: + hit_result = self.storage_backend.batch_exists_v2( + hash_value, operation.pool_transfers, extra_info + ) + else: + kv_hit_count = self.storage_backend.batch_exists(hash_value, extra_info) + hit_result = PoolTransferResult( + kv_hit_pages=kv_hit_count, extra_pool_hit_pages={} + ) + + kv_hit_pages = hit_result.kv_hit_pages + operation.pool_storage_result.update_kv_hit_pages(kv_hit_pages) + + if kv_hit_pages > 0 and operation.pool_transfers: + self._populate_transfer_keys( + operation.pool_transfers, hash_value, kv_hit_pages + ) + + return hash_value[:kv_hit_pages], kv_hit_pages * self.page_size + + def _materialize_storage_transfers( + self, + pool_transfers: Optional[list[PoolTransfer]], + anchor_host_indices: torch.Tensor, + anchor_device_indices: Optional[torch.Tensor], + kv_pages: int, + ) -> Optional[list[PoolTransfer]]: + if not pool_transfers: + return None + materialized: list[PoolTransfer] = [] + for transfer in pool_transfers: + entry = self.mem_pool_host.entry_map.get(_pool_name_key(transfer.name)) + if entry is None: + continue + page_size = getattr(entry.host_pool, "page_size", self.page_size) or 1 + keys = transfer.keys or [] + if transfer.hit_policy == PoolHitPolicy.ALL_PAGES: + keys = keys[:kv_pages] + item_count = len(keys) * page_size + host_indices = transfer.host_indices + if host_indices is None and transfer.use_anchor_host_indices: + host_indices = anchor_host_indices + if host_indices is not None: + host_indices = host_indices[:item_count] + device_indices = transfer.device_indices + if device_indices is None and transfer.use_anchor_device_indices: + device_indices = anchor_device_indices + if device_indices is not None: + device_indices = device_indices[:item_count] + materialized.append( + PoolTransfer( + name=transfer.name, + host_indices=host_indices, + device_indices=device_indices, + keys=keys, + hit_policy=transfer.hit_policy, + use_anchor_host_indices=transfer.use_anchor_host_indices, + use_anchor_device_indices=transfer.use_anchor_device_indices, + ) + ) + return materialized + + def _page_transfer(self, operation): + super()._page_transfer(operation) + kv_pages = operation.completed_tokens // self.page_size + transfers = self._materialize_storage_transfers( + operation.pool_transfers, + operation.host_indices, + None, + kv_pages, + ) + if transfers and not operation.is_terminated(): + results = self.storage_backend.batch_get_v2(transfers) + operation.pool_storage_result.update_extra_pool_hit_pages(results) + + def _page_backup(self, operation): + super()._page_backup(operation) + kv_pages = operation.completed_tokens // self.page_size + transfers = self._materialize_storage_transfers( + operation.pool_transfers, + operation.host_indices, + None, + kv_pages, + ) + if transfers: + results = self.storage_backend.batch_set_v2(transfers) + operation.pool_storage_result.update_extra_pool_hit_pages(results) + + def get_usable_prefetch_token_count(self, operation: PrefetchOperation) -> int: + usable_pages = operation.completed_tokens // self.page_size + for transfer in operation.pool_transfers or []: + if transfer.hit_policy != PoolHitPolicy.ALL_PAGES: + continue + usable_pages = min( + usable_pages, + operation.pool_storage_result.extra_pool_hit_pages.get( + _pool_name_key(transfer.name), 0 + ), + ) + return usable_pages * self.page_size + + def _resolve_pool_transfers_allocation( + self, + extra_pools: Optional[list[PoolTransfer]], + alloc_host: bool, + ) -> Optional[list[PoolTransfer]]: + if not extra_pools: + return None + newly_allocated: list[tuple[PoolTransfer, Any, torch.Tensor]] = [] + for pool in extra_pools: + entry = self.mem_pool_host.entry_map.get(_pool_name_key(pool.name)) + if entry is None: + continue + if alloc_host: + if ( + pool.host_indices is not None + or pool.device_indices is None + or pool.use_anchor_host_indices + ): + continue + entry_pool, evict_fn, size = ( + entry.host_pool, + entry.host_evict_fn, + len(pool.device_indices), + ) + else: + if ( + pool.device_indices is not None + or pool.host_indices is None + or pool.use_anchor_device_indices + ): + continue + entry_pool, evict_fn, size = ( + entry.device_pool, + entry.device_evict_fn, + len(pool.host_indices), + ) + indices = entry_pool.alloc(size) + if indices is None and evict_fn: + evict_fn(size) + indices = entry_pool.alloc(size) + if indices is None: + for prev_pool, prev_entry_pool, prev_indices in newly_allocated: + prev_entry_pool.free(prev_indices) + if alloc_host: + prev_pool.host_indices = None + else: + prev_pool.device_indices = None + return None + if alloc_host: + pool.host_indices = indices + else: + pool.device_indices = indices + newly_allocated.append((pool, entry_pool, indices)) + return extra_pools diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index 147738e55813..992cf9549228 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -74,6 +74,7 @@ def __init__(self, id: Optional[int] = None): self.key: RadixKey = None self.value: Optional[torch.Tensor] = None self.mamba_value: Optional[torch.Tensor] = None + self.mamba_host_value: Optional[torch.Tensor] = None # invariant: for any node, if mamba_lock_ref is locked, full_lock_ref must be locked; # if full_lock_ref is locked, mamba_lock_ref doesn't need to be locked. So, # full_lock_ref is always >= mamba_lock_ref. @@ -98,6 +99,8 @@ def __init__(self, id: Optional[int] = None): self.next = None self.mamba_prev = None self.mamba_next = None + self.host_mamba_prev = None + self.host_mamba_next = None self.id = TreeNode.counter if id is None else id TreeNode.counter += 1 @@ -110,6 +113,14 @@ def evicted(self): def backuped(self): return self.host_value is not None + @property + def mamba_evicted(self): + return self.mamba_value is None + + @property + def mamba_backuped(self): + return self.mamba_host_value is not None + def protect_host(self): """Protect the host value from eviction.""" self.host_ref_counter += 1 diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index d98824420f7d..44c20b25855b 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -2,9 +2,11 @@ import logging import threading from collections import defaultdict +from dataclasses import dataclass from functools import wraps -from typing import Optional +from typing import Any, Callable, List, Optional +import numpy as np import psutil import torch @@ -19,6 +21,7 @@ ) from sglang.srt.mem_cache.memory_pool import ( KVCache, + MambaPool, MHATokenToKVPool, MLATokenToKVPool, NSATokenToKVPool, @@ -51,6 +54,10 @@ logger = logging.getLogger(__name__) +def _pool_name_key(pool_name) -> str: + return pool_name.value if hasattr(pool_name, "value") else str(pool_name) + + def synchronized(func): @wraps(func) def wrapper(self, *args, **kwargs): @@ -242,6 +249,9 @@ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: """ raise NotImplementedError() + def get_register_buffers(self) -> list[torch.Tensor]: + return [self.kv_buffer] + @synchronized def clear(self): # Initialize memory states and tracking structures. @@ -1074,6 +1084,399 @@ def get_page_buffer_meta(self, indices): return ptr_list, element_size_list +class MambaPoolHost(HostKVCache): + def __init__( + self, + device_pool: MambaPool, + host_to_device_ratio: float, + host_size: int, + pin_memory: bool = True, + device: str = "cpu", + allocator_type: str = "default", + ): + self.device_pool = device_pool + self.page_size = 1 + self.layout = "layer_first" + self.pin_memory = pin_memory + self.device = device + self.allocator = get_allocator_from_storage(allocator_type) + self.num_mamba_layers = device_pool.num_mamba_layers + self.conv_state_shapes = [ + conv_state.shape[2:] for conv_state in device_pool.mamba_cache.conv + ] + self.temporal_state_shape = device_pool.mamba_cache.temporal.shape[2:] + self.conv_dtype = device_pool.mamba_cache.conv[0].dtype + self.temporal_dtype = device_pool.mamba_cache.temporal.dtype + self.dtype = self.conv_dtype + self.size_per_token = self.get_size_per_token() + if host_size > 0: + self.size = int(host_size * 1e9 // self.size_per_token) + else: + self.size = int(device_pool.size * host_to_device_ratio) + self.page_num = self.size // self.page_size + 1 + self.size = self.page_num * self.page_size + + assert ( + self.size > device_pool.size + ), "The host memory should be larger than the device memory with the current protocol" + + host_mem = psutil.virtual_memory() + requested_bytes = self.size * self.size_per_token + ten_gb = 10 * (1024**3) + available_bytes = host_mem.available - ten_gb + if requested_bytes > available_bytes: + raise ValueError( + f"Not enough host memory available. Requesting " + f"{requested_bytes / 1e9:.2f} GB but only have " + f"{available_bytes / 1e9:.2f} GB free. Please reduce the " + f"size of the hierarchical cache." + ) + logger.info( + "Allocating %.2f GB host memory for hierarchical Mamba cache.", + requested_bytes / 1e9, + ) + + self.init_kv_buffer() + self.lock = threading.RLock() + self.clear() + + def _iter_serialized_page_tensors(self, index: int): + yield self.temporal_buffer[:, index : index + self.page_size] + for conv_buf in self.conv_buffer: + yield conv_buf[:, index : index + self.page_size] + + @staticmethod + def _flatten_tensor_bytes(tensor: torch.Tensor) -> torch.Tensor: + return tensor.contiguous().view(torch.uint8).reshape(-1) + + def init_kv_buffer(self): + alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device] + temporal_dims = (self.num_mamba_layers, self.size) + self.temporal_state_shape + self.temporal_buffer = alloc_func( + temporal_dims, + dtype=self.temporal_dtype, + device=self.device, + pin_memory=self.pin_memory, + allocator=self.allocator, + ) + self.conv_buffer = [] + for conv_shape in self.conv_state_shapes: + conv_dims = (self.num_mamba_layers, self.size) + conv_shape + self.conv_buffer.append( + alloc_func( + conv_dims, + dtype=self.conv_dtype, + device=self.device, + pin_memory=self.pin_memory, + allocator=self.allocator, + ) + ) + + @synchronized + def clear(self): + self.mem_state = torch.zeros( + (self.size,), dtype=torch.uint8, device=self.device + ) + self.free_slots = torch.arange(self.size, dtype=torch.int64) + + def available_size(self): + return len(self.free_slots) + + @synchronized + def alloc(self, need_size: int) -> Optional[torch.Tensor]: + assert ( + need_size % self.page_size == 0 + ), "The requested size should be a multiple of the page size." + if need_size > self.available_size(): + return None + select_index = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] + return select_index + + @synchronized + def free(self, indices: torch.Tensor) -> int: + self.free_slots = torch.cat([self.free_slots, indices]) + return len(indices) + + def get_size_per_token(self): + conv_total_size = 0 + for conv_shape in self.conv_state_shapes: + conv_total_size += int(np.prod(conv_shape)) * self.conv_dtype.itemsize + temporal_size = ( + int(np.prod(self.temporal_state_shape)) * self.temporal_dtype.itemsize + ) + return (conv_total_size + temporal_size) * self.num_mamba_layers + + @staticmethod + def _item_size_per_index(tensor: torch.Tensor) -> int: + if tensor.shape[0] == 0: + return 0 + return int(tensor[0].numel() * tensor.element_size()) + + @staticmethod + def _copy_tensor( + src: torch.Tensor, + dst: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + io_backend: Optional[str], + ) -> None: + if src_indices.numel() == 0: + return + if io_backend == "kernel" and not (_is_npu or _is_xpu or _is_mps): + transfer_kv_per_layer_mla( + src=src, + dst=dst, + src_indices=src_indices, + dst_indices=dst_indices, + item_size=MambaPoolHost._item_size_per_index(src), + ) + return + if io_backend == "direct" and not (_is_npu or _is_xpu or _is_mps): + transfer_kv_direct( + src_layers=[src], + dst_layers=[dst], + src_indices=src_indices, + dst_indices=dst_indices, + page_size=1, + ) + return + + src_take = src.index_select(0, src_indices.to(src.device)) + dst.index_copy_(0, dst_indices.to(dst.device), src_take.to(dst.device)) + + def load_to_device_per_layer( + self, + device_pool, + host_indices, + device_indices, + layer_id, + io_backend="kernel", + ): + for conv_idx, host_conv_state in enumerate(self.conv_buffer): + self._copy_tensor( + host_conv_state[layer_id], + device_pool.mamba_cache.conv[conv_idx][layer_id], + host_indices, + device_indices, + io_backend, + ) + self._copy_tensor( + self.temporal_buffer[layer_id], + device_pool.mamba_cache.temporal[layer_id], + host_indices, + device_indices, + io_backend, + ) + + def backup_from_device_all_layer( + self, device_pool, host_indices, device_indices, io_backend="kernel" + ): + for layer_id in range(self.num_mamba_layers): + for conv_idx, host_conv_state in enumerate(self.conv_buffer): + self._copy_tensor( + device_pool.mamba_cache.conv[conv_idx][layer_id], + host_conv_state[layer_id], + device_indices, + host_indices, + io_backend, + ) + self._copy_tensor( + device_pool.mamba_cache.temporal[layer_id], + self.temporal_buffer[layer_id], + device_indices, + host_indices, + io_backend, + ) + + def get_data_page(self, index, flat: bool = True) -> torch.Tensor: + data_page = torch.cat( + [ + self._flatten_tensor_bytes(tensor) + for tensor in self._iter_serialized_page_tensors(index) + ] + ) + return data_page.flatten() if flat else data_page + + def get_dummy_flat_data_page(self) -> torch.Tensor: + return torch.zeros( + self.page_size * self.size_per_token, + dtype=torch.uint8, + device=self.device, + pin_memory=self.pin_memory, + ) + + def set_from_flat_data_page( + self, + index: int, + data_page: torch.Tensor, + ) -> None: + flat_bytes = data_page.contiguous().view(torch.uint8).reshape(-1) + expected_num_bytes = self.page_size * self.size_per_token + if flat_bytes.numel() != expected_num_bytes: + raise ValueError( + f"Invalid Mamba page size: expected {expected_num_bytes} bytes, " + f"got {flat_bytes.numel()} bytes." + ) + + start = 0 + for tensor in self._iter_serialized_page_tensors(index): + num_bytes = tensor.numel() * tensor.element_size() + tensor_bytes = flat_bytes[start : start + num_bytes] + start += num_bytes + restored = tensor_bytes.view(dtype=tensor.dtype).reshape(tensor.shape) + tensor.copy_(restored) + + def get_register_buffers(self) -> list[torch.Tensor]: + return [self.temporal_buffer, *self.conv_buffer] + + +@dataclass +class PoolEntry: + name: Any + host_pool: Any + device_pool: Any + layer_mapper: Callable[[int], Optional[int]] + is_primary_index_anchor: bool = False + host_evict_fn: Optional[Callable] = None + device_evict_fn: Optional[Callable] = None + + +class HostPoolGroup: + def __init__(self, entries: list[PoolEntry]): + if not entries: + raise ValueError("HostPoolGroup requires at least one pool entry.") + self.entries = entries + self.entry_map = {_pool_name_key(entry.name): entry for entry in entries} + self.anchor_entry = next( + (entry for entry in entries if entry.is_primary_index_anchor), + entries[0], + ) + + self.layout = self.anchor_entry.host_pool.layout + self.page_size = self.anchor_entry.host_pool.page_size + self.device = self.anchor_entry.host_pool.device + self.size = self.anchor_entry.host_pool.size + self.size_per_token = self.anchor_entry.host_pool.size_per_token + self.dtype = self.anchor_entry.host_pool.dtype + + def clear(self) -> None: + for entry in self.entries: + entry.host_pool.clear() + + def alloc(self, need_size: int) -> Optional[torch.Tensor]: + return self.anchor_entry.host_pool.alloc(need_size) + + def free(self, indices: torch.Tensor) -> int: + return self.anchor_entry.host_pool.free(indices) + + def get_data_page(self, index, flat: bool = True): + return self.anchor_entry.host_pool.get_data_page(index, flat) + + def get_dummy_flat_data_page(self): + return self.anchor_entry.host_pool.get_dummy_flat_data_page() + + def set_from_flat_data_page(self, index: int, data_page) -> None: + return self.anchor_entry.host_pool.set_from_flat_data_page(index, data_page) + + def get_ksize_per_token(self): + return self.anchor_entry.host_pool.get_ksize_per_token() + + def get_page_buffer_meta(self, indices): + return self.anchor_entry.host_pool.get_page_buffer_meta(indices) + + def get_register_buffers(self) -> list[torch.Tensor]: + return self.anchor_entry.host_pool.get_register_buffers() + + @staticmethod + def _resolve_host_indices(anchor_host_indices, transfer): + if transfer.host_indices is not None: + return transfer.host_indices + if getattr(transfer, "use_anchor_host_indices", False): + return anchor_host_indices + return None + + @staticmethod + def _resolve_device_indices(anchor_device_indices, transfer): + if transfer.device_indices is not None: + return transfer.device_indices + if getattr(transfer, "use_anchor_device_indices", False): + return anchor_device_indices + return None + + def load_to_device_per_layer( + self, + device_pool, + host_indices, + device_indices, + layer_id, + io_backend, + pool_transfers: Optional[list] = None, + ) -> None: + anchor = self.anchor_entry + local_layer_id = anchor.layer_mapper(layer_id) + if local_layer_id is not None and host_indices.numel() > 0: + anchor.host_pool.load_to_device_per_layer( + anchor.device_pool, + host_indices, + device_indices, + local_layer_id, + io_backend, + ) + + for transfer in pool_transfers or []: + entry = self.entry_map.get(_pool_name_key(transfer.name)) + if entry is None: + continue + transfer_host_indices = self._resolve_host_indices(host_indices, transfer) + transfer_device_indices = self._resolve_device_indices( + device_indices, transfer + ) + if transfer_host_indices is None or transfer_device_indices is None: + continue + local_layer_id = entry.layer_mapper(layer_id) + if local_layer_id is None: + continue + entry.host_pool.load_to_device_per_layer( + entry.device_pool, + transfer_host_indices, + transfer_device_indices, + local_layer_id, + io_backend, + ) + + def backup_from_device_all_layer( + self, + device_pool, + host_indices, + device_indices, + io_backend, + pool_transfers: Optional[list] = None, + ) -> None: + self.anchor_entry.host_pool.backup_from_device_all_layer( + self.anchor_entry.device_pool, + host_indices, + device_indices, + io_backend, + ) + for transfer in pool_transfers or []: + entry = self.entry_map.get(_pool_name_key(transfer.name)) + if entry is None: + continue + transfer_host_indices = self._resolve_host_indices(host_indices, transfer) + transfer_device_indices = self._resolve_device_indices( + device_indices, transfer + ) + if transfer_host_indices is None or transfer_device_indices is None: + continue + entry.host_pool.backup_from_device_all_layer( + entry.device_pool, + transfer_host_indices, + transfer_device_indices, + io_backend, + ) + + class NSATokenToKVPoolHost(MLATokenToKVPoolHost): device_pool: NSATokenToKVPool @@ -1307,3 +1710,125 @@ def backup_from_device_all_layer( self._backup_indexer_from_device_all_layer( device_pool, host_indices, device_indices, io_backend ) + + def get_indexer_page_views(self, host_indices: torch.Tensor) -> List[torch.Tensor]: + if host_indices.numel() == 0: + return [] + if host_indices.numel() % self.page_size != 0: + raise ValueError( + "Index buffer transfer expects page-aligned indices for NSA." + ) + if self.layout not in ["page_first", "page_first_direct"]: + raise ValueError( + "Direct NSA indexer storage requires page_first/page_first_direct " + f"layout, got {self.layout}." + ) + + host_page_indices = ( + host_indices.reshape(-1, self.page_size)[:, 0] // self.page_size + ) + if host_page_indices.device.type != "cpu": + host_page_indices = host_page_indices.cpu() + return [ + self.index_k_with_scale_buffer[page_idx] + for page_idx in host_page_indices.tolist() + ] + + +class NSAIndexerHostPool: + def __init__(self, anchor_pool: NSATokenToKVPoolHost): + if anchor_pool.layout not in ["page_first", "page_first_direct"]: + raise ValueError( + "NSA storage/offload only supports page_first/page_first_direct " + f"layout, got {anchor_pool.layout}." + ) + self.anchor_pool = anchor_pool + self.page_size = anchor_pool.page_size + self.size = anchor_pool.size + self.layout = anchor_pool.layout + self.device = anchor_pool.device + self.pin_memory = anchor_pool.pin_memory + self.dtype = anchor_pool.indexer_dtype + self.size_per_token = ( + anchor_pool.layer_num + * anchor_pool.indexer_size_per_token + * anchor_pool.indexer_dtype.itemsize + ) + + def clear(self) -> None: + return + + def alloc(self, need_size: int) -> Optional[torch.Tensor]: + return self.anchor_pool.alloc(need_size) + + def free(self, indices: torch.Tensor) -> int: + return self.anchor_pool.free(indices) + + def load_to_device_per_layer( + self, + device_pool, + host_indices, + device_indices, + layer_id, + io_backend, + ): + self.anchor_pool._load_indexer_to_device_per_layer( + device_pool, host_indices, device_indices, layer_id, io_backend + ) + + def backup_from_device_all_layer( + self, device_pool, host_indices, device_indices, io_backend + ): + self.anchor_pool._backup_indexer_from_device_all_layer( + device_pool, host_indices, device_indices, io_backend + ) + + def get_data_page(self, index, flat: bool = True) -> torch.Tensor: + page_index = index // self.page_size + data_page = self.anchor_pool.index_k_with_scale_buffer[ + page_index : page_index + 1 + ] + return data_page.flatten() if flat else data_page + + def get_dummy_flat_data_page(self) -> torch.Tensor: + return torch.zeros( + ( + 1, + self.anchor_pool.layer_num, + 1, + self.anchor_pool.indexer_page_stride_size, + ), + dtype=self.dtype, + device=self.device, + pin_memory=self.pin_memory, + ).flatten() + + def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: + page_index = index // self.page_size + if self.layout in ["page_first", "page_first_direct"]: + self.anchor_pool.index_k_with_scale_buffer[page_index : page_index + 1] = ( + data_page.reshape( + 1, + self.anchor_pool.layer_num, + 1, + self.anchor_pool.indexer_page_stride_size, + ) + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") + + def get_page_buffer_meta(self, indices): + page_indices, _ = self.anchor_pool._get_indexer_page_indices(indices, indices) + ptr_list = [] + element_size = ( + self.anchor_pool.layer_num + * self.anchor_pool.indexer_page_stride_size + * self.dtype.itemsize + ) + base_ptr = self.anchor_pool.index_k_with_scale_buffer.data_ptr() + for page_index in page_indices.tolist(): + ptr_list.append(base_ptr + page_index * element_size) + return ptr_list, [element_size] * len(ptr_list) + + def get_register_buffers(self) -> list[torch.Tensor]: + return [self.anchor_pool.index_k_with_scale_buffer] diff --git a/python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py b/python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py index 14494d819808..9897458490db 100644 --- a/python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +++ b/python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py @@ -29,8 +29,11 @@ def test_with_page_size(self): config = HiCacheStorageConfig( tp_rank=0, tp_size=1, + pp_rank=0, + pp_size=1, is_mla_model=False, - is_page_first_layout=True, + enable_storage_metrics=False, + layout="page_first", model_name="test", ) for page_size in range(1, 3): diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index 9aa82892d2f6..28200230fd43 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -168,7 +168,7 @@ def __init__( dtype: torch.dtype, metadata_client: Hf3fsMetadataInterface, is_mla_model: bool = False, - is_page_first_layout: bool = False, + layout: str = "layer_first", use_mock_client: bool = False, enable_storage_metrics: bool = False, ): @@ -183,7 +183,7 @@ def __init__( self.dtype = dtype self.metadata_client = metadata_client self.is_mla_model = is_mla_model - self.is_page_first_layout = is_page_first_layout + self.layout = layout self.enable_storage_metrics = enable_storage_metrics self.numel = self.bytes_per_page // self.dtype.itemsize self.num_pages = self.file_size // self.bytes_per_page @@ -254,10 +254,10 @@ def from_env_config( use_mock_client = False if storage_config is not None: - rank, is_mla_model, is_page_first_layout = ( + rank, is_mla_model, layout = ( storage_config.tp_rank, storage_config.is_mla_model, - storage_config.is_page_first_layout, + storage_config.layout, ) if storage_config.extra_config is not None: @@ -265,10 +265,10 @@ def from_env_config( "use_mock_hf3fs_client", False ) else: - rank, is_mla_model, is_page_first_layout = ( + rank, is_mla_model, layout = ( 0, False, - False, + "layer_first", ) mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md" @@ -288,7 +288,7 @@ def from_env_config( client_timeout=5, dtype=dtype, metadata_client=Hf3fsLocalMetadataClient(), - is_page_first_layout=is_page_first_layout, + layout=layout, use_mock_client=use_mock_client, ) @@ -339,7 +339,7 @@ def from_env_config( dtype=dtype, metadata_client=metadata_client, is_mla_model=is_mla_model, - is_page_first_layout=is_page_first_layout, + layout=layout, use_mock_client=use_mock_client, enable_storage_metrics=storage_config.enable_storage_metrics, ) diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py index 2c815fd7e20d..d976b9d0dbac 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py @@ -15,8 +15,15 @@ HiCacheStorage, HiCacheStorageConfig, HiCacheStorageExtraInfo, + PoolHitPolicy, + PoolName, + PoolTransfer, + PoolTransferResult, +) +from sglang.srt.mem_cache.memory_pool_host import ( + HostKVCache, + HostTensorAllocator, ) -from sglang.srt.mem_cache.memory_pool_host import HostKVCache, HostTensorAllocator from sglang.srt.observability.metrics_collector import StorageMetrics DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB @@ -25,6 +32,10 @@ logger = logging.getLogger(__name__) +def _pool_name_key(pool_name) -> str: + return pool_name.value if hasattr(pool_name, "value") else str(pool_name) + + class MooncakeHostTensorAllocator(HostTensorAllocator): def __init__(self): super().__init__() @@ -275,7 +286,6 @@ def register_buffer(self, tensor: torch.Tensor): class MooncakeStore(HiCacheStorage, MooncakeBaseStore): - def __init__( self, storage_config: HiCacheStorageConfig = None, mem_pool: HostKVCache = None ): @@ -484,16 +494,29 @@ def register_mem_pool_host(self, mem_pool_host: HostKVCache): "page_first_direct", "page_head", ], "mooncake store storage backend only support page first or page first direct layout" - buffer = self.mem_pool_host.kv_buffer - try: - super().register_buffer(buffer) - except TypeError as err: - logger.error("Failed to register buffer to Mooncake Store: %s", err) - raise TypeError("Mooncake Store Register Buffer Error.") from err + for buffer in self.mem_pool_host.get_register_buffers(): + try: + super().register_buffer(buffer) + except TypeError as err: + logger.error("Failed to register buffer to Mooncake Store: %s", err) + raise TypeError("Mooncake Store Register Buffer Error.") from err bytes_per_page = mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size self.gb_per_page = bytes_per_page / (1 << 30) + def register_mem_host_pool_v2(self, host_pool: HostKVCache, host_pool_name): + super().register_mem_host_pool_v2(host_pool, _pool_name_key(host_pool_name)) + for buffer in host_pool.get_register_buffers(): + try: + super().register_buffer(buffer) + except TypeError as err: + logger.error( + "Failed to register pool buffer to Mooncake Store for %s: %s", + host_pool_name, + err, + ) + raise TypeError("Mooncake Store Register Pool Buffer Error.") from err + def _get_mha_split_heads_buffer_meta(self, keys, indices): ptr_list, element_size_list = ( self.mem_pool_host.get_split_heads_page_buffer_meta( @@ -536,6 +559,15 @@ def _batch_preprocess(self, keys, host_indices): else: return self._get_mha_buffer_meta(keys, host_indices) + def _get_extra_keys_for_nsa(self, keys: List[str]) -> List[str]: + if self.is_mla_backend: + suffix = self.mla_suffix + else: + suffix = self.mha_suffix + if suffix: + return [f"{key}_{suffix}{self._NSA_INDEXER_SUFFIX}" for key in keys] + return [f"{key}{self._NSA_INDEXER_SUFFIX}" for key in keys] + def _batch_postprocess(self, results: List[int], is_set_operate=False): """ refer to https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-store/include/pybind_client.h @@ -645,6 +677,178 @@ def batch_set_v1( return self._batch_postprocess(set_results, is_set_operate=True) + def _apply_extra_backend_tag(self, keys: List[str]) -> List[str]: + if self.extra_backend_tag is None: + return keys + prefix = self.extra_backend_tag + return [f"{prefix}_{key}" for key in keys] + + def _get_pool_keys(self, pool_name, keys: List[str]) -> List[str]: + pool_name_key = _pool_name_key(pool_name) + tagged_keys = self._apply_extra_backend_tag(keys) + if pool_name_key == PoolName.KV.value: + return tagged_keys + if pool_name_key == PoolName.NSA.value: + return self._get_extra_keys_for_nsa(tagged_keys) + + if self.is_mla_backend: + suffix = self.mla_suffix + else: + suffix = self.mha_suffix + if suffix: + return [f"{key}_{suffix}_{pool_name_key}" for key in tagged_keys] + return [f"{key}_{pool_name_key}" for key in tagged_keys] + + def _batch_exists_for_pool(self, pool_name, keys: List[str]) -> int: + query_keys = self._get_pool_keys(pool_name, keys) + exist_result = self._batch_exist(query_keys) + for i, status in enumerate(exist_result): + if status != 1: + return i + return len(query_keys) + + def batch_exists_v2( + self, + keys: List[str], + pool_transfers: Optional[List[PoolTransfer]] = None, + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ): + kv_hit_pages = self.batch_exists(keys, extra_info) + hit_count: dict[str, int] = ( + {PoolName.KV.value: kv_hit_pages} if kv_hit_pages else {} + ) + final_pages = kv_hit_pages + + for transfer in pool_transfers or []: + if final_pages == 0: + break + pool_name = _pool_name_key(transfer.name) + if transfer.hit_policy == PoolHitPolicy.ALL_PAGES: + boundary = self._batch_exists_for_pool( + transfer.name, keys[:kv_hit_pages] + ) + else: + trailing = max(1, len(transfer.keys) if transfer.keys else 1) + boundary = 0 + for prefix_len in range(kv_hit_pages, 0, -1): + trailing_keys = keys[max(0, prefix_len - trailing) : prefix_len] + if self._batch_exists_for_pool(transfer.name, trailing_keys) == len( + trailing_keys + ): + boundary = prefix_len + break + hit_count[pool_name] = boundary + final_pages = min(final_pages, boundary) + + return PoolTransferResult( + kv_hit_pages=final_pages, + extra_pool_hit_pages=hit_count, + ) + + def _transfer_meta_v2(self, transfer: PoolTransfer): + pool_name = _pool_name_key(transfer.name) + host_pool = self.registered_pools[pool_name] + host_indices = transfer.host_indices + keys = transfer.keys or [] + if host_indices is None: + raise ValueError(f"host_indices is required for {pool_name}") + + if pool_name == PoolName.KV.value: + key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess( + self._apply_extra_backend_tag(keys), host_indices + ) + return host_pool, key_strs, buffer_ptrs, buffer_sizes, None + + if hasattr(host_pool, "get_page_buffer_meta"): + key_strs = self._get_pool_keys(transfer.name, keys) + buffer_ptrs, buffer_sizes = host_pool.get_page_buffer_meta(host_indices) + return host_pool, key_strs, buffer_ptrs, buffer_sizes, None + + stage_pages = [ + host_pool.get_data_page( + host_indices[i * host_pool.page_size].item(), flat=True + ) + for i in range(len(keys)) + ] + for page in stage_pages: + super().register_buffer(page) + key_strs = self._get_pool_keys(transfer.name, keys) + buffer_ptrs = [page.data_ptr() for page in stage_pages] + buffer_sizes = [page.numel() * page.element_size() for page in stage_pages] + return host_pool, key_strs, buffer_ptrs, buffer_sizes, stage_pages + + def batch_get_v2( + self, + transfers: List[PoolTransfer], + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> dict[str, List[bool]]: + results: dict[str, List[bool]] = {} + for transfer in transfers: + pool_name = _pool_name_key(transfer.name) + host_pool, key_strs, buffer_ptrs, buffer_sizes, stage_pages = ( + self._transfer_meta_v2(transfer) + ) + get_results = self._get_batch_zero_copy_impl( + key_strs, buffer_ptrs, buffer_sizes + ) + if pool_name == PoolName.KV.value: + results[pool_name] = self._batch_postprocess( + get_results, is_set_operate=False + ) + continue + + pool_results = [res > 0 for res in get_results] + if stage_pages is not None: + for i, ok in enumerate(pool_results): + if not ok: + break + host_pool.set_from_flat_data_page( + transfer.host_indices[i * host_pool.page_size].item(), + stage_pages[i], + ) + results[pool_name] = pool_results + return results + + def batch_set_v2( + self, + transfers: List[PoolTransfer], + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> dict[str, List[bool]]: + results: dict[str, List[bool]] = {} + for transfer in transfers: + pool_name = _pool_name_key(transfer.name) + _host_pool, key_strs, buffer_ptrs, buffer_sizes, _stage_pages = ( + self._transfer_meta_v2(transfer) + ) + exist_result = self._batch_exist(key_strs) + set_keys = [] + set_buffer_ptrs = [] + set_buffer_sizes = [] + set_indices = [] + set_results = [-1] * len(key_strs) + for i in range(len(key_strs)): + if exist_result[i] != 1: + set_keys.append(key_strs[i]) + set_buffer_ptrs.append(buffer_ptrs[i]) + set_buffer_sizes.append(buffer_sizes[i]) + set_indices.append(i) + else: + set_results[i] = 0 + if set_keys: + put_results = self._put_batch_zero_copy_impl( + set_keys, set_buffer_ptrs, set_buffer_sizes + ) + for i, idx in enumerate(set_indices): + set_results[idx] = put_results[i] + + if pool_name == PoolName.KV.value: + results[pool_name] = self._batch_postprocess( + set_results, is_set_operate=True + ) + else: + results[pool_name] = [res == 0 for res in set_results] + return results + def set( self, key, @@ -771,10 +975,7 @@ def exists(self, key) -> bool: def batch_exists( self, keys, extra_info: Optional[HiCacheStorageExtraInfo] = None ) -> int: - # Apply extra_backend_tag prefix if available - if self.extra_backend_tag is not None: - prefix = self.extra_backend_tag - keys = [f"{prefix}_{key}" for key in keys] + keys = self._apply_extra_backend_tag(keys) if self.is_mla_backend: query_keys = [f"{key}_{self.mla_suffix}_k" for key in keys] @@ -800,9 +1001,7 @@ def batch_exists( return len(query_keys) // key_multiplier def close(self): - # MooncakeDistributedStore will automatically call the destructor, so - # it is unnecessary to close it manually. - pass + return def clear(self) -> None: self.store.remove_all() diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py index ae4788cb5dd4..67643dad9da6 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py @@ -155,8 +155,11 @@ def test_batch_operation(config: HiCacheStorageConfig): is_mla_model=False, tp_rank=0, tp_size=1, + pp_rank=0, + pp_size=1, + enable_storage_metrics=False, model_name=None, - is_page_first_layout=True, + layout="page_first", ) ) test_batch_operation( @@ -164,8 +167,11 @@ def test_batch_operation(config: HiCacheStorageConfig): is_mla_model=True, tp_rank=0, tp_size=1, + pp_rank=0, + pp_size=1, + enable_storage_metrics=False, model_name=None, - is_page_first_layout=True, + layout="page_first", ) ) test_batch_operation( @@ -173,8 +179,11 @@ def test_batch_operation(config: HiCacheStorageConfig): is_mla_model=False, tp_rank=1, tp_size=4, + pp_rank=0, + pp_size=1, + enable_storage_metrics=False, model_name=None, - is_page_first_layout=True, + layout="page_first", ) ) test_batch_operation( @@ -182,8 +191,11 @@ def test_batch_operation(config: HiCacheStorageConfig): is_mla_model=True, tp_rank=3, tp_size=8, + pp_rank=0, + pp_size=1, + enable_storage_metrics=False, model_name=None, - is_page_first_layout=True, + layout="page_first", ) ) logger.info(f"✅ All tests passed") diff --git a/python/sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py b/python/sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py index ad1796f0abd9..8a156eae81cd 100755 --- a/python/sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +++ b/python/sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py @@ -37,8 +37,11 @@ def setUp(self): self.storage_config = HiCacheStorageConfig( tp_rank=0, tp_size=2, + pp_rank=0, + pp_size=1, is_mla_model=False, - is_page_first_layout=False, + enable_storage_metrics=False, + layout="layer_first", model_name="test_model", ) diff --git a/python/sglang/srt/mem_cache/test_hicache_pool_v2.py b/python/sglang/srt/mem_cache/test_hicache_pool_v2.py new file mode 100644 index 000000000000..1cecd20a1e14 --- /dev/null +++ b/python/sglang/srt/mem_cache/test_hicache_pool_v2.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +import os + +import torch + +from sglang.srt.mem_cache.hicache_storage import ( + HiCacheFile, + HiCacheStorageConfig, + PoolHitPolicy, + PoolName, + PoolTransfer, +) +from sglang.srt.mem_cache.memory_pool_host import ( + HostPoolGroup, + NSAIndexerHostPool, + PoolEntry, +) + + +class _RecordingPool: + def __init__(self, page_size: int = 2): + self.page_size = page_size + self.layout = "page_first" + self.device = "cpu" + self.size = 64 + self.size_per_token = 1 + self.dtype = torch.uint8 + self.load_calls = [] + self.backup_calls = [] + + def clear(self): + return + + def alloc(self, need_size: int): + return torch.arange(need_size, dtype=torch.int64) + + def free(self, indices: torch.Tensor): + return len(indices) + + def load_to_device_per_layer( + self, device_pool, host_indices, device_indices, layer_id, io_backend + ): + self.load_calls.append((host_indices.clone(), device_indices.clone(), layer_id)) + + def backup_from_device_all_layer( + self, device_pool, host_indices, device_indices, io_backend + ): + self.backup_calls.append((host_indices.clone(), device_indices.clone())) + + def get_data_page(self, index, flat: bool = True): + return torch.zeros((4,), dtype=torch.uint8) + + def get_dummy_flat_data_page(self): + return torch.zeros((4,), dtype=torch.uint8) + + def set_from_flat_data_page(self, index: int, data_page: torch.Tensor): + return + + def get_ksize_per_token(self): + return 1 + + def get_page_buffer_meta(self, indices): + return [0], [4] + + def get_register_buffers(self): + return [] + + +class _FlatPagePool: + def __init__(self, page_size: int = 2, page_bytes: int = 4): + self.page_size = page_size + self.page_bytes = page_bytes + self.layout = "page_first" + self.device = "cpu" + self.size = 64 + self.size_per_token = 1 + self.dtype = torch.uint8 + self.pages: dict[int, torch.Tensor] = {} + + def get_data_page(self, index, flat: bool = True): + return self.pages[index].clone() + + def get_dummy_flat_data_page(self): + return torch.zeros((self.page_bytes,), dtype=torch.uint8) + + def set_from_flat_data_page(self, index: int, data_page: torch.Tensor): + self.pages[index] = data_page.clone() + + def get_register_buffers(self): + return [] + + def get_ksize_per_token(self): + return 1 + + +class _FakeNSAAnchor: + def __init__(self): + self.layout = "page_first" + self.page_size = 2 + self.size = 8 + self.device = "cpu" + self.pin_memory = False + self.layer_num = 2 + self.indexer_dtype = torch.uint8 + self.indexer_size_per_token = 3 + self.indexer_page_stride_size = 6 + self.index_k_with_scale_buffer = torch.zeros((4, 2, 1, 6), dtype=torch.uint8) + + def _get_indexer_page_indices(self, host_indices, device_indices): + host_page_indices = ( + host_indices.reshape(-1, self.page_size)[:, 0] // self.page_size + ) + device_page_indices = ( + device_indices.reshape(-1, self.page_size)[:, 0] // self.page_size + ) + return host_page_indices, device_page_indices + + +def _storage_config(layout: str = "page_first") -> HiCacheStorageConfig: + return HiCacheStorageConfig( + tp_rank=0, + tp_size=1, + pp_rank=0, + pp_size=1, + is_mla_model=True, + enable_storage_metrics=False, + layout=layout, + model_name="test", + ) + + +def test_host_pool_group_reuses_anchor_indices(): + anchor = _RecordingPool(page_size=2) + sidecar = _RecordingPool(page_size=2) + group = HostPoolGroup( + [ + PoolEntry("kv", anchor, object(), lambda layer_id: layer_id, True), + PoolEntry("nsa", sidecar, object(), lambda layer_id: layer_id), + ] + ) + + anchor_host = torch.tensor([4, 5], dtype=torch.int64) + anchor_device = torch.tensor([14, 15], dtype=torch.int64) + transfer = PoolTransfer( + name=PoolName.NSA, + use_anchor_host_indices=True, + use_anchor_device_indices=True, + ) + + group.load_to_device_per_layer( + object(), + anchor_host, + anchor_device, + layer_id=0, + io_backend="direct", + pool_transfers=[transfer], + ) + group.backup_from_device_all_layer( + object(), + anchor_host, + anchor_device, + io_backend="direct", + pool_transfers=[transfer], + ) + + assert sidecar.load_calls[0][0].tolist() == anchor_host.tolist() + assert sidecar.load_calls[0][1].tolist() == anchor_device.tolist() + assert sidecar.backup_calls[0][0].tolist() == anchor_host.tolist() + assert sidecar.backup_calls[0][1].tolist() == anchor_device.tolist() + + +def test_nsa_indexer_host_pool_roundtrip_and_meta(): + anchor = _FakeNSAAnchor() + pool = NSAIndexerHostPool(anchor) + data_page = torch.arange(12, dtype=torch.uint8) + + pool.set_from_flat_data_page(2, data_page) + assert torch.equal(pool.get_data_page(2), data_page) + + ptrs, sizes = pool.get_page_buffer_meta(torch.tensor([2, 3], dtype=torch.int64)) + assert len(ptrs) == 1 + assert sizes == [12] + + +def test_file_backend_nsa_key_compatibility_and_prefix(tmp_path): + storage = HiCacheFile(_storage_config(), file_path=str(tmp_path)) + kv_pool = _FlatPagePool() + nsa_pool = _FlatPagePool() + storage.register_mem_pool_host(kv_pool) + storage.register_mem_host_pool_v2(kv_pool, PoolName.KV.value) + storage.register_mem_host_pool_v2(nsa_pool, PoolName.NSA.value) + + kv_pool.set_from_flat_data_page(0, torch.tensor([1, 2, 3, 4], dtype=torch.uint8)) + kv_pool.set_from_flat_data_page(4, torch.tensor([5, 6, 7, 8], dtype=torch.uint8)) + nsa_pool.set_from_flat_data_page(0, torch.tensor([9, 9, 9, 9], dtype=torch.uint8)) + + storage.batch_set_v2( + [ + PoolTransfer( + name=PoolName.KV, + host_indices=torch.tensor([0, 1, 4, 5], dtype=torch.int64), + keys=["hash0", "hash2"], + ), + PoolTransfer( + name=PoolName.NSA, + host_indices=torch.tensor([0, 1], dtype=torch.int64), + keys=["hash0"], + ), + ] + ) + + assert os.path.exists(tmp_path / "hash0__nsa_idx_test.bin") + + kv_only = storage.batch_exists_v2(["hash0", "hash1", "hash2"]) + assert kv_only.kv_hit_pages == 1 + + hit = storage.batch_exists_v2( + ["hash0", "hash2"], + [ + PoolTransfer( + name=PoolName.NSA, + hit_policy=PoolHitPolicy.ALL_PAGES, + ) + ], + ) + assert hit.kv_hit_pages == 1