[HiCache][WIP] support spec decode+hicache storage#17776
Conversation
Summary of ChangesHello @alphabetc1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the HiCache system by introducing comprehensive support for speculative decoding. It achieves this by implementing a dual KV cache management system, allowing separate storage and retrieval of key-value states for both the main and a 'draft' model. This architectural change is crucial for optimizing inference speed in large language models by enabling parallel processing and early prediction mechanisms. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces support for speculative decoding with HiCache storage by integrating a draft KV pool and associated storage backend. The changes involve extending CacheOperation, StorageOperation, and PrefetchOperation classes with an is_draft flag, and modifying the HiCacheController and HiRadixCache to manage separate memory pools and storage operations for the draft model. This is a significant feature addition that enhances the system's capabilities for speculative decoding.
| ops = self.write_queue | ||
| self.write_queue = [] | ||
| if not self.has_draft_kv_pool: | ||
| ops = [CacheOperation.merge_ops(ops)] |
There was a problem hiding this comment.
The start_writing method's behavior changes significantly based on self.has_draft_kv_pool. When self.has_draft_kv_pool is False, all operations in self.write_queue are merged into a single CacheOperation before processing. However, when self.has_draft_kv_pool is True, the method iterates over individual operations without merging them. This difference in batching strategy could lead to performance variations, potentially reducing throughput when the draft KV pool is active due to more frequent, smaller GPU transfers. Consider if merging operations is still beneficial when self.has_draft_kv_pool is true, or if the current approach is intentional for specific reasons (e.g., finer-grained control over draft operations).
| ops = self.load_queue | ||
| self.load_queue = [] | ||
| if not self.has_draft_kv_pool: | ||
| ops = [CacheOperation.merge_ops(ops)] |
There was a problem hiding this comment.
Similar to start_writing, the start_loading method's behavior changes based on self.has_draft_kv_pool. When self.has_draft_kv_pool is False, all operations in self.load_queue are merged into a single CacheOperation before processing. When self.has_draft_kv_pool is True, the method iterates over individual operations without merging them. This could lead to performance differences between the two modes. If batching is generally more efficient for GPU transfers, consider if merging operations is still desirable when the draft KV pool is active.
| elif not torch.equal(draft_indices, host_indices): | ||
| self.mem_pool_host_draft.free(draft_indices) | ||
| logger.warning( | ||
| "Draft HiCache host indices desynced. Disable draft hicache." | ||
| ) | ||
| self.has_draft_kv_pool = False |
There was a problem hiding this comment.
The check elif not torch.equal(draft_indices, host_indices): followed by self.mem_pool_host_draft.free(draft_indices) in the write method implies a strong assumption that draft_indices and host_indices should always be identical if self.has_draft_kv_pool is true. If they are not equal, the draft HiCache is disabled. This design choice should be clearly documented, explaining why the indices must be identical and what happens if they diverge. Also, ensure that freeing draft_indices in this specific else if branch does not inadvertently cause issues if host_indices (which is equal to draft_indices in the else branch) is still expected to be managed by the main mem_pool_host.
| if req_id in self.ongoing_prefetch: | ||
| # todo: more policies for prefetch progress such as timeout | ||
| # the current policy is to prefetch with best effort and terminate when queuing is over | ||
| last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[ | ||
| req_id | ||
| ] | ||
|
|
||
| if not self.can_terminate_prefetch(operation): | ||
| return False | ||
| if operation.host_indices is None: | ||
| # prefetch has not been issued due to insufficient host memory | ||
| return True | ||
|
|
||
| completed_tokens, hash_value = self.cache_controller.terminate_prefetch( | ||
| operation | ||
| ) | ||
| logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") | ||
| if not self.can_terminate_prefetch(operation): | ||
| return False | ||
|
|
||
| min_completed_tokens = completed_tokens | ||
| if self.tp_world_size > 1: | ||
| # synchrnoize TP workers to make the same update to hiradix cache | ||
| completed_tokens_tensor = torch.tensor( | ||
| min_completed_tokens, dtype=torch.int | ||
| ) | ||
| torch.distributed.all_reduce( | ||
| completed_tokens_tensor, | ||
| op=torch.distributed.ReduceOp.MIN, | ||
| group=self.tp_group, | ||
| completed_tokens, hash_value = self.cache_controller.terminate_prefetch( | ||
| operation | ||
| ) | ||
| min_completed_tokens = completed_tokens_tensor.item() | ||
| fetched_token_ids = token_ids[:min_completed_tokens] | ||
| written_indices = host_indices[:min_completed_tokens] | ||
| matched_length = self._insert_helper_host( | ||
| last_host_node, | ||
| RadixKey( | ||
| token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key | ||
| ), | ||
| written_indices, | ||
| hash_value[: min_completed_tokens // self.page_size], | ||
| ) | ||
| logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") | ||
|
|
||
| self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) | ||
| self.cache_controller.append_host_mem_release( | ||
| host_indices[min_completed_tokens:completed_tokens] | ||
| ) | ||
| last_host_node.release_host() | ||
| del self.ongoing_prefetch[req_id] | ||
| self.cache_controller.prefetch_tokens_occupied -= len(token_ids) | ||
| min_completed_tokens = completed_tokens | ||
| if self.tp_world_size > 1: | ||
| # synchrnoize TP workers to make the same update to hiradix cache | ||
| completed_tokens_tensor = torch.tensor( | ||
| min_completed_tokens, dtype=torch.int | ||
| ) | ||
| torch.distributed.all_reduce( | ||
| completed_tokens_tensor, | ||
| op=torch.distributed.ReduceOp.MIN, | ||
| group=self.tp_group, | ||
| ) | ||
| min_completed_tokens = completed_tokens_tensor.item() | ||
| fetched_token_ids = token_ids[:min_completed_tokens] | ||
| written_indices = host_indices[:min_completed_tokens] | ||
| matched_length = self._insert_helper_host( | ||
| last_host_node, | ||
| RadixKey( | ||
| token_ids=fetched_token_ids, | ||
| extra_key=last_host_node.key.extra_key, | ||
| ), | ||
| written_indices, | ||
| hash_value[: min_completed_tokens // self.page_size], | ||
| ) | ||
|
|
||
| if self.enable_storage_metrics: | ||
| self.storage_metrics_collector.log_prefetched_tokens( | ||
| min_completed_tokens - matched_length | ||
| self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) | ||
| self.cache_controller.append_host_mem_release( | ||
| host_indices[min_completed_tokens:completed_tokens] | ||
| ) | ||
| last_host_node.release_host() | ||
| del self.ongoing_prefetch[req_id] | ||
| self.cache_controller.prefetch_tokens_occupied -= len(token_ids) | ||
|
|
||
| if self.enable_storage_metrics: | ||
| self.storage_metrics_collector.log_prefetched_tokens( | ||
| min_completed_tokens - matched_length | ||
| ) | ||
|
|
||
| if req_id in self.ongoing_prefetch_draft: | ||
| ( | ||
| draft_last_host_node, | ||
| draft_token_ids, | ||
| draft_host_indices, | ||
| draft_operation, | ||
| ) = self.ongoing_prefetch_draft[req_id] | ||
| if draft_operation.host_indices is not None: | ||
| draft_completed_tokens, draft_hash_value = ( | ||
| self.cache_controller.terminate_prefetch(draft_operation) | ||
| ) | ||
| draft_min_completed_tokens = draft_completed_tokens | ||
| if self.tp_world_size > 1: | ||
| draft_completed_tokens_tensor = torch.tensor( | ||
| draft_min_completed_tokens, dtype=torch.int | ||
| ) | ||
| torch.distributed.all_reduce( | ||
| draft_completed_tokens_tensor, | ||
| op=torch.distributed.ReduceOp.MIN, | ||
| group=self.tp_group, | ||
| ) | ||
| draft_min_completed_tokens = draft_completed_tokens_tensor.item() | ||
| draft_fetched_token_ids = draft_token_ids[:draft_min_completed_tokens] | ||
| draft_written_indices = draft_host_indices[:draft_min_completed_tokens] | ||
| draft_matched_length = self._insert_helper_host( | ||
| draft_last_host_node, | ||
| RadixKey( | ||
| token_ids=draft_fetched_token_ids, | ||
| extra_key=draft_last_host_node.key.extra_key, | ||
| ), | ||
| draft_written_indices, | ||
| draft_hash_value[: draft_min_completed_tokens // self.page_size], | ||
| ) | ||
| self.cache_controller.mem_pool_host_draft.free( | ||
| draft_host_indices[:draft_matched_length] | ||
| ) | ||
| self.cache_controller.append_host_mem_release( | ||
| draft_host_indices[ | ||
| draft_min_completed_tokens:draft_completed_tokens | ||
| ], | ||
| is_draft=True, | ||
| ) | ||
| draft_last_host_node.release_host() | ||
| del self.ongoing_prefetch_draft[req_id] | ||
| self.cache_controller.prefetch_tokens_occupied -= len(draft_token_ids) |
There was a problem hiding this comment.
The logic within check_prefetch_progress for handling self.ongoing_prefetch and self.ongoing_prefetch_draft is almost identical, leading to significant code duplication. This section could be refactored into a helper method that takes the ongoing_prefetch map and the operation as arguments, improving readability and maintainability.
def _process_prefetch_operation(self, req_id: str, pending_map: dict, is_draft: bool):
last_host_node, token_ids, host_indices, operation = pending_map[req_id]
if operation.host_indices is None:
# prefetch has not been issued due to insufficient host memory
return True
if not self.can_terminate_prefetch(operation):
return False
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
operation
)
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
min_completed_tokens = completed_tokens
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce(
completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
min_completed_tokens = completed_tokens_tensor.item()
fetched_token_ids = token_ids[:min_completed_tokens]
written_indices = host_indices[:min_completed_tokens]
matched_length = self._insert_helper_host(
last_host_node,
RadixKey(
token_ids=fetched_token_ids,
extra_key=last_host_node.key.extra_key,
),
written_indices,
hash_value[: min_completed_tokens // self.page_size],
)
if is_draft:
self.cache_controller.mem_pool_host_draft.free(
host_indices[:matched_length]
)
self.cache_controller.append_host_mem_release(
host_indices[min_completed_tokens:completed_tokens],
is_draft=True,
)
else:
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
self.cache_controller.append_host_mem_release(
host_indices[min_completed_tokens:completed_tokens]
)
last_host_node.release_host()
del pending_map[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
if self.enable_storage_metrics and not is_draft: # Assuming metrics are only for main
self.storage_metrics_collector.log_prefetched_tokens(
min_completed_tokens - matched_length
)
return True
def check_prefetch_progress(self, req_id: str) -> bool:
if (
req_id not in self.ongoing_prefetch
and req_id not in self.ongoing_prefetch_draft
):
# there is no ongoing prefetch for this request or it has been revoked
return True
if req_id in self.ongoing_prefetch_draft:
_, _, _, draft_operation = self.ongoing_prefetch_draft[req_id]
if (
draft_operation.host_indices is not None
and not self.can_terminate_prefetch(draft_operation)
):
return False
self._process_prefetch_operation(req_id, self.ongoing_prefetch_draft, True)
if req_id in self.ongoing_prefetch:
return self._process_prefetch_operation(req_id, self.ongoing_prefetch, False)
return True
Motivation
base on #17338
support spec decode+hicache storage
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci