Skip to content

Enable Cross layers KV cache layout at NIXL Connector V2#33339

Merged
NickLucche merged 94 commits intovllm-project:mainfrom
liranschour:nixl_kv_cont_cross_layers
Feb 5, 2026
Merged

Enable Cross layers KV cache layout at NIXL Connector V2#33339
NickLucche merged 94 commits intovllm-project:mainfrom
liranschour:nixl_kv_cont_cross_layers

Conversation

@liranschour
Copy link
Copy Markdown
Contributor

@liranschour liranschour commented Jan 29, 2026

Purpose

Enable NIXL Connector to us the new continuous cross layer KV cache layout described in RFC and implemented in #27743

Demonstrate performance improvement of more the 2x in Tok/sec and TTFT due to dramatic reduction of fragmentation of transfer buffers.

Tested with P!=D with run_accuracy_test.sh P=1 D=2

branch num reqs input len TTFT ITL tok/s Desc/transfer
main 1000 16 18756.42 5.35 5288.41 56
kv_cross_layers 1000 16 8494.44 8.74 8572.56 1
main 128 1024 1660.84 9.40 37945.20 3528
kv_cross_layers 128 1024 686.98 9.26 55418.76 1
main 128 10240 11140.52 42.78 62339.74 34000
kv_cross_layers 128 10240 5226.71 14.41 117631.48 422

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

liranschour and others added 30 commits December 7, 2025 13:31
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Co-authored-by: Or Ozeri <or@ozery.com>
Signed-off-by: liranschour <liranschour@users.noreply.github.com>
Co-authored-by: Or Ozeri <or@ozery.com>
Signed-off-by: liranschour <liranschour@users.noreply.github.com>
Co-authored-by: Or Ozeri <or@ozery.com>
Signed-off-by: liranschour <liranschour@users.noreply.github.com>
Co-authored-by: Or Ozeri <or@ozery.com>
Signed-off-by: liranschour <liranschour@users.noreply.github.com>
Co-authored-by: Or Ozeri <or@ozery.com>
Signed-off-by: liranschour <liranschour@users.noreply.github.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
"CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
"CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"CROSS_LAYERS_BLOCKS=True GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can refactor to just add CROSS_LAYERS_BLOCKS=True to tp_configs, assuming all above are compatible.

Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, only one nit

Comment on lines +65 to +67
else
echo "CROSS_LAYERS_BLOCKS is not set, skipping --enable-cross-layers runs."
fi
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no need to echo out disabled options imo

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed that echo

Comment on lines +365 to +366
if current_platform.device_type != "cpu"
else -2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: this is untested on cpu right?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we we need this special case.
We should be able to correctly set block_size_position using test_shape even when running on CPU.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this special case.
Setting block_size_position only by kv_cache_shape.

expected_base_addrs: list[int]
expected_num_entries: int
kv_caches: dict[str, torch.Tensor]
if connector.prefer_cross_layer_blocks:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that connector.prefer_cross_layer_blocks was correctly parsed of the test enable_cross_layers parameter.
Can you assert that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an assert for that

Comment on lines +365 to +366
if current_platform.device_type != "cpu"
else -2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we we need this special case.
We should be able to correctly set block_size_position using test_shape even when running on CPU.

Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: Liran Schour <lirans@il.ibm.com>
@orozery orozery added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 3, 2026
liranschour and others added 2 commits February 3, 2026 13:46
Signed-off-by: Liran Schour <lirans@il.ibm.com>
@liranschour liranschour requested a review from orozery February 3, 2026 14:17
@NickLucche NickLucche enabled auto-merge (squash) February 4, 2026 22:55
Signed-off-by: Liran Schour <lirans@il.ibm.com>
@NickLucche NickLucche merged commit 8322d4e into vllm-project:main Feb 5, 2026
48 checks passed
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…t#33339)

Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: liranschour <liranschour@users.noreply.github.com>
Co-authored-by: Or Ozeri <or@ozery.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…t#33339)

Signed-off-by: Liran Schour <lirans@il.ibm.com>
Signed-off-by: liranschour <liranschour@users.noreply.github.com>
Co-authored-by: Or Ozeri <or@ozery.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants