-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Hicache Storage Layer Prototype #7704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 31 commits
257808d
9766794
c5a23b1
897a66f
95bc313
f8ea5e9
28a2e8e
3a92509
ddddf42
4146051
554a91b
3546e02
4b816da
5e730e5
f75ff68
516abc6
fd74ed0
7cfc9ad
c05e19d
3f38e28
f82cc02
2bbd0c5
528395a
b07e8bb
b55b990
70046c3
2fab3ca
b90604b
f7b6280
ae77cc7
29620f7
849d966
44068a0
95657a3
7c18e06
384f3f5
f62ebe2
3462237
6837d40
22b913d
4bc693e
8c1c46c
1c9a64e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,8 @@ | |
| from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator | ||
| from sglang.srt.mem_cache.memory_pool_host import HostKVCache | ||
|
|
||
| from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
|
|
@@ -159,6 +161,57 @@ def clear(self): | |
| self.buffers.queue.clear() | ||
|
|
||
|
|
||
| class StorageOperation: | ||
| counter = 0 | ||
|
|
||
| def __init__( | ||
| self, | ||
| host_indices: torch.Tensor, | ||
| token_ids: List[int], | ||
| last_hash: Optional[str] = None, | ||
| ): | ||
| self.host_indices = host_indices | ||
| self.token_ids = token_ids | ||
| self.last_hash = last_hash | ||
| self.completed_tokens = 0 | ||
| self.hash_value = [] | ||
|
|
||
| self.id = StorageOperation.counter | ||
| StorageOperation.counter += 1 | ||
|
|
||
| def __lt__(self, other: "StorageOperation"): | ||
| return self.id < other.id | ||
|
|
||
|
|
||
| class PrefetchOperation(StorageOperation): | ||
| def __init__( | ||
| self, | ||
| request_id: str, | ||
| host_indices: torch.Tensor, | ||
| token_ids: List[int], | ||
| last_hash: Optional[str] = None, | ||
| ): | ||
| self.request_id = request_id | ||
|
|
||
| self._done_flag = False | ||
| self._lock = threading.Lock() | ||
|
|
||
| super().__init__(host_indices, token_ids, last_hash) | ||
|
|
||
| def increment(self, num_tokens: int): | ||
| with self._lock: | ||
| if self._done_flag: | ||
| return | ||
| self.completed_tokens += num_tokens | ||
|
|
||
| def mark_done(self): | ||
| with self._lock: | ||
| self._done_flag = True | ||
|
|
||
| def is_done(self) -> bool: | ||
| return self._done_flag | ||
|
|
||
|
|
||
| class HiCacheController: | ||
|
|
||
| def __init__( | ||
|
|
@@ -169,6 +222,8 @@ def __init__( | |
| load_cache_event: threading.Event = None, | ||
| write_policy: str = "write_through_selective", | ||
| io_backend: str = "", | ||
| storage_backend: Optional[str] = None, | ||
| prefetch_threshold: int = 256, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wonder if we can pass tp_rank and pp_rank information to this function. When dealing with file systems that aren't optimized for handling numerous small files, it might be more practical to have each rank correspond to a single large file. Including model parallelism (tp/pp) information in filenames would help us manage sglang instances and their associated KVCache files more effectively. |
||
| ): | ||
| self.mem_pool_device_allocator = token_to_kv_pool_allocator | ||
| self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() | ||
|
|
@@ -186,6 +241,21 @@ def __init__( | |
| else: | ||
| self.io_backend = io_backend | ||
|
|
||
| self.enable_storage = False | ||
| # todo: move backend initialization to storage backend module | ||
| if storage_backend is not None: | ||
| if storage_backend == "file": | ||
| self.storage_backend = HiCacheFile() | ||
| self.enable_storage = True | ||
| # tracking prefetch operation progress | ||
| self.ongoing_prefetch: dict[int, PrefetchOperation] = {} | ||
| # todo: threshold policy for prefetching | ||
| self.prefetch_threshold = prefetch_threshold | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Unsupported storage backend: {storage_backend}" | ||
| ) | ||
|
|
||
| self.load_cache_event = load_cache_event | ||
| self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) | ||
| self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter) | ||
|
|
@@ -218,9 +288,26 @@ def __init__( | |
| self.load_thread = threading.Thread( | ||
| target=self.load_thread_func_layer_by_layer, daemon=True | ||
| ) | ||
|
|
||
| self.write_thread.start() | ||
| self.load_thread.start() | ||
|
|
||
| if self.enable_storage: | ||
| self.prefetch_thread = threading.Thread( | ||
| target=self.prefetch_thread_func, daemon=True | ||
| ) | ||
| self.backup_thread = threading.Thread( | ||
| target=self.backup_thread_func, daemon=True | ||
| ) | ||
| self.prefetch_queue = Queue() | ||
| self.backup_queue = Queue() | ||
|
|
||
| self.prefetch_revoke_queue = Queue() | ||
| self.ack_backup_queue = Queue() | ||
|
|
||
| self.prefetch_thread.start() | ||
| self.backup_thread.start() | ||
|
|
||
| def reset(self): | ||
| self.stop_event.set() | ||
| self.write_thread.join() | ||
|
|
@@ -232,6 +319,13 @@ def reset(self): | |
| self.load_buffer.clear() | ||
| self.ack_write_queue.queue.clear() | ||
| self.ack_load_queue.queue.clear() | ||
| if self.enable_storage: | ||
| self.prefetch_thread.join() | ||
| self.backup_thread.join() | ||
| self.prefetch_queue.queue.clear() | ||
| self.backup_queue.queue.clear() | ||
| self.prefetch_revoke_queue.queue.clear() | ||
| self.ack_backup_queue.queue.clear() | ||
|
|
||
| self.write_thread = threading.Thread( | ||
| target=self.write_thread_func_direct, daemon=True | ||
|
|
@@ -243,6 +337,16 @@ def reset(self): | |
| self.write_thread.start() | ||
| self.load_thread.start() | ||
|
|
||
| if self.enable_storage: | ||
| self.prefetch_thread = threading.Thread( | ||
| target=self.prefetch_thread_func, daemon=True | ||
| ) | ||
| self.backup_thread = threading.Thread( | ||
| target=self.backup_thread_func, daemon=True | ||
| ) | ||
| self.prefetch_thread.start() | ||
| self.backup_thread.start() | ||
|
|
||
| def write( | ||
| self, | ||
| device_indices: torch.Tensor, | ||
|
|
@@ -383,3 +487,147 @@ def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> in | |
| raise ValueError( | ||
| f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}" | ||
| ) | ||
|
|
||
| def prefetch( | ||
| self, | ||
| request_id: str, | ||
| host_indices: torch.Tensor, | ||
| new_input_tokens: List[int], | ||
| last_hash: Optional[str] = None, | ||
| ) -> int: | ||
| """ | ||
| Prefetch KV caches from storage backend to host memory. | ||
| """ | ||
| operation = PrefetchOperation( | ||
| request_id, host_indices, new_input_tokens, last_hash | ||
| ) | ||
| self.ongoing_prefetch[request_id] = operation | ||
| self.prefetch_queue.put(operation) | ||
|
|
||
| def terminate_prefetch(self, request_id: str): | ||
| operation = self.ongoing_prefetch.pop(request_id, None) | ||
| if operation is None: | ||
| raise ValueError( | ||
| f"Request ID {request_id} not found in ongoing prefetches." | ||
| ) | ||
| operation.mark_done() | ||
| return operation.completed_tokens, operation.hash_value | ||
|
|
||
| def prefetch_io_aux_func(self): | ||
| """ | ||
| Auxiliary function conducting IO operations for prefetching. | ||
| """ | ||
| while not self.stop_event.is_set(): | ||
| try: | ||
| operation = self.prefetch_buffer.get(block=True, timeout=1) | ||
| for h in operation.hash_value: | ||
| page_data = self.storage_backend.get(h) | ||
| if page_data is None: | ||
| logger.warning( | ||
| f"Prefetch operation {operation.request_id} failed to retrieve page {h}." | ||
| ) | ||
| break | ||
| self.mem_pool_host.set_from_flat_data_page( | ||
| operation.host_indices[operation.completed_tokens], | ||
| page_data, | ||
| ) | ||
| operation.increment(self.page_size) | ||
| if operation.is_done(): | ||
xiezhq-hermann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # operation terminated by controller, release pre-allocated memory | ||
| self.mem_pool_host.free( | ||
| operation.host_indices[operation.completed_tokens :] | ||
| ) | ||
| break | ||
| except Empty: | ||
| continue | ||
|
|
||
| def prefetch_thread_func(self): | ||
| """ | ||
| Manage prefetching operations from storage backend to host memory. | ||
| """ | ||
| self.prefetch_buffer = Queue() | ||
| aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True) | ||
| aux_thread.start() | ||
| while (not self.stop_event.is_set()) or not self.prefetch_queue.empty(): | ||
| try: | ||
| operation = self.prefetch_queue.get(block=True, timeout=1) | ||
| if operation is None: | ||
| continue | ||
|
|
||
| last_hash = operation.last_hash | ||
| tokens_to_fetch = operation.token_ids | ||
|
|
||
| storage_hit_count = 0 | ||
| remaining_tokens = len(tokens_to_fetch) | ||
| hash_value = [] | ||
| while remaining_tokens >= self.page_size: | ||
| last_hash = get_hash_str( | ||
| tokens_to_fetch[ | ||
| storage_hit_count : storage_hit_count + self.page_size | ||
| ], | ||
| last_hash, | ||
| ) | ||
| if self.storage_backend.exists(last_hash): | ||
| storage_hit_count += self.page_size | ||
| hash_value.append(last_hash) | ||
| remaining_tokens -= self.page_size | ||
| else: | ||
| break | ||
|
|
||
| if storage_hit_count < self.prefetch_threshold: | ||
| # not to prefetch if not enough benefits | ||
| self.prefetch_revoke_queue.put(operation.request_id) | ||
| else: | ||
| operation.hash_value = hash_value | ||
| logger.debug( | ||
| f"Prefetching {len(hash_value)} pages for request {operation.request_id}." | ||
| ) | ||
| self.prefetch_buffer.put(operation) | ||
|
|
||
| except Empty: | ||
| continue | ||
|
|
||
| def write_storage( | ||
| self, | ||
| host_indices: torch.Tensor, | ||
| token_ids: List[int], | ||
| last_hash: Optional[str] = None, | ||
| ) -> int: | ||
| """ | ||
| Write KV caches from host memory to storage backend. | ||
| """ | ||
| operation = StorageOperation(host_indices, token_ids, last_hash) | ||
| self.backup_queue.put(operation) | ||
| return operation.id | ||
|
|
||
| def backup_thread_func(self): | ||
| """ | ||
| Manage backup operations from host memory to storage backend. | ||
| """ | ||
| while not self.stop_event.is_set(): | ||
| try: | ||
| operation = self.backup_queue.get(block=True, timeout=1) | ||
| if operation is None: | ||
| continue | ||
|
|
||
| last_hash = operation.last_hash | ||
| tokens_to_backup = operation.token_ids | ||
|
|
||
| for i in range(0, len(tokens_to_backup), self.page_size): | ||
| last_hash = get_hash_str( | ||
| tokens_to_backup[i : i + self.page_size], last_hash | ||
| ) | ||
| # todo, handle failures in storage backend | ||
| self.storage_backend.set( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can batch_set be used here to boost I/O? IO coalescing is required in scenarios like 3fs.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes these will be abstracted out for different options |
||
| last_hash, | ||
| self.mem_pool_host.get_flat_data_page( | ||
| operation.host_indices[i] | ||
| ), | ||
| ) | ||
| operation.completed_tokens += self.page_size | ||
| operation.hash_value.append(last_hash) | ||
|
|
||
| self.ack_backup_queue.put((operation.id, operation.hash_value)) | ||
|
|
||
| except Empty: | ||
| continue | ||
Uh oh!
There was an error while loading. Please reload this page.