Skip to content

[NIXL][BUG FIX] Fix both failing issue and accuracy issue with nixl + host_buffer on CUDA#30419

Merged
NickLucche merged 14 commits intovllm-project:mainfrom
xuechendi:debug/block_id_mistatch_schedule_to_finish
Dec 18, 2025
Merged

[NIXL][BUG FIX] Fix both failing issue and accuracy issue with nixl + host_buffer on CUDA#30419
NickLucche merged 14 commits intovllm-project:mainfrom
xuechendi:debug/block_id_mistatch_schedule_to_finish

Conversation

@xuechendi
Copy link
Copy Markdown
Contributor

@xuechendi xuechendi commented Dec 10, 2025

Purpose

This PR is fixed upon #30420
Should get that one merged firstly

Two issue detected and resolved in this PR

  1. Fix a bug after [NIXL] Add remote_request_id to kv_transfer_params #29665 for running PD with cpu host buffer
  2. Fix accuracy issue for running PD with cpu host buffer, described in [Bug]: NIXL PD disaggregate with host_buffer has accuracy issue - Prefill scheduled num_block mismatch at update_state_after_alloc and request_finished #30358

Test Plan

PREFILL_BLOCK_SIZE=16 DECODE_BLOCK_SIZE=16 bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh --kv_buffer_device cpu

Qwen3-0.6B
Before: accuracy is ~0.3
Now: Accuracy is = 0.4109

Root Cause and proposed change

Previous:

when scheduler call schedule, only brand new request will gets self.connector.update_state_after_alloc

=> However, if one prefill request is chunked into small request, the self.connector.update_state_after_alloc only registered partial of block_ids(first request of the prefill) into nixl_metadata

--

Solution:

Add another call to self.connector.update_state_after_alloc at running queue process, so if the following chunked gets scheduled, it will continue to update block_ids to nixl metadata.

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a failing issue and an accuracy issue related to NIXL with a CPU host buffer. The change in nixl_connector.py to use .get() for remote_request_id is a good defensive measure against potential KeyError exceptions. However, I've identified a critical issue in the new logic added to scheduler.py. The new call to update_state_after_alloc only passes new_blocks, which for chunked prefills, results in an incomplete list of blocks being registered for transfer. This would lead to data corruption and accuracy problems, which is likely the very issue this PR is trying to solve. I have provided a suggestion to fix this by passing all of the request's blocks.

@xuechendi
Copy link
Copy Markdown
Contributor Author

@KuntaiDu , May you take a review, not sure if LMCache might also uses needs update_state_after_alloc for getting latest block_ids

@NickLucche
Copy link
Copy Markdown
Collaborator

NickLucche commented Dec 11, 2025

cc @markmc @njhill for the scheduler changes.

markmc

This comment was marked as off-topic.

@xuechendi xuechendi force-pushed the debug/block_id_mistatch_schedule_to_finish branch from 75538df to fea492b Compare December 11, 2025 15:44
@xuechendi xuechendi requested a review from markmc December 11, 2025 16:09
@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Dec 12, 2025

@markmc @njhill I don't think this change is necessary as block_ids of subsequent allocations are available to connectors via SchedulerOutput.scheduled_cached_reqs.new_block_ids which is passed in build_connector_meta.

@xuechendi you can look at the offloading connector for a reference:

Now that I think of it, I think there's also another bug in the nixl connector triggered when a chunked-prefill request gets preempted.
Currently, looks like the nixl connector does not detect that.
For the offloading connector we also have the same bug, which should be solved by #29870.
Basically, we detect preemptions via SchedulerOutput.preempted_req_ids in build_connector_meta.

@xuechendi
Copy link
Copy Markdown
Contributor Author

xuechendi commented Dec 12, 2025

@orozery
Oh, I see, so in offloadingConnectorMetadata, it leverage scheduler_output to obtain the new_block_ids.
I can do that in nixl_connector as well.

@xuechendi xuechendi force-pushed the debug/block_id_mistatch_schedule_to_finish branch from fea492b to fc1ff41 Compare December 12, 2025 20:40
@xuechendi
Copy link
Copy Markdown
Contributor Author

@NickLucche @markmc @orozery , Now I switched the fixing by following similiar approach done in OffloadingConnectorScheduler, I have verified accuracy locally with host buffer, it looks good. Please help to review.
This PR is still depending on #30420 being merged firstly.

@xuechendi xuechendi force-pushed the debug/block_id_mistatch_schedule_to_finish branch from fc1ff41 to 5943fdf Compare December 12, 2025 21:06
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 15, 2025
Comment on lines +697 to +698
if is_partial:
continue
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if it's not better to save what we have, instead of waiting for the entire request to be available.
Then I see that self.copy_blocks is blocking so I guess it does not matter.
@NickLucche your thoughts?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's interesting developments for copying chunks, but for now I would treat this PR as a bug fix and prioritize getting the feature in a working state.
We can leave this optimization for future work.

Copy link
Copy Markdown
Contributor Author

@xuechendi xuechendi Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orozery done. Change back to copy immediately

Comment on lines +475 to +476
# list of GPU block IDs per request
self._request_block_ids: dict[ReqId, list[int]] = {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of tracking the block IDs of all requests, let's just track block IDs of requests that needs saving.
You can discard this new variable, and use self._reqs_need_save to track the block IDs.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree we can re-use the existing container

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work @xuechendi and the great review @orozery !
Left a few comments but things look good overall.

Comment on lines +475 to +476
# list of GPU block IDs per request
self._request_block_ids: dict[ReqId, list[int]] = {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree we can re-use the existing container

Comment on lines +697 to +698
if is_partial:
continue
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's interesting developments for copying chunks, but for now I would treat this PR as a bug fix and prioritize getting the feature in a working state.
We can leave this optimization for future work.

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
@xuechendi
Copy link
Copy Markdown
Contributor Author

Thanks! @NickLucche, I added the comment

if not is_partial:
      # For non-partial prefills, once new req_meta is scheduled, it
      # can be removed from _reqs_need_save.
      # For partial prefill case, we will retain the request in
      # _reqs_need_save until all blocks are scheduled with req_meta.
      # Therefore, only pop if `not is_partial`.
      self._reqs_need_save.pop(req_id)

@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 17, 2025
@NickLucche NickLucche enabled auto-merge (squash) December 17, 2025 09:29
@NickLucche NickLucche disabled auto-merge December 17, 2025 09:39
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Comment on lines +763 to +765
# Clear _reqs_need_save if a request is aborted as partial prefill.
self._reqs_need_save.pop(request.request_id, None)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to move this upward a few lines, at least before if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, thanks, updated

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Comment on lines +761 to +762
# Clear _reqs_need_save if a request is aborted as partial prefill.
self._reqs_need_save.pop(request.request_id, None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should work, but it seems more fragile to me.
I would go further and put this pop right at the beginning.
Then, you could also remove the entire is_partial check from build_connector_meta.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since _reqs_need_save originally is used to buffer which requests should create req_meta for saving to host. So the life cyle is from "scheduled" to "all request metadata are created".

If we go with your proposal, the life cycle becomes from "scheduled" to "request ends"
This changes the original design.

@NickLucche, do you think we should do that?

I assume the fix here is to just handle a corner case when request was aborted ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @orozery is proposing to only clear the id on request finished, so either terminal block was processed or abort/error.
I think it makes sense but I don't necessarily think the current implementation is fragile.

Hence I don't have a strong opinion here, this could also be done in a separate PR, as long as we maximize clarity for these cases.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickLucche @orozery , let's do that in separate PR, since other queues _reqs_to_recv _reqs_need_send and etc are being cleared in build_connector_meta, I would prefer not adding refactor scope into this PR (originally for for accuracy fix). And once this one merged, I can open a new PR and we can have better design there.

@mergify
Copy link
Copy Markdown

mergify bot commented Dec 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 18, 2025

Hi @xuechendi, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@xuechendi xuechendi force-pushed the debug/block_id_mistatch_schedule_to_finish branch from d0d2752 to a7d268d Compare December 18, 2025 15:20
@NickLucche NickLucche enabled auto-merge (squash) December 18, 2025 17:09
@markmc markmc dismissed their stale review December 18, 2025 22:10

Unblocking

@NickLucche NickLucche merged commit 6ca74bc into vllm-project:main Dec 18, 2025
52 checks passed
@github-project-automation github-project-automation bot moved this to Done in NVIDIA Dec 18, 2025
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Dec 22, 2025
… host_buffer on CUDA (vllm-project#30419)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
… host_buffer on CUDA (vllm-project#30419)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
… host_buffer on CUDA (vllm-project#30419)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
… host_buffer on CUDA (vllm-project#30419)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants