-
-
Notifications
You must be signed in to change notification settings - Fork 15k
[NIXL][BUG FIX] Fix both failing issue and accuracy issue with nixl + host_buffer on CUDA #30419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7a7a75b
877c335
86429a8
2509c98
d2f0c47
ae3f679
3a9042e
1ab4093
f45d50a
b78d55b
a5f18f3
dcb0cfa
f3216ae
a7d268d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -23,7 +23,11 @@ | |||||||||||||||||||||
| from vllm.attention.backends.abstract import AttentionMetadata | ||||||||||||||||||||||
| from vllm.attention.selector import get_attn_backend | ||||||||||||||||||||||
| from vllm.config import VllmConfig | ||||||||||||||||||||||
| from vllm.distributed.kv_transfer.kv_connector.utils import EngineId, TpKVTopology | ||||||||||||||||||||||
| from vllm.distributed.kv_transfer.kv_connector.utils import ( | ||||||||||||||||||||||
| EngineId, | ||||||||||||||||||||||
| TpKVTopology, | ||||||||||||||||||||||
| yield_req_data, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| from vllm.distributed.kv_transfer.kv_connector.v1.base import ( | ||||||||||||||||||||||
| CopyBlocksOp, | ||||||||||||||||||||||
| KVConnectorBase_V1, | ||||||||||||||||||||||
|
|
@@ -481,7 +485,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): | |||||||||||||||||||||
| # New requests are added by update_state_after_alloc in | ||||||||||||||||||||||
| # the scheduler. Used to make metadata passed to Worker. | ||||||||||||||||||||||
| self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} | ||||||||||||||||||||||
| self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} | ||||||||||||||||||||||
| self._reqs_need_save: dict[ReqId, Request] = {} | ||||||||||||||||||||||
| # Reqs to send and their expiration time | ||||||||||||||||||||||
| self._reqs_need_send: dict[ReqId, float] = {} | ||||||||||||||||||||||
| self._reqs_in_batch: set[ReqId] = set() | ||||||||||||||||||||||
|
|
@@ -627,16 +631,7 @@ def update_state_after_alloc( | |||||||||||||||||||||
| if self.use_host_buffer and params.get("do_remote_decode"): | ||||||||||||||||||||||
| # NOTE: when accelerator is not directly supported by Nixl, | ||||||||||||||||||||||
| # prefilled blocks need to be saved to host memory before transfer. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # save all blocks | ||||||||||||||||||||||
| block_ids = blocks.get_block_ids()[0] | ||||||||||||||||||||||
| # TODO: skip the blocks that are already in the host xfer buffer. | ||||||||||||||||||||||
| # Currently, the host xfer buffer block is 1-to-1 mapped to device | ||||||||||||||||||||||
| # kv blocks, so host blocks won't be flushed as long as its device | ||||||||||||||||||||||
| # block is not overwritten; and it will be safe to skip saving them | ||||||||||||||||||||||
| # to host xfer buffer. | ||||||||||||||||||||||
| if block_ids: | ||||||||||||||||||||||
| self._reqs_need_save[request.request_id] = (request, block_ids) | ||||||||||||||||||||||
| self._reqs_need_save[request.request_id] = request | ||||||||||||||||||||||
| elif params.get("do_remote_prefill"): | ||||||||||||||||||||||
| if params.get("remote_block_ids"): | ||||||||||||||||||||||
| if all( | ||||||||||||||||||||||
|
|
@@ -688,21 +683,39 @@ def build_connector_meta( | |||||||||||||||||||||
| kv_transfer_params=req.kv_transfer_params, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for req_id, (req, block_ids) in self._reqs_need_save.items(): | ||||||||||||||||||||||
| # NOTE: For the prefill side, there might be a chance that an early added | ||||||||||||||||||||||
| # request is a chunked prefill, so we need to check if new blocks are added | ||||||||||||||||||||||
| for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output): | ||||||||||||||||||||||
| req_to_save = self._reqs_need_save.get(req_id) | ||||||||||||||||||||||
| if req_to_save is None or new_block_id_groups is None: | ||||||||||||||||||||||
| continue | ||||||||||||||||||||||
| req = req_to_save | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| assert req.kv_transfer_params is not None | ||||||||||||||||||||||
| meta.add_new_req_to_save( | ||||||||||||||||||||||
| request_id=req_id, | ||||||||||||||||||||||
| local_block_ids=block_ids, | ||||||||||||||||||||||
| local_block_ids=new_block_id_groups[0], | ||||||||||||||||||||||
| kv_transfer_params=req.kv_transfer_params, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| assert scheduler_output.num_scheduled_tokens is not None | ||||||||||||||||||||||
| num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] | ||||||||||||||||||||||
| is_partial = ( | ||||||||||||||||||||||
| req.num_computed_tokens + num_scheduled_tokens | ||||||||||||||||||||||
xuechendi marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
| ) < req.num_prompt_tokens | ||||||||||||||||||||||
|
Comment on lines
+702
to
+704
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might work...
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep let's keep this here this is a tmp buffer, its function terminates here once metadata are built
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NickLucche @xuechendi I think you are missing an edge case: vllm/vllm/v1/core/sched/scheduler.py Lines 1353 to 1362 in 811cdf5
When the scheduler finish requests before they finish processing, you will not clear
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which is every time a request is aborted too, nice catch @orozery !
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, @orozery ! @NickLucche , I've updated the PR with clean up in |
||||||||||||||||||||||
| if not is_partial: | ||||||||||||||||||||||
| # For non-partial prefills, once new req_meta is scheduled, it | ||||||||||||||||||||||
| # can be removed from _reqs_need_save. | ||||||||||||||||||||||
| # For partial prefill case, we will retain the request in | ||||||||||||||||||||||
| # _reqs_need_save until all blocks are scheduled with req_meta. | ||||||||||||||||||||||
| # Therefore, only pop if `not is_partial`. | ||||||||||||||||||||||
| self._reqs_need_save.pop(req_id) | ||||||||||||||||||||||
NickLucche marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| meta.reqs_to_send = self._reqs_need_send | ||||||||||||||||||||||
| meta.reqs_in_batch = self._reqs_in_batch | ||||||||||||||||||||||
| meta.reqs_not_processed = self._reqs_not_processed | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Clear the list once workers start the transfers | ||||||||||||||||||||||
| self._reqs_need_recv.clear() | ||||||||||||||||||||||
| self._reqs_need_save.clear() | ||||||||||||||||||||||
| self._reqs_in_batch = set() | ||||||||||||||||||||||
| self._reqs_not_processed = set() | ||||||||||||||||||||||
| self._reqs_need_send = {} | ||||||||||||||||||||||
|
|
@@ -748,6 +761,8 @@ def request_finished( | |||||||||||||||||||||
| # Also include the case of a P/D Prefill request with immediate | ||||||||||||||||||||||
| # block free (eg abort). Stop tracking this request. | ||||||||||||||||||||||
| self._reqs_not_processed.add(request.request_id) | ||||||||||||||||||||||
| # Clear _reqs_need_save if a request is aborted as partial prefill. | ||||||||||||||||||||||
| self._reqs_need_save.pop(request.request_id, None) | ||||||||||||||||||||||
|
Comment on lines
+764
to
+765
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should work, but it seems more fragile to me.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since If we go with your proposal, the life cycle becomes from "scheduled" to "request ends" @NickLucche, do you think we should do that? I assume the fix here is to just handle a corner case when request was aborted ?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think @orozery is proposing to only clear the id on request finished, so either terminal block was processed or abort/error. Hence I don't have a strong opinion here, this could also be done in a separate PR, as long as we maximize clarity for these cases.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NickLucche @orozery , let's do that in separate PR, since other queues |
||||||||||||||||||||||
| return False, None | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # TODO: check whether block_ids actually ever be 0. If not we could | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.