-
Notifications
You must be signed in to change notification settings - Fork 5.2k
[PD]: Add support for HiSparse to directly transfer the cache from Prefill to Decode DRAM. #21591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b22b672
6941052
57375d3
4a5b54f
1f6ce8a
5bd2254
96e6a1b
bd18a9f
0d4ab1b
f183073
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -127,6 +127,8 @@ class KVArgsRegisterInfo: | |||
| # for mamba state different tp slice transfer | ||||
| dst_state_item_lens: list[int] | ||||
| dst_state_dim_per_tensor: list[int] | ||||
| # HiSparse: decode host pool stores KV at token granularity | ||||
| enable_hisparse: bool = False | ||||
|
Comment on lines
+130
to
+131
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be updated too. We keep the |
||||
| staging: Optional[StagingRegisterInfo] = None | ||||
|
|
||||
| @classmethod | ||||
|
|
@@ -152,7 +154,10 @@ def from_zmq(cls, msg: List[bytes]): | |||
| if len(msg) > 11 and len(msg[11]) > 0 | ||||
| else [] | ||||
| ), | ||||
| staging=StagingRegisterInfo.from_zmq_fields(msg, 12), | ||||
| enable_hisparse=( | ||||
| msg[12].decode("ascii") == "1" if len(msg) > 12 else False | ||||
| ), | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can move
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also please double check on this line:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CC: @YAMY1234
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okko
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ShangmingCai updated
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM |
||||
| staging=StagingRegisterInfo.from_zmq_fields(msg, 13), | ||||
| ) | ||||
|
|
||||
|
|
||||
|
|
@@ -694,6 +699,49 @@ def send_kvcache( | |||
| executor=executor, | ||||
| ) | ||||
|
|
||||
| def send_kvcache_hisparse( | ||||
| self, | ||||
| mooncake_session_id: str, | ||||
| prefill_kv_indices: npt.NDArray[np.int32], | ||||
| dst_kv_ptrs: list[int], | ||||
| dst_kv_indices: npt.NDArray[np.int32], | ||||
| page_index_slice: slice, | ||||
| executor: concurrent.futures.ThreadPoolExecutor, | ||||
| ): | ||||
| """HiSparse transfer: prefill page_size > decode host page_size=1. | ||||
|
|
||||
| Receives page-level prefill_kv_indices and the full token-level | ||||
| dst_kv_indices. Expands both to token granularity before transfer. | ||||
| """ | ||||
| page_size = self.kv_args.page_size | ||||
| per_token_item_lens = [il // page_size for il in self.kv_args.kv_item_lens] | ||||
|
|
||||
| # Expand page-level src indices to token-level | ||||
| base = np.repeat(prefill_kv_indices * page_size, page_size) | ||||
| offsets = np.tile(np.arange(page_size, dtype=np.int32), len(prefill_kv_indices)) | ||||
| expanded_src = base + offsets | ||||
|
|
||||
| # Expand page-level index_slice to token-level for dst | ||||
| token_start = page_index_slice.start * page_size | ||||
| token_end = min(page_index_slice.stop * page_size, len(dst_kv_indices)) | ||||
| expanded_dst = dst_kv_indices[token_start:token_end] | ||||
|
|
||||
| # Clip src to match dst length (last page may be partial) | ||||
| expanded_src = expanded_src[: len(expanded_dst)] | ||||
|
|
||||
| logger.debug( | ||||
| f"Send KVCache for hisparse: {expanded_src.shape} -> {expanded_dst.shape}" | ||||
| ) | ||||
| return self._send_kvcache_generic( | ||||
| mooncake_session_id=mooncake_session_id, | ||||
| src_data_ptrs=self.kv_args.kv_data_ptrs, | ||||
| dst_data_ptrs=dst_kv_ptrs, | ||||
| item_lens=per_token_item_lens, | ||||
| prefill_data_indices=expanded_src, | ||||
| dst_data_indices=expanded_dst, | ||||
| executor=executor, | ||||
| ) | ||||
|
|
||||
| def send_kvcache_slice( | ||||
| self, | ||||
| mooncake_session_id: str, | ||||
|
|
@@ -1165,13 +1213,23 @@ def transfer_worker( | |||
| self.attn_tp_size | ||||
| == target_rank_registration_info.dst_attn_tp_size | ||||
| ): | ||||
| ret = self.send_kvcache( | ||||
| req.mooncake_session_id, | ||||
| kv_chunk.prefill_kv_indices, | ||||
| target_rank_registration_info.dst_kv_ptrs, | ||||
| chunked_dst_kv_indice, | ||||
| executor, | ||||
| ) | ||||
| if target_rank_registration_info.enable_hisparse: | ||||
| ret = self.send_kvcache_hisparse( | ||||
| req.mooncake_session_id, | ||||
| kv_chunk.prefill_kv_indices, | ||||
| target_rank_registration_info.dst_kv_ptrs, | ||||
| req.dst_kv_indices, | ||||
| kv_chunk.index_slice, | ||||
| executor, | ||||
| ) | ||||
| else: | ||||
| ret = self.send_kvcache( | ||||
| req.mooncake_session_id, | ||||
| kv_chunk.prefill_kv_indices, | ||||
| target_rank_registration_info.dst_kv_ptrs, | ||||
| chunked_dst_kv_indice, | ||||
| executor, | ||||
| ) | ||||
| elif ( | ||||
| self.enable_staging | ||||
| and staging_strategy is not None | ||||
|
|
@@ -1715,6 +1773,7 @@ def _register_kv_args(self): | |||
| dst_tp_rank = str(tp_rank).encode("ascii") | ||||
| dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode("ascii") | ||||
| dst_kv_item_len = str(kv_item_len).encode("ascii") | ||||
| enable_hisparse = b"1" if self.kv_mgr.server_args.enable_hisparse else b"0" | ||||
|
|
||||
| if ( | ||||
| self.kv_mgr.enable_staging | ||||
|
|
@@ -1743,6 +1802,7 @@ def _register_kv_args(self): | |||
| dst_kv_item_len, | ||||
| packed_state_item_lens, | ||||
| packed_state_dim_per_tensor, | ||||
| enable_hisparse, | ||||
| packed_staging_base_ptr, | ||||
| staging_total_size_str, | ||||
| ] | ||||
|
|
||||
Uh oh!
There was an error while loading. Please reload this page.