Skip to content

[Model Runner v2] fix pd accuracy#42888

Closed
ZJY0516 wants to merge 2 commits into
vllm-project:mainfrom
ZJY0516:fix-mdrv2-pd
Closed

[Model Runner v2] fix pd accuracy#42888
ZJY0516 wants to merge 2 commits into
vllm-project:mainfrom
ZJY0516:fix-mdrv2-pd

Conversation

@ZJY0516
Copy link
Copy Markdown
Member

@ZJY0516 ZJY0516 commented May 17, 2026

Purpose

FIX #42846
Related error: https://buildkite.com/vllm/ci/builds/66617/canvas?jid=019e38db-bb6b-4ab4-8378-fa658fc52ac8&tab=output

_sync_block_size_with_kernel doubles self.num_blocks and halves self.block_size so NIXL descriptors walk the cache at the kernel block_size. That matches V1's KV cache layout (V1 reshapes the cache to kernel granularity in gpu_model_runner._reshape_kv_cache_tensors), but not V2's — V2 (gpu/attn_utils._reshape_kv_cache) keeps the tensor at logical block_size.

Fix

Skip the num_blocks *= ratio / block_size //= ratio step under V2 and pin _physical_blocks_per_logical_kv_block = 1 so the rest of the NIXL worker (byte accounting in register_kv_caches, descriptor strides in _build_fa_local / _build_fa_remote, _logical_to_kernel_block_ids) stays 1:1 with V2's logical-grain cache layout. V1 path is unchanged.

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516 ZJY0516 added the ready ONLY add when PR is ready to merge/full CI is needed label May 17, 2026
@mergify mergify Bot added bug Something isn't working kv-connector labels May 17, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the KV cache registration logic in the NIXL worker to correctly validate tensor dimensions when the V2 model runner is enabled by using logical block counts. A critical issue was identified where related variables still use physical units for stride and block length calculations, which could lead to incorrect memory addressing and potential out-of-bounds access.

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py Outdated
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d8e8ac794c

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py Outdated
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516 ZJY0516 changed the title [Bugfix] fix dim check in nixl worker [[Model Runner v2]] fix pd accuracy May 18, 2026
@ZJY0516 ZJY0516 changed the title [[Model Runner v2]] fix pd accuracy [Model Runner v2] fix pd accuracy May 18, 2026
@chfeng-cs
Copy link
Copy Markdown
Contributor

chfeng-cs commented May 18, 2026

I filed the original issue #42846 and have a PR #42872 open that takes a different approach — worth comparing the two.

Your fix skips the num_blocks *= ratio / block_size //= ratio step on the NIXL side and pins _physical_blocks_per_logical_kv_block = 1, so NIXL operates at logical-block granularity. This resolves the immediate crash, but the KV cache tensor registered by MRV2 remains at logical shape (1813, 2, 128, ...) rather than the kernel shape (3626, 2, 64, ...) that FlashInfer expects. The block table also stays at logical granularity, so the block IDs passed during P/D transfer are logical rather than kernel IDs.

#42872 instead fixes the reshape path directly: _reshape_kv_cache now computes kernel_num_blocks = num_blocks * (block_size // kernel_block_size) and passes kernel_block_size as the shape block size, so the registered tensor is at kernel granularity from the start. BlockTables is also updated to expand logical block IDs into kernel block IDs at write time via map_to_kernel_blocks.

If I've mischaracterized your approach, happy to be corrected. Happy to discuss which approach is preferable — or whether parts of both are worth combining.

@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented May 18, 2026

This resolves the immediate crash, but the KV cache tensor registered by MRV2 remains at logical shape (1813, 2, 128, ...) rather than the kernel shape (3626, 2, 64, ...) that FlashInfer expects. The block table also stays at logical granularity, so the block IDs passed during P/D transfer are logical rather than kernel IDs.

I believe this matches V2's intended design — keep the KV cache at the logical shape (rather than reshape to kernel grain like V1 does) and defer any kernel-grain conversion to attention metadata construction.

The PR also passes the accuracy tests in CI

@chfeng-cs
Copy link
Copy Markdown
Contributor

chfeng-cs commented May 18, 2026

Thanks for looking into this. I think the key question is whether MRV2 actually has a later logical->kernel block conversion for the dense FlashInfer path.

From what I can see, the FlashInfer metadata builder only uses the block size from its kv_cache_spec:

  • AttentionGroup.create_metadata_builders() only changes the spec block size if kernel_block_size is explicitly passed in.
  • FlashInferMetadataBuilder then sets page_size = self.kv_cache_spec.block_size.
  • The block table is passed through CommonAttentionMetadata into FlashInfer/TRTLLM paths directly.

So unless MRV2 passes kernel_block_size into the metadata builder and reshapes the KV cache accordingly, the dense FlashInfer path still operates as if page_size is 128. I do not see a separate metadata-time conversion that turns logical block ids into kernel block ids for this path.

That is why I am hesitant about fixing this in the NIXL worker. Forcing _physical_blocks_per_logical_kv_block = 1 makes NIXL accept the current logical KV cache shape, but it does not address the underlying mismatch between:

  • KV manager block size: 128
  • FlashInfer supported kernel block size: 64
  • expected KV cache view / block table granularity for the backend

In other words, NIXL is probably just the first component that asserts on the mismatch. The more general fix should be in MRV2 KV cache initialization / block table construction, matching the MRV1 behavior around prepare_kernel_block_sizes().

If there is a specific MRV2 metadata path that performs this logical->kernel conversion for dense FlashInfer, could you point me to it? Thanks.

Copy link
Copy Markdown
Member

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure it's working as is for all cases, especially with hybrid ssm.
I was hoping we could clarify the necessity for this change in MRv2, given it's breaking interface for all connectors, not just nixl.
I am reverting it here #42766

@ZJY0516 ZJY0516 removed the ready ONLY add when PR is ready to merge/full CI is needed label May 18, 2026
@ZJY0516 ZJY0516 closed this May 19, 2026
@njhill njhill added the v2 label May 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working kv-connector v2

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug][CI] NIXL + FlashInfer fails with Qwen3 MRV2 and --block-size 128

4 participants