-
Notifications
You must be signed in to change notification settings - Fork 5.3k
[P/D disagg] - support decode side radix cache #19746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
73f769c
c316c42
0268551
5fa0819
49cf9e6
95dc9a7
c21f1cd
aa6a6e6
5b93682
64cf052
9873033
c782926
5e8fb59
eacb121
2c3cf34
4a30312
90a4a8b
b990f29
e141484
018c311
6628e96
58d52c7
05a8e65
4cd3ee8
6d1171c
9ed1495
6b68b47
d1f3d74
824c974
40f8848
8698b5f
10e47e9
74ac023
5187e30
3cb66bc
bed7d4a
67bfc8c
8bc1310
ee7bee3
8da42a4
963ea7f
540ee41
a34a99d
53294d6
03bf260
9ac9bb7
7856688
4523ebb
6058b91
a6b6e8b
a0ebaad
01333d4
7e5e8a1
83c76ca
f035678
421bfd6
4f2c34e
c6e70aa
b24f58d
8ea6649
b0bcc35
f7327da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @xiezhq-hermann, we need your help. The logic of these tree cache changes looks reasonable, but I am not quite sure whether it will introduce a potential memory leak or bug when used with other features, and many of the logics here are not protected by the decode radix cache server args. We need an expert to check on this. |
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -324,7 +324,7 @@ def pop_bootstrapped( | |
| self.scheduler.tree_cache.release_aborted_request(req.rid) | ||
| continue | ||
|
|
||
| # KV.WaitingForInput - init here | ||
| # KV.WaitingForInput - decode is ready to receive. initialize the kv sender | ||
| req.time_stats.set_bootstrap_done_time() | ||
| num_kv_indices = len(req.origin_input_ids) | ||
| if self.req_to_metadata_buffer_idx_allocator.available_size() == 0: | ||
|
|
@@ -335,7 +335,19 @@ def pop_bootstrapped( | |
| ) | ||
| assert req.metadata_buffer_index is not None | ||
|
|
||
| num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size) | ||
| # Cal number of pages to send | ||
| # if decode has a cached prefix, we need to send the delta indices | ||
| # otherwise, send the entire request | ||
| decode_prefix_len = ( | ||
| req.disagg_kv_sender.kv_mgr.req_to_decode_prefix_len.pop( | ||
| req.bootstrap_room, 0 | ||
| ) | ||
| ) | ||
| req.start_send_idx = decode_prefix_len | ||
| num_kv_indices_to_send = num_kv_indices - decode_prefix_len | ||
| num_pages = kv_to_page_num( | ||
| num_kv_indices_to_send, self.token_to_kv_pool.page_size | ||
| ) | ||
| req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index) | ||
|
|
||
| bootstrapped_reqs.append(req) | ||
|
|
@@ -768,12 +780,20 @@ def send_kv_chunk( | |
| # if not the last chunk and the last page is partial, delay the last partial page to the next send | ||
| end_idx = end_idx - end_idx % page_size | ||
|
|
||
| if end_idx < start_idx: | ||
| logger.debug( | ||
| "send_kv_chunk skip: rid=%s start_send_idx=%s end_idx=%s", | ||
| req.rid, | ||
| start_idx, | ||
| end_idx, | ||
| ) | ||
| return | ||
|
|
||
|
Comment on lines
+783
to
+791
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be better if we return early here?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. I changed the |
||
| kv_indices = ( | ||
| self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx] | ||
| .cpu() | ||
| .numpy() | ||
| ) | ||
| req.start_send_idx = end_idx | ||
| state_indices = None | ||
| if last_chunk: | ||
| self.disagg_metadata_buffers.set_buf(req) | ||
|
|
@@ -820,9 +840,14 @@ def send_kv_chunk( | |
| state_indices = kv_to_page_indices(state_indices, page_size) | ||
|
|
||
| page_indices = kv_to_page_indices(kv_indices, page_size) | ||
| # Skip empty non-last chunks for all backends. For empty last chunks, | ||
| # only NIXL currently defines the aux/state-only completion path used | ||
| # by decode-side radix cache; keep a conservative early return for | ||
| # other backends until they implement the same semantics. | ||
| if len(page_indices) == 0: | ||
| logger.info( | ||
| f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty" | ||
| ) | ||
| return | ||
| if not last_chunk: | ||
| return | ||
| if self.transfer_backend != TransferBackend.NIXL: | ||
| return | ||
| req.disagg_kv_sender.send(page_indices, state_indices) | ||
| req.start_send_idx = end_idx | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: consider putting this pop logic in the
clear()function.I can help with this when I am doing that following PR, we can keep it this way now.