Skip to content

[NIXL] refine decoder side post process for heterogeneous BlockSize and kv_layout#30275

Merged
NickLucche merged 7 commits intovllm-project:mainfrom
xuechendi:dev/decode_KV_post_process
Jan 9, 2026
Merged

[NIXL] refine decoder side post process for heterogeneous BlockSize and kv_layout#30275
NickLucche merged 7 commits intovllm-project:mainfrom
xuechendi:dev/decode_KV_post_process

Conversation

@xuechendi
Copy link
Copy Markdown
Contributor

@xuechendi xuechendi commented Dec 8, 2025

Purpose

We have supported heterogeneous BlockSize and kv_layout in seperate post process methods.
This PR is to clean up and use single method to post_process for cases.

What is changed in this PR:

I removed permute_device_kv and blocksize_post_process, and move the logic into post_process_device_kv_on_receive as single post_process function with 3 options:

if enable_permute_local_kv and block_size_ratio > 1:
    _kv_postprocess_blksize_and_layout(
        cache, indices, block_size_ratio
    )
elif enable_permute_local_kv:
    _kv_postprocess_layout(cache, indices)
else:
    _kv_postprocess_blksize(cache, indices, block_size_ratio)

Test Plan

Test with heterogeneous KV_layout + heterogeneous block_size

DECODER_KV_LAYOUT=NHD PREFILL_BLOCK_SIZE=16 DECODE_BLOCK_SIZE=64 bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh 2>&1 | tee nixl_hetero_layout_blksize.log

GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES="deepseek-ai/DeepSeek-V2-Lite-Chat" DECODER_KV_LAYOUT=NHD PREFILL_BLOCK_SIZE=16 DECODE_BLOCK_SIZE=64 bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh 2>&1 | tee nixl_hetero_layout_blksize_MLA.log

=> Passed accuracy test

Test with heterogeneous KV_layout

DECODER_KV_LAYOUT=NHD PREFILL_BLOCK_SIZE=16 DECODE_BLOCK_SIZE=16 bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh 2>&1 | tee nixl_hetero_layout_blksize.log

GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES="deepseek-ai/DeepSeek-V2-Lite-Chat" DECODER_KV_LAYOUT=NHD PREFILL_BLOCK_SIZE=16 DECODE_BLOCK_SIZE=16 bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh 2>&1 | tee nixl_hetero_layout_blksize_MLA.log

=> Passed accuracy test

Test with heterogeneous heterogeneous block_size

PREFILL_BLOCK_SIZE=16 DECODE_BLOCK_SIZE=64 bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh 2>&1 | tee nixl_hetero_layout_blksize.log

GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES="deepseek-ai/DeepSeek-V2-Lite-Chat" PREFILL_BLOCK_SIZE=16 DECODE_BLOCK_SIZE=64 bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh 2>&1 | tee nixl_hetero_layout_blksize_MLA.log

=> Passed accuracy test

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Note

Cursor Bugbot is generating a summary for commit bff8eaa. Configure here.


Note

Consolidates decoder-side KV post-processing into a single path with shared utils to handle heterogeneous block size and layout.

  • Adds kv_postprocess_blksize_on_receive, kv_postprocess_layout_on_receive, and kv_postprocess_blksize_and_layout_on_receive in utils.py
  • Replaces permute_device_kv and blocksize_post_process with unified post_process_device_kv_on_receive in nixl_connector.py, selecting behavior based on enable_permute_local_kv and block_size_ratio
  • Updates get_finished to batch block IDs per ratio and invoke the new post-process; minor logging and tensor creation tweaks

Written by Cursor Bugbot for commit bff8eaa. This will update automatically on new commits. Configure here.


Note

Cursor Bugbot is generating a summary for commit 15ff574. Configure here.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the post-processing logic for heterogeneous BlockSize and kv_layout, which is a good direction for code cleanup. However, the implementation introduces several issues. There are critical bugs in the tensor reshape operations within the new helper functions (_kv_postprocess_layout, _kv_postprocess_blksize, and _kv_postprocess_blksize_and_layout), which will likely lead to runtime errors or corrupted KV cache data. Additionally, there's a redundant index_select operation that should be removed to improve performance. These issues need to be addressed to ensure the correctness and efficiency of the new implementation.

@mergify
Copy link
Copy Markdown

mergify bot commented Dec 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 16, 2025
@xuechendi xuechendi force-pushed the dev/decode_KV_post_process branch from edc6d6e to 1753de4 Compare December 17, 2025 22:30
@mergify mergify bot removed the needs-rebase label Dec 17, 2025
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 17, 2025

Hi @xuechendi, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@xuechendi xuechendi force-pushed the dev/decode_KV_post_process branch from 1753de4 to 010e76a Compare December 17, 2025 22:43
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
@xuechendi xuechendi force-pushed the dev/decode_KV_post_process branch from 010e76a to 002a105 Compare December 18, 2025 18:05
@mergify mergify bot removed the needs-rebase label Dec 18, 2025
Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for refactoring this @xuechendi !
Left a few comments, overall this is looking pretty good.

blocks_to_update.permute(0, 2, 1, 3), block_size_ratio
).permute(0, 2, 1, 3)
cache.index_copy_(0, indices, permuted_blocks)
device = sample_cache.device
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this self.device?

split_k_and_v = not (
self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first
)
assert block_size_ratio >= 1, "Only nP < nD supported currently."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could probably use debug log here stating what's being post-processed

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used logger.info_once(), is that ok?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they serve two different purposes, a debug log would provide info on the proceeding of the transfer operation per-request which I think is ok being debug.
info_once may still be useful for the end user, although in theory we could later allow deployments where P1 block_size != P2 block_size=D block_size, hence the log info_once would fall short in reporting that.

Comment on lines +1746 to +1747
):
def _kv_postprocess_blksize(cache, indices, block_size_ratio):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add a short comment at the top of post_process_device_kv_on_receive now that the 3 functions can be moved out

):
block_ids_for_blocksize_post_process[block_size_ratio].append(
meta.local_block_ids
meta.local_physical_block_ids
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok this was a bug then right

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I missed that in previous PR

@mergify
Copy link
Copy Markdown

mergify bot commented Dec 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 19, 2025
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
…ocess

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
@xuechendi xuechendi force-pushed the dev/decode_KV_post_process branch from 1d2394d to f0befd7 Compare December 19, 2025 19:43
@mergify mergify bot removed the needs-rebase label Dec 19, 2025
@xuechendi xuechendi requested a review from NickLucche December 19, 2025 19:46
@xuechendi
Copy link
Copy Markdown
Contributor Author

@NickLucche , Thanks for the review, I have resolved all comments and rebased.

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 8, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 8, 2026
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
@mergify mergify bot removed the needs-rebase label Jan 8, 2026
Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a nit on logging. LGTM , thanks @xuechendi !

split_k_and_v = not (
self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first
)
assert block_size_ratio >= 1, "Only nP < nD supported currently."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they serve two different purposes, a debug log would provide info on the proceeding of the transfer operation per-request which I think is ok being debug.
info_once may still be useful for the end user, although in theory we could later allow deployments where P1 block_size != P2 block_size=D block_size, hence the log info_once would fall short in reporting that.

@xuechendi
Copy link
Copy Markdown
Contributor Author

Just a nit on logging. LGTM , thanks @xuechendi !

Thanks, @NickLucche , you're right, I updated it to logger.debug

@NickLucche NickLucche enabled auto-merge (squash) January 9, 2026 19:21
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 9, 2026
@NickLucche NickLucche merged commit 9457812 into vllm-project:main Jan 9, 2026
56 checks passed
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…nd kv_layout (vllm-project#30275)

Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants