Skip to content

[HiCache][WIP] support spec decode+hicache storage#17776

Closed
alphabetc1 wants to merge 1 commit into
sgl-project:mainfrom
alphabetc1:feat/hicache_spec_decode_l3
Closed

[HiCache][WIP] support spec decode+hicache storage#17776
alphabetc1 wants to merge 1 commit into
sgl-project:mainfrom
alphabetc1:feat/hicache_spec_decode_l3

Conversation

@alphabetc1
Copy link
Copy Markdown
Collaborator

@alphabetc1 alphabetc1 commented Jan 26, 2026

Motivation

base on #17338
support spec decode+hicache storage

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @alphabetc1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the HiCache system by introducing comprehensive support for speculative decoding. It achieves this by implementing a dual KV cache management system, allowing separate storage and retrieval of key-value states for both the main and a 'draft' model. This architectural change is crucial for optimizing inference speed in large language models by enabling parallel processing and early prediction mechanisms.

Highlights

  • Draft Model KV Cache Integration: Introduced new attributes and methods within CacheController to manage a separate KV cache pool and storage backend specifically for draft models, enabling distinct handling of main and draft model states.
  • Speculative Decoding Support in Cache Operations: Modified CacheOperation, StorageOperation, and PrefetchOperation classes to include a use_draft or is_draft flag, allowing cache write, load, prefetch, and eviction operations to conditionally target either the main or draft KV cache pools and storage backends.
  • Scheduler Integration for Draft KV Pools: Added a new _register_draft_kv_pool_for_hicache method in the Scheduler to dynamically initialize and register the draft model's KV cache pools with the CacheController when speculative decoding is active.
  • HiRadixCache Enhancements for Dual Cache Management: Updated HiRadixCache to track ongoing prefetch and backup operations for both main and draft models separately, and adjusted control queue draining and prefetch progress checks to accommodate the dual cache system.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces support for speculative decoding with HiCache storage by integrating a draft KV pool and associated storage backend. The changes involve extending CacheOperation, StorageOperation, and PrefetchOperation classes with an is_draft flag, and modifying the HiCacheController and HiRadixCache to manage separate memory pools and storage operations for the draft model. This is a significant feature addition that enhances the system's capabilities for speculative decoding.

Comment on lines +540 to +543
ops = self.write_queue
self.write_queue = []
if not self.has_draft_kv_pool:
ops = [CacheOperation.merge_ops(ops)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The start_writing method's behavior changes significantly based on self.has_draft_kv_pool. When self.has_draft_kv_pool is False, all operations in self.write_queue are merged into a single CacheOperation before processing. However, when self.has_draft_kv_pool is True, the method iterates over individual operations without merging them. This difference in batching strategy could lead to performance variations, potentially reducing throughput when the draft KV pool is active due to more frequent, smaller GPU transfers. Consider if merging operations is still beneficial when self.has_draft_kv_pool is true, or if the current approach is intentional for specific reasons (e.g., finer-grained control over draft operations).

Comment on lines +622 to +625
ops = self.load_queue
self.load_queue = []
if not self.has_draft_kv_pool:
ops = [CacheOperation.merge_ops(ops)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to start_writing, the start_loading method's behavior changes based on self.has_draft_kv_pool. When self.has_draft_kv_pool is False, all operations in self.load_queue are merged into a single CacheOperation before processing. When self.has_draft_kv_pool is True, the method iterates over individual operations without merging them. This could lead to performance differences between the two modes. If batching is generally more efficient for GPU transfers, consider if merging operations is still desirable when the draft KV pool is active.

Comment on lines +521 to +526
elif not torch.equal(draft_indices, host_indices):
self.mem_pool_host_draft.free(draft_indices)
logger.warning(
"Draft HiCache host indices desynced. Disable draft hicache."
)
self.has_draft_kv_pool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check elif not torch.equal(draft_indices, host_indices): followed by self.mem_pool_host_draft.free(draft_indices) in the write method implies a strong assumption that draft_indices and host_indices should always be identical if self.has_draft_kv_pool is true. If they are not equal, the draft HiCache is disabled. This design choice should be clearly documented, explaining why the indices must be identical and what happens if they diverge. Also, ensure that freeing draft_indices in this specific else if branch does not inadvertently cause issues if host_indices (which is equal to draft_indices in the else branch) is still expected to be managed by the main mem_pool_host.

Comment on lines +697 to +797
if req_id in self.ongoing_prefetch:
# todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
req_id
]

if not self.can_terminate_prefetch(operation):
return False
if operation.host_indices is None:
# prefetch has not been issued due to insufficient host memory
return True

completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
operation
)
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
if not self.can_terminate_prefetch(operation):
return False

min_completed_tokens = completed_tokens
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce(
completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
operation
)
min_completed_tokens = completed_tokens_tensor.item()
fetched_token_ids = token_ids[:min_completed_tokens]
written_indices = host_indices[:min_completed_tokens]
matched_length = self._insert_helper_host(
last_host_node,
RadixKey(
token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
),
written_indices,
hash_value[: min_completed_tokens // self.page_size],
)
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")

self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
self.cache_controller.append_host_mem_release(
host_indices[min_completed_tokens:completed_tokens]
)
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
min_completed_tokens = completed_tokens
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce(
completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
min_completed_tokens = completed_tokens_tensor.item()
fetched_token_ids = token_ids[:min_completed_tokens]
written_indices = host_indices[:min_completed_tokens]
matched_length = self._insert_helper_host(
last_host_node,
RadixKey(
token_ids=fetched_token_ids,
extra_key=last_host_node.key.extra_key,
),
written_indices,
hash_value[: min_completed_tokens // self.page_size],
)

if self.enable_storage_metrics:
self.storage_metrics_collector.log_prefetched_tokens(
min_completed_tokens - matched_length
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
self.cache_controller.append_host_mem_release(
host_indices[min_completed_tokens:completed_tokens]
)
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)

if self.enable_storage_metrics:
self.storage_metrics_collector.log_prefetched_tokens(
min_completed_tokens - matched_length
)

if req_id in self.ongoing_prefetch_draft:
(
draft_last_host_node,
draft_token_ids,
draft_host_indices,
draft_operation,
) = self.ongoing_prefetch_draft[req_id]
if draft_operation.host_indices is not None:
draft_completed_tokens, draft_hash_value = (
self.cache_controller.terminate_prefetch(draft_operation)
)
draft_min_completed_tokens = draft_completed_tokens
if self.tp_world_size > 1:
draft_completed_tokens_tensor = torch.tensor(
draft_min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce(
draft_completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
draft_min_completed_tokens = draft_completed_tokens_tensor.item()
draft_fetched_token_ids = draft_token_ids[:draft_min_completed_tokens]
draft_written_indices = draft_host_indices[:draft_min_completed_tokens]
draft_matched_length = self._insert_helper_host(
draft_last_host_node,
RadixKey(
token_ids=draft_fetched_token_ids,
extra_key=draft_last_host_node.key.extra_key,
),
draft_written_indices,
draft_hash_value[: draft_min_completed_tokens // self.page_size],
)
self.cache_controller.mem_pool_host_draft.free(
draft_host_indices[:draft_matched_length]
)
self.cache_controller.append_host_mem_release(
draft_host_indices[
draft_min_completed_tokens:draft_completed_tokens
],
is_draft=True,
)
draft_last_host_node.release_host()
del self.ongoing_prefetch_draft[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(draft_token_ids)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic within check_prefetch_progress for handling self.ongoing_prefetch and self.ongoing_prefetch_draft is almost identical, leading to significant code duplication. This section could be refactored into a helper method that takes the ongoing_prefetch map and the operation as arguments, improving readability and maintainability.

    def _process_prefetch_operation(self, req_id: str, pending_map: dict, is_draft: bool):
        last_host_node, token_ids, host_indices, operation = pending_map[req_id]

        if operation.host_indices is None:
            # prefetch has not been issued due to insufficient host memory
            return True

        if not self.can_terminate_prefetch(operation):
            return False

        completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
            operation
        )
        logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")

        min_completed_tokens = completed_tokens
        if self.tp_world_size > 1:
            # synchrnoize TP workers to make the same update to hiradix cache
            completed_tokens_tensor = torch.tensor(
                min_completed_tokens, dtype=torch.int
            )
            torch.distributed.all_reduce(
                completed_tokens_tensor,
                op=torch.distributed.ReduceOp.MIN,
                group=self.tp_group,
            )
            min_completed_tokens = completed_tokens_tensor.item()
        fetched_token_ids = token_ids[:min_completed_tokens]
        written_indices = host_indices[:min_completed_tokens]
        matched_length = self._insert_helper_host(
            last_host_node,
            RadixKey(
                token_ids=fetched_token_ids,
                extra_key=last_host_node.key.extra_key,
            ),
            written_indices,
            hash_value[: min_completed_tokens // self.page_size],
        )

        if is_draft:
            self.cache_controller.mem_pool_host_draft.free(
                host_indices[:matched_length]
            )
            self.cache_controller.append_host_mem_release(
                host_indices[min_completed_tokens:completed_tokens],
                is_draft=True,
            )
        else:
            self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
            self.cache_controller.append_host_mem_release(
                host_indices[min_completed_tokens:completed_tokens]
            )
        last_host_node.release_host()
        del pending_map[req_id]
        self.cache_controller.prefetch_tokens_occupied -= len(token_ids)

        if self.enable_storage_metrics and not is_draft: # Assuming metrics are only for main
            self.storage_metrics_collector.log_prefetched_tokens(
                min_completed_tokens - matched_length
            )
        return True

    def check_prefetch_progress(self, req_id: str) -> bool:
        if (
            req_id not in self.ongoing_prefetch
            and req_id not in self.ongoing_prefetch_draft
        ):
            # there is no ongoing prefetch for this request or it has been revoked
            return True

        if req_id in self.ongoing_prefetch_draft:
            _, _, _, draft_operation = self.ongoing_prefetch_draft[req_id]
            if (
                draft_operation.host_indices is not None
                and not self.can_terminate_prefetch(draft_operation)
            ):
                return False
            self._process_prefetch_operation(req_id, self.ongoing_prefetch_draft, True)

        if req_id in self.ongoing_prefetch:
            return self._process_prefetch_operation(req_id, self.ongoing_prefetch, False)

        return True

@alphabetc1 alphabetc1 marked this pull request as draft January 27, 2026 07:28
@alphabetc1 alphabetc1 closed this Mar 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant