From 78cf99966c3e2fd044233f5b91853d6a323f97aa Mon Sep 17 00:00:00 2001 From: shuwenn <47200617+alphabetc1@users.noreply.github.com> Date: Tue, 27 Jan 2026 02:17:30 +0800 Subject: [PATCH] [HiCache] support spec decode+hicache storage --- .../sglang/srt/managers/cache_controller.py | 295 +++++++++++++----- python/sglang/srt/managers/scheduler.py | 54 ++++ python/sglang/srt/mem_cache/hiradix_cache.py | 230 ++++++++++---- 3 files changed, 446 insertions(+), 133 deletions(-) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 80174584e51d..25260913215f 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -107,11 +107,13 @@ def __init__( device_indices: torch.Tensor, node_id: int, priority: Optional[int] = None, + use_draft: bool = False, ): self.host_indices = host_indices self.device_indices = device_indices self.node_ids = [node_id] self.data = None + self.use_draft = use_draft self.id = CacheOperation.counter CacheOperation.counter += 1 @@ -130,7 +132,10 @@ def merge_ops(ops: List[CacheOperation]) -> CacheOperation: priority = min(op.priority for op in ops) for op in ops: node_ids.extend(op.node_ids) - merged_op = CacheOperation(host_indices, device_indices, -1, priority) + use_draft = all(op.use_draft for op in ops) + merged_op = CacheOperation( + host_indices, device_indices, -1, priority, use_draft=use_draft + ) merged_op.node_ids = node_ids return merged_op @@ -197,6 +202,7 @@ def __init__( last_hash: Optional[str] = None, hash_value: Optional[List[str]] = None, prefix_keys: Optional[List[str]] = None, + is_draft: bool = False, ): self.host_indices = host_indices self.token_ids = token_ids @@ -204,6 +210,7 @@ def __init__( self.completed_tokens = 0 self.hash_value = hash_value if hash_value is not None else [] self.prefix_keys = prefix_keys + self.is_draft = is_draft self.id = StorageOperation.counter StorageOperation.counter += 1 @@ -220,6 +227,7 @@ def __init__( token_ids: List[int], last_hash: Optional[str] = None, prefix_keys: Optional[List[str]] = None, + is_draft: bool = False, ): self.request_id = request_id @@ -227,7 +235,13 @@ def __init__( self._terminated_flag = False self.start_time = time.monotonic() - super().__init__(host_indices, token_ids, last_hash, prefix_keys=prefix_keys) + super().__init__( + host_indices, + token_ids, + last_hash, + prefix_keys=prefix_keys, + is_draft=is_draft, + ) def increment(self, num_tokens: int): with self._lock: @@ -265,6 +279,11 @@ def __init__( self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() self.mem_pool_host = mem_pool_host + self.has_draft_kv_pool = False + self.mem_pool_device_draft = None + self.mem_pool_host_draft = None + self.storage_backend_draft = None + self.draft_storage_prefix = "draft_" self.write_policy = write_policy self.page_size = page_size self.io_backend = io_backend @@ -377,6 +396,50 @@ def __init__( self.prefetch_thread.start() self.backup_thread.start() + def set_draft_kv_pool(self, draft_kv_pool, draft_host_kv_pool): + """ + Set draft model KV pools for speculative decoding. + + Args: + draft_kv_pool: The draft model's device KV cache pool + draft_host_kv_pool: The draft model's host KV cache pool + """ + self.has_draft_kv_pool = True + self.mem_pool_device_draft = draft_kv_pool + self.mem_pool_host_draft = draft_host_kv_pool + + if self.enable_storage and self.storage_backend_draft is None: + from sglang.srt.mem_cache.storage import StorageBackendFactory + + try: + self.storage_backend_draft = StorageBackendFactory.create_backend( + self.storage_backend_type, + self.storage_config, + self.mem_pool_host_draft, + ) + except ValueError as e: + logger.error(f"Failed to create draft storage backend: {e}") + self.storage_backend_draft = None + self.has_draft_kv_pool = False + return + + self.storage_backend_draft.register_mem_pool_host(self.mem_pool_host_draft) + + def _apply_storage_prefix(self, keys: List[str], is_draft: bool) -> List[str]: + if not is_draft: + return keys + return [f"{self.draft_storage_prefix}{key}" for key in keys] + + def _apply_prefix_keys(self, prefix_keys: Optional[List[str]], is_draft: bool): + if not prefix_keys: + return prefix_keys + return self._apply_storage_prefix(prefix_keys, is_draft) + + def _select_storage_backend(self, is_draft: bool): + if is_draft and self.storage_backend_draft is not None: + return self.storage_backend_draft + return self.storage_backend + def _generate_storage_config( self, model_name: Optional[str] = None, @@ -447,8 +510,26 @@ def write( host_indices = self.mem_pool_host.alloc(len(device_indices)) if host_indices is None: return None + use_draft = False + if self.has_draft_kv_pool: + draft_indices = self.mem_pool_host_draft.alloc(len(device_indices)) + if draft_indices is None: + logger.warning( + "Draft HiCache host alloc failed. Disable draft hicache." + ) + self.has_draft_kv_pool = False + elif not torch.equal(draft_indices, host_indices): + self.mem_pool_host_draft.free(draft_indices) + logger.warning( + "Draft HiCache host indices desynced. Disable draft hicache." + ) + self.has_draft_kv_pool = False + else: + use_draft = True self.write_queue.append( - CacheOperation(host_indices, device_indices, node_id, priority) + CacheOperation( + host_indices, device_indices, node_id, priority, use_draft=use_draft + ) ) self.start_writing() return host_indices @@ -456,30 +537,42 @@ def write( def start_writing(self) -> None: if len(self.write_queue) == 0: return + ops = self.write_queue + self.write_queue = [] + if not self.has_draft_kv_pool: + ops = [CacheOperation.merge_ops(ops)] - op = CacheOperation.merge_ops(self.write_queue) - host_indices, device_indices = self.move_indices(op) - self.write_queue.clear() + for op in ops: + host_indices, device_indices = self.move_indices(op) - start_event = device_module.Event() - finish_event = device_module.Event() + start_event = device_module.Event() + finish_event = device_module.Event() - start_event.record() - with device_module.stream(self.write_stream): - start_event.wait(self.write_stream) - self.mem_pool_host.backup_from_device_all_layer( - self.mem_pool_device, host_indices, device_indices, self.io_backend + start_event.record() + with device_module.stream(self.write_stream): + start_event.wait(self.write_stream) + self.mem_pool_host.backup_from_device_all_layer( + self.mem_pool_device, host_indices, device_indices, self.io_backend + ) + if op.use_draft and self.mem_pool_device_draft is not None: + self.mem_pool_host_draft.backup_from_device_all_layer( + self.mem_pool_device_draft, + host_indices, + device_indices, + self.io_backend, + ) + finish_event.record() + # NOTE: We must save the host indices and device indices here, + # this is because we need to guarantee that these tensors are + # still alive when the write stream is executing. + if host_indices.is_cuda: + host_indices.record_stream(self.write_stream) + if device_indices.is_cuda: + device_indices.record_stream(self.write_stream) + + self.ack_write_queue.append( + HiCacheAck(start_event, finish_event, op.node_ids) ) - finish_event.record() - # NOTE: We must save the host indices and device indices here, - # this is because we need to guarantee that these tensors are - # still alive when the write stream is executing. - if host_indices.is_cuda: - host_indices.record_stream(self.write_stream) - if device_indices.is_cuda: - device_indices.record_stream(self.write_stream) - - self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids)) def load( self, @@ -494,7 +587,13 @@ def load( if device_indices is None: return None self.load_queue.append( - CacheOperation(host_indices, device_indices, node_id, priority) + CacheOperation( + host_indices, + device_indices, + node_id, + priority, + use_draft=self.has_draft_kv_pool, + ) ) return device_indices @@ -520,40 +619,52 @@ def move_indices(self, op: CacheOperation): def start_loading(self) -> int: if len(self.load_queue) == 0: return -1 + ops = self.load_queue + self.load_queue = [] + if not self.has_draft_kv_pool: + ops = [CacheOperation.merge_ops(ops)] producer_id = self.layer_done_counter.update_producer() - op = CacheOperation.merge_ops(self.load_queue) - host_indices, device_indices = self.move_indices(op) - self.load_queue.clear() producer_event = self.layer_done_counter.events[producer_id] producer_event.start_event.record() with device_module.stream(self.load_stream): producer_event.start_event.wait(self.load_stream) - for i in range(self.layer_num): - self.mem_pool_host.load_to_device_per_layer( - self.mem_pool_device, - host_indices, - device_indices, - i, - self.io_backend, + for op in ops: + host_indices, device_indices = self.move_indices(op) + for i in range(self.layer_num): + self.mem_pool_host.load_to_device_per_layer( + self.mem_pool_device, + host_indices, + device_indices, + i, + self.io_backend, + ) + if op.use_draft and self.mem_pool_device_draft is not None: + self.mem_pool_host_draft.load_to_device_per_layer( + self.mem_pool_device_draft, + host_indices, + device_indices, + i, + self.io_backend, + ) + producer_event.complete(i) + # NOTE: We must save the host indices and device indices here, + # this is because we need to guarantee that these tensors are + # still alive when the load stream is executing. + if host_indices.is_cuda: + host_indices.record_stream(self.load_stream) + if device_indices.is_cuda: + device_indices.record_stream(self.load_stream) + + self.ack_load_queue.append( + HiCacheAck( + start_event=producer_event.start_event, + finish_event=producer_event.finish_event, + node_ids=op.node_ids, + ) ) - producer_event.complete(i) - # NOTE: We must save the host indices and device indices here, - # this is because we need to guarantee that these tensors are - # still alive when the load stream is executing. - if host_indices.is_cuda: - host_indices.record_stream(self.load_stream) - if device_indices.is_cuda: - device_indices.record_stream(self.load_stream) - - self.ack_load_queue.append( - HiCacheAck( - start_event=producer_event.start_event, - finish_event=producer_event.finish_event, - node_ids=op.node_ids, - ) - ) + return producer_id def evict_device(self, device_indices: torch.Tensor) -> int: @@ -563,8 +674,9 @@ def evict_device(self, device_indices: torch.Tensor) -> int: def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int: if not backup_only: raise ValueError("Other eviction policies are not supported yet.") - self.mem_pool_host.free(host_indices) + if self.has_draft_kv_pool and self.mem_pool_host_draft is not None: + self.mem_pool_host_draft.free(host_indices) return len(host_indices) def prefetch( @@ -574,12 +686,18 @@ def prefetch( new_input_tokens: List[int], last_hash: Optional[str] = None, prefix_keys: Optional[List[str]] = None, + is_draft: bool = False, ) -> PrefetchOperation: """ Prefetch KV caches from storage backend to host memory. """ operation = PrefetchOperation( - request_id, host_indices, new_input_tokens, last_hash, prefix_keys + request_id, + host_indices, + new_input_tokens, + last_hash, + prefix_keys, + is_draft=is_draft, ) self.prefetch_queue.put(operation) return operation @@ -588,19 +706,20 @@ def terminate_prefetch(self, operation): operation.mark_terminate() return operation.completed_tokens, operation.hash_value - def append_host_mem_release(self, host_indices: torch.Tensor): + def append_host_mem_release( + self, host_indices: torch.Tensor, is_draft: bool = False + ): if host_indices.numel() == 0: return pages = host_indices.split(self.mem_pool_host.page_size) for page in pages: - self.host_mem_release_queue.put(page) + self.host_mem_release_queue.put((page, is_draft)) def _page_get_zero_copy( self, operation, hash_values, host_indices, extra_info=None ): - results = self.storage_backend.batch_get_v1( - hash_values, host_indices, extra_info - ) + backend = self._select_storage_backend(operation.is_draft) + results = backend.batch_get_v1(hash_values, host_indices, extra_info) inc = 0 for i in range(len(hash_values)): if not results[i]: @@ -613,10 +732,12 @@ def _page_get_zero_copy( # todo: deprecate def _generic_page_get(self, operation, hash_values, host_indices, extra_info=None): - dummy_page_dst = [ - self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values - ] - page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst) + mem_pool_host = ( + self.mem_pool_host_draft if operation.is_draft else self.mem_pool_host + ) + dummy_page_dst = [mem_pool_host.get_dummy_flat_data_page() for _ in hash_values] + backend = self._select_storage_backend(operation.is_draft) + page_data = backend.batch_get(hash_values, dummy_page_dst) if page_data is None: return for i in range(len(hash_values)): @@ -627,7 +748,7 @@ def _generic_page_get(self, operation, hash_values, host_indices, extra_info=Non break # Must set the data before increasing the completed tokens. # Otherwise this page may be read before being set. - self.mem_pool_host.set_from_flat_data_page( + mem_pool_host.set_from_flat_data_page( host_indices[i * self.page_size], page_data[i], ) @@ -636,9 +757,10 @@ def _generic_page_get(self, operation, hash_values, host_indices, extra_info=Non def _page_transfer(self, operation): # Transfer batch by batch - prefix_keys = operation.prefix_keys + prefix_keys = self._apply_prefix_keys(operation.prefix_keys, operation.is_draft) for i in range(0, len(operation.hash_value), self.storage_batch_size): batch_hashes = operation.hash_value[i : i + self.storage_batch_size] + batch_hashes = self._apply_storage_prefix(batch_hashes, operation.is_draft) batch_host_indices = operation.host_indices[ i * self.page_size : (i + len(batch_hashes)) * self.page_size ] @@ -655,7 +777,9 @@ def _page_transfer(self, operation): break # Some operations fail or operation terminated by controller if prefix_keys and len(prefix_keys) > 0: - prefix_keys += batch_hashes + prefix_keys += self._apply_storage_prefix( + batch_hashes, operation.is_draft + ) def prefetch_io_aux_func(self): """ @@ -667,7 +791,8 @@ def prefetch_io_aux_func(self): self._page_transfer(operation) # operation terminated by controller, release pre-allocated memory self.append_host_mem_release( - operation.host_indices[operation.completed_tokens :] + operation.host_indices[operation.completed_tokens :], + is_draft=operation.is_draft, ) except Empty: continue @@ -686,6 +811,7 @@ def _storage_hit_query(self, operation) -> tuple[list[str], int]: last_hash = operation.last_hash tokens_to_fetch = operation.token_ids prefix_keys = operation.prefix_keys.copy() if operation.prefix_keys else None + prefix_keys = self._apply_prefix_keys(prefix_keys, operation.is_draft) storage_query_count = 0 hash_value = [] @@ -704,7 +830,9 @@ def _storage_hit_query(self, operation) -> tuple[list[str], int]: ) batch_hashes.append(last_hash) extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys) - hit_page_num = self.storage_backend.batch_exists(batch_hashes, extra_info) + query_hashes = self._apply_storage_prefix(batch_hashes, operation.is_draft) + backend = self._select_storage_backend(operation.is_draft) + hit_page_num = backend.batch_exists(query_hashes, extra_info) hash_value.extend(batch_hashes[:hit_page_num]) storage_query_count += hit_page_num * self.page_size if hit_page_num < len(batch_hashes): @@ -740,8 +868,12 @@ def prefetch_thread_func(self): if storage_hit_count < self.prefetch_threshold: # not to prefetch if not enough benefits - self.prefetch_revoke_queue.put(operation.request_id) - self.append_host_mem_release(operation.host_indices) + self.prefetch_revoke_queue.put( + (operation.request_id, operation.is_draft) + ) + self.append_host_mem_release( + operation.host_indices, is_draft=operation.is_draft + ) logger.debug( f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})." ) @@ -751,7 +883,8 @@ def prefetch_thread_func(self): ] # free the pre-allocated memory for pages that are not hit self.append_host_mem_release( - operation.host_indices[storage_hit_count:] + operation.host_indices[storage_hit_count:], + is_draft=operation.is_draft, ) operation.host_indices = operation.host_indices[:storage_hit_count] logger.debug( @@ -768,41 +901,51 @@ def write_storage( token_ids: List[int], hash_value: Optional[List[str]] = None, prefix_keys: Optional[List[str]] = None, + is_draft: bool = False, ) -> int: """ Write KV caches from host memory to storage backend. """ operation = StorageOperation( - host_indices, token_ids, hash_value=hash_value, prefix_keys=prefix_keys + host_indices, + token_ids, + hash_value=hash_value, + prefix_keys=prefix_keys, + is_draft=is_draft, ) self.backup_queue.put(operation) return operation.id # todo: deprecate def _generic_page_set(self, hash_values, host_indices, extra_info=None) -> bool: + is_draft = extra_info is not None and getattr(extra_info, "is_draft", False) + mem_pool_host = self.mem_pool_host_draft if is_draft else self.mem_pool_host data = [ - self.mem_pool_host.get_data_page(host_indices[i * self.page_size]) + mem_pool_host.get_data_page(host_indices[i * self.page_size]) for i in range(len(hash_values)) ] - return self.storage_backend.batch_set(hash_values, data) + backend = self._select_storage_backend(is_draft) + return backend.batch_set(hash_values, data) def _page_set_zero_copy(self, hash_values, host_indices, extra_info=None) -> bool: - return all( - self.storage_backend.batch_set_v1(hash_values, host_indices, extra_info) - ) + is_draft = extra_info is not None and getattr(extra_info, "is_draft", False) + backend = self._select_storage_backend(is_draft) + return all(backend.batch_set_v1(hash_values, host_indices, extra_info)) # Backup batch by batch def _page_backup(self, operation): # Backup batch by batch - prefix_keys = operation.prefix_keys + prefix_keys = self._apply_prefix_keys(operation.prefix_keys, operation.is_draft) for i in range(0, len(operation.hash_value), self.storage_batch_size): batch_hashes = operation.hash_value[i : i + self.storage_batch_size] + batch_hashes = self._apply_storage_prefix(batch_hashes, operation.is_draft) batch_host_indices = operation.host_indices[ i * self.page_size : (i + len(batch_hashes)) * self.page_size ] # Set one batch token, and record if success. # todo: allow partial success extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys) + extra_info.is_draft = operation.is_draft success = self.page_set_func(batch_hashes, batch_host_indices, extra_info) if not success: logger.warning( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 0c68283f4e2d..234f6142da76 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -667,6 +667,8 @@ def init_cache_with_memory_pool(self): self.tp_worker.register_hicache_layer_transfer_counter( self.tree_cache.cache_controller.layer_done_counter ) + if not self.spec_algorithm.is_none(): + self._register_draft_kv_pool_for_hicache(server_args) elif self.is_hybrid_swa: from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache @@ -735,6 +737,58 @@ def init_running_status(self): self.forward_sleep_time = None self._engine_paused = False + def _register_draft_kv_pool_for_hicache(self, server_args: ServerArgs): + """Register draft model KV pool with HiCache for speculative decoding.""" + if self.draft_worker is None: + return + + from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool + from sglang.srt.mem_cache.memory_pool_host import ( + MHATokenToKVPoolHost, + MLATokenToKVPoolHost, + ) + + draft_runner = None + if hasattr(self.draft_worker, "draft_worker"): + draft_runner = getattr(self.draft_worker, "draft_worker").model_runner + elif hasattr(self.draft_worker, "draft_model_runner"): + draft_runner = getattr(self.draft_worker, "draft_model_runner") + elif hasattr(self.draft_worker, "model_runner"): + draft_runner = getattr(self.draft_worker, "model_runner") + + if draft_runner is None: + return + + draft_kv_pool = draft_runner.token_to_kv_pool + if isinstance(draft_kv_pool, MHATokenToKVPool): + draft_host_kv_pool = MHATokenToKVPoolHost( + draft_kv_pool, + server_args.hicache_ratio, + server_args.hicache_size, + self.page_size, + server_args.hicache_mem_layout, + allocator_type=server_args.hicache_storage_backend, + ) + elif isinstance(draft_kv_pool, MLATokenToKVPool): + draft_host_kv_pool = MLATokenToKVPoolHost( + draft_kv_pool, + server_args.hicache_ratio, + server_args.hicache_size, + self.page_size, + server_args.hicache_mem_layout, + allocator_type=server_args.hicache_storage_backend, + ) + else: + logger.warning( + f"Draft KV pool type {type(draft_kv_pool).__name__} not supported for HiCache" + ) + return + + if hasattr(self.tree_cache, "cache_controller"): + self.tree_cache.cache_controller.set_draft_kv_pool( + draft_kv_pool, draft_host_kv_pool + ) + def init_chunked_prefill(self): # Init chunked prefill self.chunked_prefill_size = self.server_args.chunked_prefill_size diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 853baef6bd2a..b6c14b2dc7a3 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -138,7 +138,9 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): self.ongoing_load_back = {} # record the ongoing prefetch requests self.ongoing_prefetch = {} + self.ongoing_prefetch_draft = {} self.ongoing_backup = {} + self.ongoing_backup_draft = {} # todo: dynamically adjust the threshold self.write_through_threshold = ( 1 if server_args.hicache_write_policy == "write_through" else 2 @@ -289,6 +291,16 @@ def write_backup_storage(self, node: TreeNode): ) self.ongoing_backup[operation_id] = node node.protect_host() + if self.cache_controller.has_draft_kv_pool: + draft_operation_id = self.cache_controller.write_storage( + node.host_value, + node.key, + node.hash_value, + prefix_keys, + is_draft=True, + ) + self.ongoing_backup_draft[draft_operation_id] = node + node.protect_host() def _inc_hit_count(self, node: TreeNode, chunked=False): # skip the hit count update for chunked requests @@ -575,8 +587,11 @@ def drain_storage_control_queues(self): # process prefetch revokes for _ in range(n_revoke): - req_id = cc.prefetch_revoke_queue.get() - info = self.ongoing_prefetch.pop(req_id, None) + req_id, is_draft = cc.prefetch_revoke_queue.get() + pending_map = ( + self.ongoing_prefetch_draft if is_draft else self.ongoing_prefetch + ) + info = pending_map.pop(req_id, None) if info is not None: last_host_node, token_ids, _, _ = info last_host_node.release_host() @@ -587,7 +602,11 @@ def drain_storage_control_queues(self): for _ in range(n_backup): operation = cc.ack_backup_queue.get() ack_id = operation.id - entry = self.ongoing_backup.pop(ack_id, None) + entry = ( + self.ongoing_backup_draft.pop(ack_id, None) + if operation.is_draft + else self.ongoing_backup.pop(ack_id, None) + ) if entry is not None: entry.release_host() if self.enable_storage_metrics: @@ -597,11 +616,19 @@ def drain_storage_control_queues(self): # release host memory host_indices_list = [] + draft_host_indices_list = [] for _ in range(n_release): - host_indices_list.append(cc.host_mem_release_queue.get()) + host_indices, is_draft = cc.host_mem_release_queue.get() + if is_draft: + draft_host_indices_list.append(host_indices) + else: + host_indices_list.append(host_indices) if host_indices_list: host_indices = torch.cat(host_indices_list, dim=0) cc.mem_pool_host.free(host_indices) + if draft_host_indices_list and cc.has_draft_kv_pool: + draft_host_indices = torch.cat(draft_host_indices_list, dim=0) + cc.mem_pool_host_draft.free(draft_host_indices) # Timeout is linearly increasing with the number of pages def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation): @@ -652,74 +679,134 @@ def can_terminate_prefetch(self, operation: PrefetchOperation): return can_terminate def check_prefetch_progress(self, req_id: str) -> bool: - if req_id not in self.ongoing_prefetch: + if ( + req_id not in self.ongoing_prefetch + and req_id not in self.ongoing_prefetch_draft + ): # there is no ongoing prefetch for this request or it has been revoked return True - # todo: more policies for prefetch progress such as timeout - # the current policy is to prefetch with best effort and terminate when queuing is over - last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[ - req_id - ] + if req_id in self.ongoing_prefetch_draft: + _, _, _, draft_operation = self.ongoing_prefetch_draft[req_id] + if ( + draft_operation.host_indices is not None + and not self.can_terminate_prefetch(draft_operation) + ): + return False - if operation.host_indices is None: - # prefetch has not been issued due to insufficient host memory - return True + if req_id in self.ongoing_prefetch: + # todo: more policies for prefetch progress such as timeout + # the current policy is to prefetch with best effort and terminate when queuing is over + last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[ + req_id + ] - if not self.can_terminate_prefetch(operation): - return False + if operation.host_indices is None: + # prefetch has not been issued due to insufficient host memory + return True - completed_tokens, hash_value = self.cache_controller.terminate_prefetch( - operation - ) - logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") + if not self.can_terminate_prefetch(operation): + return False - min_completed_tokens = completed_tokens - if self.tp_world_size > 1: - # synchrnoize TP workers to make the same update to hiradix cache - completed_tokens_tensor = torch.tensor( - min_completed_tokens, dtype=torch.int - ) - torch.distributed.all_reduce( - completed_tokens_tensor, - op=torch.distributed.ReduceOp.MIN, - group=self.tp_group, + completed_tokens, hash_value = self.cache_controller.terminate_prefetch( + operation ) - min_completed_tokens = completed_tokens_tensor.item() - fetched_token_ids = token_ids[:min_completed_tokens] - written_indices = host_indices[:min_completed_tokens] - matched_length = self._insert_helper_host( - last_host_node, - RadixKey( - token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key - ), - written_indices, - hash_value[: min_completed_tokens // self.page_size], - ) + logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") - self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) - self.cache_controller.append_host_mem_release( - host_indices[min_completed_tokens:completed_tokens] - ) - last_host_node.release_host() - del self.ongoing_prefetch[req_id] - self.cache_controller.prefetch_tokens_occupied -= len(token_ids) + min_completed_tokens = completed_tokens + if self.tp_world_size > 1: + # synchrnoize TP workers to make the same update to hiradix cache + completed_tokens_tensor = torch.tensor( + min_completed_tokens, dtype=torch.int + ) + torch.distributed.all_reduce( + completed_tokens_tensor, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + min_completed_tokens = completed_tokens_tensor.item() + fetched_token_ids = token_ids[:min_completed_tokens] + written_indices = host_indices[:min_completed_tokens] + matched_length = self._insert_helper_host( + last_host_node, + RadixKey( + token_ids=fetched_token_ids, + extra_key=last_host_node.key.extra_key, + ), + written_indices, + hash_value[: min_completed_tokens // self.page_size], + ) - if self.enable_storage_metrics: - self.storage_metrics_collector.log_prefetched_tokens( - min_completed_tokens - matched_length + self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) + self.cache_controller.append_host_mem_release( + host_indices[min_completed_tokens:completed_tokens] ) + last_host_node.release_host() + del self.ongoing_prefetch[req_id] + self.cache_controller.prefetch_tokens_occupied -= len(token_ids) + + if self.enable_storage_metrics: + self.storage_metrics_collector.log_prefetched_tokens( + min_completed_tokens - matched_length + ) + + if req_id in self.ongoing_prefetch_draft: + ( + draft_last_host_node, + draft_token_ids, + draft_host_indices, + draft_operation, + ) = self.ongoing_prefetch_draft[req_id] + if draft_operation.host_indices is not None: + draft_completed_tokens, draft_hash_value = ( + self.cache_controller.terminate_prefetch(draft_operation) + ) + draft_min_completed_tokens = draft_completed_tokens + if self.tp_world_size > 1: + draft_completed_tokens_tensor = torch.tensor( + draft_min_completed_tokens, dtype=torch.int + ) + torch.distributed.all_reduce( + draft_completed_tokens_tensor, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + draft_min_completed_tokens = draft_completed_tokens_tensor.item() + draft_fetched_token_ids = draft_token_ids[:draft_min_completed_tokens] + draft_written_indices = draft_host_indices[:draft_min_completed_tokens] + draft_matched_length = self._insert_helper_host( + draft_last_host_node, + RadixKey( + token_ids=draft_fetched_token_ids, + extra_key=draft_last_host_node.key.extra_key, + ), + draft_written_indices, + draft_hash_value[: draft_min_completed_tokens // self.page_size], + ) + self.cache_controller.mem_pool_host_draft.free( + draft_host_indices[:draft_matched_length] + ) + self.cache_controller.append_host_mem_release( + draft_host_indices[ + draft_min_completed_tokens:draft_completed_tokens + ], + is_draft=True, + ) + draft_last_host_node.release_host() + del self.ongoing_prefetch_draft[req_id] + self.cache_controller.prefetch_tokens_occupied -= len(draft_token_ids) return True def terminate_prefetch(self, req_id: str): - if req_id not in self.ongoing_prefetch: - return - - _, _, _, operation = self.ongoing_prefetch[req_id] - if operation.host_indices is None: - return - operation.mark_terminate() + if req_id in self.ongoing_prefetch: + _, _, _, operation = self.ongoing_prefetch[req_id] + if operation.host_indices is not None: + operation.mark_terminate() + if req_id in self.ongoing_prefetch_draft: + _, _, _, operation = self.ongoing_prefetch_draft[req_id] + if operation.host_indices is not None: + operation.mark_terminate() def match_prefix(self, params: MatchPrefixParams): key = params.key @@ -798,6 +885,35 @@ def prefetch_from_storage( ) self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens) + if self.cache_controller.has_draft_kv_pool: + draft_host_indices = self.cache_controller.mem_pool_host_draft.alloc( + prefetch_length + ) + if draft_host_indices is None or not torch.equal( + draft_host_indices, host_indices + ): + if draft_host_indices is not None: + self.cache_controller.mem_pool_host_draft.free(draft_host_indices) + logger.warning( + "Draft HiCache prefetch indices desynced. Skip draft prefetch." + ) + return + draft_operation = self.cache_controller.prefetch( + req_id, + draft_host_indices, + new_input_tokens, + last_hash, + prefix_keys, + is_draft=True, + ) + self.ongoing_prefetch_draft[req_id] = ( + last_host_node, + new_input_tokens, + draft_host_indices, + draft_operation, + ) + self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens) + def _insert_helper_host( self, node: TreeNode, key: RadixKey, host_value, hash_value ):