diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 92d509732dc..8cb237cf171 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -149,6 +149,7 @@ def __init__( self, token_to_kv_pool_allocator: TokenToKVPoolAllocator, mem_pool_host: HostKVCache, + page_size: int, load_cache_event: threading.Event = None, write_policy: str = "write_through_selective", ): @@ -156,6 +157,7 @@ def __init__( self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() self.mem_pool_host = mem_pool_host self.write_policy = write_policy + self.page_size = page_size self.load_cache_event = load_cache_event self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) @@ -184,7 +186,12 @@ def __init__( self.load_stream = torch.cuda.Stream() self.write_thread = threading.Thread( - target=self.write_thread_func_buffer, daemon=True + target=( + self.write_thread_func_buffer + if self.page_size == 1 + else self.write_thread_func_direct + ), + daemon=True, ) self.load_thread = threading.Thread( target=self.load_thread_func_layer_by_layer, daemon=True @@ -205,7 +212,12 @@ def reset(self): self.ack_load_queue.queue.clear() self.write_thread = threading.Thread( - target=self.write_thread_func_buffer, daemon=True + target=( + self.write_thread_func_buffer + if self.page_size == 1 + else self.write_thread_func_direct + ), + daemon=True, ) self.load_thread = threading.Thread( target=self.load_thread_func_layer_by_layer, daemon=True @@ -260,10 +272,12 @@ def write_thread_func_direct(self): while not self.stop_event.is_set(): try: operation = self.write_queue.get(block=True, timeout=1) - operation.data = self.mem_pool_device.get_flat_data( - operation.device_indices + self.mem_pool_host.write_page_all_layers( + operation.host_indices, + operation.device_indices, + self.mem_pool_device, ) - self.mem_pool_host.transfer(operation.host_indices, operation.data) + self.write_stream.synchronize() self.mem_pool_host.complete_io(operation.host_indices) for node_id in operation.node_ids: if node_id != 0: @@ -320,12 +334,21 @@ def load_thread_func_layer_by_layer(self): self.layer_done_counter.reset() for i in range(self.mem_pool_host.layer_num): - flat_data = self.mem_pool_host.get_flat_data_by_layer( - batch_operation.host_indices, i - ) - self.mem_pool_device.transfer_per_layer( - batch_operation.device_indices, flat_data, i - ) + if self.page_size == 1: + flat_data = self.mem_pool_host.get_flat_data_by_layer( + batch_operation.host_indices, i + ) + self.mem_pool_device.transfer_per_layer( + batch_operation.device_indices, flat_data, i + ) + else: + self.mem_pool_host.load_page_per_layer( + batch_operation.host_indices, + batch_operation.device_indices, + self.mem_pool_device, + i, + ) + self.load_stream.synchronize() self.layer_done_counter.increment() self.mem_pool_host.complete_io(batch_operation.host_indices) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4a434cc5ad6..e917f42b114 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1282,7 +1282,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: ] if self.enable_hierarchical_cache: - self.tree_cache.read_to_load_cache() + self.tree_cache.ready_to_load_cache() if adder.new_chunked_req is not None: assert self.chunked_req is None diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 748960f195e..61b52b50562 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -16,7 +16,6 @@ TokenToKVPoolAllocator, ) from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode -from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match logger = logging.getLogger(__name__) @@ -31,29 +30,25 @@ def __init__( page_size: int, hicache_ratio: float, ): - if page_size != 1: - raise ValueError( - "Page size larger than 1 is not yet supported in HiRadixCache." - ) self.kv_cache = token_to_kv_pool_allocator.get_kvcache() if isinstance(self.kv_cache, MHATokenToKVPool): self.token_to_kv_pool_host = MHATokenToKVPoolHost( - self.kv_cache, hicache_ratio + self.kv_cache, hicache_ratio, page_size ) elif isinstance(self.kv_cache, MLATokenToKVPool): self.token_to_kv_pool_host = MLATokenToKVPoolHost( - self.kv_cache, hicache_ratio + self.kv_cache, hicache_ratio, page_size ) else: - raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.") + raise ValueError(f"HiRadixCache only supports MHA and MLA yet") self.tp_group = tp_cache_group - self.page_size = page_size self.load_cache_event = threading.Event() self.cache_controller = HiCacheController( token_to_kv_pool_allocator, self.token_to_kv_pool_host, + page_size, load_cache_event=self.load_cache_event, ) @@ -65,7 +60,7 @@ def __init__( self.write_through_threshold = 1 self.load_back_threshold = 10 super().__init__( - req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False + req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False ) def reset(self): @@ -299,18 +294,26 @@ def init_load_back( return last_node, prefix_indices - def read_to_load_cache(self): + def ready_to_load_cache(self): self.load_cache_event.set() def match_prefix(self, key: List[int], include_evicted=False, **kwargs): - if self.disable: - return [], self.root_node + empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) + if self.disable or len(key) == 0: + if include_evicted: + return empty_value, self.root_node, self.root_node + else: + return empty_value, self.root_node + + if self.page_size != 1: + page_aligned_len = len(key) // self.page_size * self.page_size + key = key[:page_aligned_len] value, last_node = self._match_prefix_helper(self.root_node, key) if value: value = torch.cat(value) else: - value = torch.tensor([], dtype=torch.int64) + value = empty_value last_node_global = last_node while last_node.evicted: @@ -323,11 +326,13 @@ def match_prefix(self, key: List[int], include_evicted=False, **kwargs): def _match_prefix_helper(self, node: TreeNode, key: List): node.last_access_time = time.time() + child_key = self.get_child_key_fn(key) value = [] - while len(key) > 0 and key[0] in node.children.keys(): - child = node.children[key[0]] + + while len(key) > 0 and child_key in node.children.keys(): + child = node.children[child_key] child.last_access_time = time.time() - prefix_len = _key_match(child.key, key) + prefix_len = self.key_match_fn(child.key, key) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) if not new_node.evicted: @@ -339,12 +344,16 @@ def _match_prefix_helper(self, node: TreeNode, key: List): value.append(child.value) node = child key = key[prefix_len:] + + if len(key): + child_key = self.get_child_key_fn(key) + return value, node def _split_node(self, key, child: TreeNode, split_len: int): # child node split into new_node -> child new_node = TreeNode() - new_node.children = {key[split_len]: child} + new_node.children = {self.get_child_key_fn(key[split_len:]): child} new_node.parent = child.parent new_node.lock_ref = child.lock_ref new_node.key = child.key[:split_len] @@ -361,7 +370,7 @@ def _split_node(self, key, child: TreeNode, split_len: int): child.host_value = child.host_value[split_len:] child.parent = new_node child.key = child.key[split_len:] - new_node.parent.children[key[0]] = new_node + new_node.parent.children[self.get_child_key_fn(key)] = new_node return new_node def _insert_helper(self, node: TreeNode, key: List, value): @@ -369,52 +378,53 @@ def _insert_helper(self, node: TreeNode, key: List, value): if len(key) == 0: return 0 - if key[0] in node.children.keys(): - child = node.children[key[0]] - prefix_len = _key_match(child.key, key) + child_key = self.get_child_key_fn(key) + total_prefix_length = 0 + + while len(key) > 0 and child_key in node.children.keys(): + node = node.children[child_key] + node.last_access_time = time.time() + prefix_len = self.key_match_fn(node.key, key) - if prefix_len == len(child.key): - if child.evicted: + if prefix_len == len(node.key): + if node.evicted: # change the reference if the node is evicted # this often happens in the case of KV cache recomputation - child.value = value[:prefix_len] - self.token_to_kv_pool_host.update_synced(child.host_value) - self.evictable_size_ += len(value[:prefix_len]) - return self._insert_helper( - child, key[prefix_len:], value[prefix_len:] - ) + node.value = value[:prefix_len] + self.token_to_kv_pool_host.update_synced(node.host_value) + self.evictable_size_ += len(node.value) else: - self.inc_hit_count(child) - return prefix_len + self._insert_helper( - child, key[prefix_len:], value[prefix_len:] - ) - - # partial match, split the node - new_node = self._split_node(child.key, child, prefix_len) - if new_node.evicted: - new_node.value = value[:prefix_len] - self.token_to_kv_pool_host.update_synced(new_node.host_value) - self.evictable_size_ += len(new_node.value) - return self._insert_helper( - new_node, key[prefix_len:], value[prefix_len:] - ) + self.inc_hit_count(node) + total_prefix_length += prefix_len else: - self.inc_hit_count(new_node) - return prefix_len + self._insert_helper( - new_node, key[prefix_len:], value[prefix_len:] - ) + # partial match, split the node + new_node = self._split_node(node.key, node, prefix_len) + if new_node.evicted: + new_node.value = value[:prefix_len] + self.token_to_kv_pool_host.update_synced(new_node.host_value) + self.evictable_size_ += len(new_node.value) + else: + self.inc_hit_count(new_node) + total_prefix_length += prefix_len + node = new_node + + key = key[prefix_len:] + value = value[prefix_len:] + + if len(key): + child_key = self.get_child_key_fn(key) if len(key): new_node = TreeNode() new_node.parent = node new_node.key = key new_node.value = value - node.children[key[0]] = new_node + node.children[child_key] = new_node self.evictable_size_ += len(value) if self.cache_controller.write_policy == "write_through": self.write_backup(new_node) - return 0 + return total_prefix_length def _collect_leaves_device(self): def is_leaf(node): diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index ae50df11441..31866e01074 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -608,8 +608,9 @@ def __init__( self, device_pool: MHATokenToKVPool, host_to_device_ratio: float, - pin_memory: bool = False, # no need to use pin memory with the double buffering - device: str = "cpu", + pin_memory: bool, + device: str, + page_size: int, ): assert ( host_to_device_ratio >= 1 @@ -620,8 +621,11 @@ def __init__( self.host_to_device_ratio = host_to_device_ratio self.pin_memory = pin_memory self.device = device + self.page_size = page_size self.size = int(device_pool.size * host_to_device_ratio) + # Align the host memory pool size to the page size + self.size = self.size - (self.size % self.page_size) self.dtype = device_pool.store_dtype self.size_per_token = self.get_size_per_token() @@ -775,10 +779,13 @@ def __init__( self, device_pool: MHATokenToKVPool, host_to_device_ratio: float, - pin_memory: bool = False, # no need to use pin memory with the double buffering + page_size: int, + pin_memory: bool = True, device: str = "cpu", ): - super().__init__(device_pool, host_to_device_ratio, pin_memory, device) + super().__init__( + device_pool, host_to_device_ratio, pin_memory, device, page_size + ) def get_size_per_token(self): self.head_num = self.device_pool.head_num @@ -811,16 +818,48 @@ def get_flat_data_by_layer(self, indices, layer_id): def assign_flat_data(self, indices, flat_data): self.kv_buffer[:, :, indices] = flat_data + def write_page_all_layers(self, host_indices, device_indices, device_pool): + device_indices_cpu = device_indices[:: self.page_size].cpu() + for i in range(len(device_indices_cpu)): + h_index = host_indices[i * self.page_size] + d_index = device_indices_cpu[i] + for j in range(self.layer_num): + self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_( + device_pool.k_buffer[j][d_index : d_index + self.page_size], + non_blocking=True, + ) + self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_( + device_pool.v_buffer[j][d_index : d_index + self.page_size], + non_blocking=True, + ) + + def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id): + device_indices_cpu = device_indices[:: self.page_size].cpu() + for i in range(len(device_indices_cpu)): + h_index = host_indices[i * self.page_size] + d_index = device_indices_cpu[i] + device_pool.k_buffer[layer_id][d_index : d_index + self.page_size].copy_( + self.kv_buffer[0, layer_id, h_index : h_index + self.page_size], + non_blocking=True, + ) + device_pool.v_buffer[layer_id][d_index : d_index + self.page_size].copy_( + self.kv_buffer[1, layer_id, h_index : h_index + self.page_size], + non_blocking=True, + ) + class MLATokenToKVPoolHost(HostKVCache): def __init__( self, device_pool: MLATokenToKVPool, host_to_device_ratio: float, - pin_memory: bool = False, # no need to use pin memory with the double buffering + page_size: int, + pin_memory: bool = True, device: str = "cpu", ): - super().__init__(device_pool, host_to_device_ratio, pin_memory, device) + super().__init__( + device_pool, host_to_device_ratio, pin_memory, device, page_size + ) def get_size_per_token(self): self.kv_lora_rank = self.device_pool.kv_lora_rank @@ -857,3 +896,24 @@ def get_flat_data_by_layer(self, indices, layer_id): def assign_flat_data(self, indices, flat_data): self.kv_buffer[:, indices] = flat_data + + def write_page_all_layers(self, host_indices, device_indices, device_pool): + device_indices_cpu = device_indices[:: self.page_size].cpu() + for i in range(len(device_indices_cpu)): + h_index = host_indices[i * self.page_size] + d_index = device_indices_cpu[i] + for j in range(self.layer_num): + self.kv_buffer[j, h_index : h_index + self.page_size].copy_( + device_pool.kv_buffer[j][d_index : d_index + self.page_size], + non_blocking=True, + ) + + def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id): + device_indices_cpu = device_indices[:: self.page_size].cpu() + for i in range(len(device_indices_cpu)): + h_index = host_indices[i * self.page_size] + d_index = device_indices_cpu[i] + device_pool.kv_buffer[layer_id][d_index : d_index + self.page_size].copy_( + self.kv_buffer[layer_id, h_index : h_index + self.page_size], + non_blocking=True, + ) diff --git a/python/sglang/srt/mem_cache/paged_allocator.py b/python/sglang/srt/mem_cache/paged_allocator.py index 37c4a8e5eb2..f6f3e23d853 100644 --- a/python/sglang/srt/mem_cache/paged_allocator.py +++ b/python/sglang/srt/mem_cache/paged_allocator.py @@ -190,6 +190,30 @@ def __init__( def available_size(self): return len(self.free_pages) * self.page_size + def get_kvcache(self): + return self._kvcache + + def alloc(self, need_size: int): + # page-aligned allocation, returning contiguous indices of pages + if self.debug_mode: + assert ( + need_size % self.page_size == 0 + ), "The allocation size should be page-aligned" + + num_pages = need_size // self.page_size + if num_pages > len(self.free_pages): + return None + + out_pages = self.free_pages[:num_pages] + self.free_pages = self.free_pages[num_pages:] + + out_indices = ( + out_pages[:, None] * self.page_size + + torch.arange(self.page_size, device=self.device) + ).reshape(-1) + + return out_indices + def alloc_extend( self, prefix_lens: torch.Tensor, diff --git a/test/srt/test_hicache.py b/test/srt/test_hicache.py index ac7a27f04f0..d651aa047bc 100644 --- a/test/srt/test_hicache.py +++ b/test/srt/test_hicache.py @@ -12,7 +12,7 @@ ) -class TestPageSize(CustomTestCase): +class TestHiCache(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST @@ -21,7 +21,9 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-hierarchical-cache"], + other_args=[ + "--enable-hierarchical-cache", + ], ) @classmethod diff --git a/test/srt/test_hicache_mla.py b/test/srt/test_hicache_mla.py index 8396615f3d5..71418470af6 100644 --- a/test/srt/test_hicache_mla.py +++ b/test/srt/test_hicache_mla.py @@ -21,7 +21,10 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--trust-remote-code", "--enable-hierarchical-cache"], + other_args=[ + "--trust-remote-code", + "--enable-hierarchical-cache", + ], ) @classmethod diff --git a/test/srt/test_hicache_page.py b/test/srt/test_hicache_page.py new file mode 100644 index 00000000000..f237af51b74 --- /dev/null +++ b/test/srt/test_hicache_page.py @@ -0,0 +1,49 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestHiCachePage(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--enable-hierarchical-cache", + "--page-size", + "32", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + + +if __name__ == "__main__": + unittest.main()