-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[P/D disagg] - support decode side radix cache #19746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 57 commits
73f769c
c316c42
0268551
5fa0819
49cf9e6
95dc9a7
c21f1cd
aa6a6e6
5b93682
64cf052
9873033
c782926
5e8fb59
eacb121
2c3cf34
4a30312
90a4a8b
b990f29
e141484
018c311
6628e96
58d52c7
05a8e65
4cd3ee8
6d1171c
9ed1495
6b68b47
d1f3d74
824c974
40f8848
8698b5f
10e47e9
74ac023
5187e30
3cb66bc
bed7d4a
67bfc8c
8bc1310
ee7bee3
8da42a4
963ea7f
540ee41
a34a99d
53294d6
03bf260
9ac9bb7
7856688
4523ebb
6058b91
a6b6e8b
a0ebaad
01333d4
7e5e8a1
83c76ca
f035678
421bfd6
4f2c34e
c6e70aa
b24f58d
8ea6649
b0bcc35
f7327da
7f20917
2a83901
00c8672
58ec674
2d2f990
839f0c7
fd86f5a
a05b933
1aa3ce1
b731e82
e6541bb
5daee09
02483bd
da927b8
b666f15
5175ef7
2623dd1
79db56a
d6b462d
260d1e0
98109bc
fd8c021
f57ca13
b41b1ff
317d66f
6b7c98d
be816d6
5175e5a
d1aec3a
a3d3ce4
7d7156c
f6b9601
eaee8c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @xiezhq-hermann, we need your help. The logic of these tree cache changes looks reasonable, but I am not quite sure whether it will introduce a potential memory leak or bug when used with other features, and many of the logics here are not protected by the decode radix cache server args. We need an expert to check on this. |
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,8 +44,15 @@ class TransferInfo: | |
| dst_aux_index: int | ||
| required_dst_info_num: int | ||
| dst_state_indices: List[int] | ||
| decode_prefix_len: Optional[int] = None # for decode radix cache | ||
|
|
||
| def is_dummy(self): | ||
| # A transfer is "dummy" only for CP non-authoritative ranks. | ||
| # When dst_kv_indices is empty due to a decode-side radix cache | ||
| # full hit (decode_prefix_len > 0), the transfer is NOT dummy -- | ||
| # aux/state data still needs to be sent. | ||
| if self.dst_kv_indices.size == 0 and self.decode_prefix_len: | ||
|
ishandhanani marked this conversation as resolved.
|
||
| return False | ||
| return self.dst_kv_indices.size == 0 | ||
|
|
||
| @classmethod | ||
|
|
@@ -65,6 +72,9 @@ def from_zmq(cls, msg: List[bytes]): | |
| dst_aux_index=int(msg[5].decode("ascii")), | ||
| required_dst_info_num=int(msg[6].decode("ascii")), | ||
| dst_state_indices=dst_state_indices, | ||
| decode_prefix_len=( | ||
| int(msg[8].decode("ascii")) if len(msg) > 8 and msg[8] != b"" else None | ||
| ), # hacky just add it into the message that will be sent | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -883,39 +893,44 @@ def add_transfer_request( | |
| assert len(chunked_dst_kv_indice) == len(kv_indices) | ||
| assert req.agent_name in self.decode_kv_args_table | ||
|
|
||
| notif = ( | ||
| f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.engine_rank}" | ||
| ) | ||
| decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size | ||
|
|
||
| if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): | ||
| kv_xfer_handle = self.send_kvcache( | ||
| req.agent_name, | ||
| kv_indices, | ||
| self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, | ||
| chunked_dst_kv_indice, | ||
| self.decode_kv_args_table[req.agent_name].gpu_id, | ||
| notif, | ||
| ) | ||
| else: | ||
| kv_xfer_handle = self.send_kvcache_slice( | ||
| req.agent_name, | ||
| kv_indices, | ||
| self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, | ||
| chunked_dst_kv_indice, | ||
| self.decode_kv_args_table[req.agent_name].gpu_id, | ||
| notif, | ||
| prefill_tp_size=self.attn_tp_size, | ||
| decode_tp_size=decode_tp_size, | ||
| decode_tp_rank=self.decode_kv_args_table[ | ||
| req.agent_name | ||
| ].decode_tp_rank, | ||
| dst_kv_item_len=self.decode_kv_args_table[ | ||
| req.agent_name | ||
| ].dst_kv_item_len, | ||
| # Skip KV RDMA transfer when there are no pages to send | ||
| # (e.g., decode-side radix cache matched the entire prefix). | ||
| # Aux data is still sent below when is_last=True. | ||
| if len(kv_indices) > 0: | ||
| notif = ( | ||
| f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.pp_rank}" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like a rebase issue, it reverted a fix for the hang when TP P>D. I will fix it in #23967 |
||
| ) | ||
|
|
||
| handles.append(kv_xfer_handle) | ||
| if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): | ||
| kv_xfer_handle = self.send_kvcache( | ||
| req.agent_name, | ||
| kv_indices, | ||
| self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, | ||
| chunked_dst_kv_indice, | ||
| self.decode_kv_args_table[req.agent_name].gpu_id, | ||
| notif, | ||
| ) | ||
| else: | ||
| kv_xfer_handle = self.send_kvcache_slice( | ||
| req.agent_name, | ||
| kv_indices, | ||
| self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, | ||
| chunked_dst_kv_indice, | ||
| self.decode_kv_args_table[req.agent_name].gpu_id, | ||
| notif, | ||
| prefill_tp_size=self.attn_tp_size, | ||
| decode_tp_size=decode_tp_size, | ||
| decode_tp_rank=self.decode_kv_args_table[ | ||
| req.agent_name | ||
| ].decode_tp_rank, | ||
| dst_kv_item_len=self.decode_kv_args_table[ | ||
| req.agent_name | ||
| ].dst_kv_item_len, | ||
| ) | ||
|
|
||
| handles.append(kv_xfer_handle) | ||
| # Only the last chunk we need to send the aux data. | ||
| if is_last: | ||
| if state_indices is not None: | ||
|
|
@@ -936,16 +951,24 @@ def add_transfer_request( | |
| handles.append(state_xfer_handle) | ||
|
|
||
| assert aux_index is not None | ||
| # When no KV pages were sent (decode-side cache hit), | ||
| # encode pp_rank in aux notif so receiver can mark | ||
| # expected_kvs_per_pp[pp_rank] = 0. | ||
| if len(kv_indices) == 0: | ||
| aux_notif = f"{req.room}_aux_nokv_{self.kv_args.pp_rank}" | ||
| else: | ||
| aux_notif = f"{req.room}_aux" | ||
| aux_xfer_handle = self.send_aux( | ||
| req.agent_name, | ||
| aux_index, | ||
| self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, | ||
| req.dst_aux_index, | ||
| f"{req.room}_aux", | ||
| aux_notif, | ||
| ) | ||
| handles.append(aux_xfer_handle) | ||
| if is_last: | ||
| del self.transfer_infos[bootstrap_room] | ||
| self.req_to_decode_prefix_len.pop(bootstrap_room, None) | ||
| return handles | ||
|
|
||
| def update_transfer_status(self): | ||
|
|
@@ -978,6 +1001,15 @@ def update_transfer_status(self): | |
| ) | ||
| elif components[1] == "aux": | ||
| self.transfer_statuses[room].received_aux = True | ||
| # Handle "nokv" marker: no KV pages were sent for | ||
| # this pp_rank (decode-side radix cache hit). | ||
| if len(components) > 3 and components[2] == "nokv": | ||
| pp_rank = int(components[3]) | ||
| self.transfer_statuses[room].expected_kvs_per_pp[pp_rank] = 0 | ||
| if self.transfer_statuses[room].num_pp_ranks_expected is None: | ||
| self.transfer_statuses[room].num_pp_ranks_expected = ( | ||
| self.required_prefill_response_num_table.get(room, 1) | ||
| ) | ||
| elif components[1] == "state": | ||
| pp_rank = int(components[2]) if len(components) > 2 else 0 | ||
| self.transfer_statuses[room].received_state_per_pp.add(pp_rank) | ||
|
|
@@ -1019,6 +1051,14 @@ def bootstrap_thread(): | |
| ].required_dst_info_num | ||
| logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}") | ||
| if len(self.transfer_infos[room]) == required_dst_info_num: | ||
| self.req_to_decode_prefix_len[room] = next( | ||
| ( | ||
| info.decode_prefix_len | ||
| for info in self.transfer_infos[room].values() | ||
| if info.decode_prefix_len is not None | ||
| ), | ||
| 0, | ||
| ) | ||
| logger.debug(f"{room=} is bootstrapped") | ||
| self.update_status(room, KVPoll.WaitingForInput) | ||
|
|
||
|
|
@@ -1113,6 +1153,7 @@ def send_metadata( | |
| kv_indices: npt.NDArray[np.int32], | ||
| aux_index: Optional[int] = None, | ||
| state_indices: Optional[List[int]] = None, | ||
| decode_prefix_len: Optional[int] = None, | ||
| ): | ||
| if self.bootstrap_infos is None: | ||
| logger.error( | ||
|
|
@@ -1146,6 +1187,7 @@ def send_metadata( | |
| if not is_dummy and state_indices is not None | ||
| else b"" | ||
| ), | ||
| str(decode_prefix_len or 0).encode("ascii"), | ||
| ] | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: consider putting this pop logic in the
clear()function.I can help with this when I am doing that following PR, we can keep it this way now.