Skip to content

[V0 Deprecation] Refactor kv cache from list to element#37487

Merged
vllm-bot merged 6 commits intomainfrom
wentao-kv_cache-no-list
Mar 24, 2026
Merged

[V0 Deprecation] Refactor kv cache from list to element#37487
vllm-bot merged 6 commits intomainfrom
wentao-kv_cache-no-list

Conversation

@yewentao256
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 commented Mar 18, 2026

Purpose

A follow up for #37195 of removing the virtual engine, this PR further refactor the kv cache from list to element to clean the code

Tests in CI

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 18, 2026
@mergify mergify bot added deepseek Related to DeepSeek models qwen Related to Qwen models rocm Related to AMD ROCm v1 kv-connector labels Mar 18, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Mar 18, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the kv_cache by removing the outer list wrapper, simplifying its structure from a list of one element to just the element itself (a tensor or a tuple of tensors). This change is consistently applied across various components, including attention layers, mamba-based layers, and their corresponding test files. The modifications simplify code by removing unnecessary [0] indexing when accessing the kv_cache. The change in _cleanup_profiling_kv_cache is a good addition that makes the cleanup logic more robust to the different types of kv_cache. The refactoring appears to be correct and improves code clarity.

@hmellor
Copy link
Copy Markdown
Member

hmellor commented Mar 19, 2026

The NIXL failure seems like it might be relevant?

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Comment on lines +184 to +194
@@ -185,15 +185,13 @@ def inject_kv_into_layer(
if kv_cache_attr is None:
continue

kv_cache_layer = kv_cache_attr[0]

filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
if isinstance(attn_metadata, dict):
inject_kv_into_layer(
kv_cache_layer,
kv_cache_attr,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rename this to kv_cache_layer?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks! And also fix the previous CI issue

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 enabled auto-merge (squash) March 20, 2026 22:09
@vllm-bot vllm-bot merged commit c59a132 into main Mar 24, 2026
81 of 84 checks passed
@vllm-bot vllm-bot deleted the wentao-kv_cache-no-list branch March 24, 2026 03:10
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Mar 24, 2026
RhizoNymph pushed a commit to RhizoNymph/vllm that referenced this pull request Mar 26, 2026
HenryTangDev pushed a commit to HenryTangMain/vllm that referenced this pull request Mar 27, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…#37487)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
…#37487)

Signed-off-by: yewentao256 <zhyanwentao@126.com>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…#37487)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…#37487)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models kv-connector qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants