Skip to content

[Bugfix][KV Transfer] Reject NixlConnector + expandable_segments:True#41237

Open
esmeetu wants to merge 2 commits intovllm-project:mainfrom
esmeetu:fix-nixl-expandable-segments-conflict
Open

[Bugfix][KV Transfer] Reject NixlConnector + expandable_segments:True#41237
esmeetu wants to merge 2 commits intovllm-project:mainfrom
esmeetu:fix-nixl-expandable-segments-conflict

Conversation

@esmeetu
Copy link
Copy Markdown
Member

@esmeetu esmeetu commented Apr 29, 2026

Purpose

NixlConnector pins the KV cache via ibv_reg_mr once at startup. When PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True, PyTorch's CUDA VMM allocator can later remap KV cache virtual addresses to different physical pages — leaving the registered IB rkey pointing at stale physical pages.

The two settings are fundamentally incompatible: ibv_reg_mr expects stable physical mappings for the lifetime of the MR, while expandable_segments explicitly trades that guarantee for memory-pool flexibility.

In practice this combination silently corrupts inter-node KV transfers and only surfaces at the first request, with errors like:

  • IBV_WC_REM_ACCESS_ERR (synd 0x13 vend 0x88 hw_synd 0/0) on the decode side
  • makeXferReq: remote agent '<uuid>' was invalidated in between prepXferDlist and this call
  • NIXL_ERR_REMOTE_DISCONNECT / NIXL_ERR_NOT_FOUND / NIXL_ERR_CANCELED
  • 0 Gbps RDMA traffic despite NIXL appearing healthy at startup; benchmark stalls indefinitely

The prefill side shows zero errors because the allocator misbehavior happens at the CUDA VMM layer below NIXL's awareness. Diagnosing this from the symptom is hours of work; preventing the misconfiguration up front is one if.

This change validates the combination at config-load time in KVTransferConfig.__post_init__ and fails fast with an actionable error.

This affects only NixlConnector — other connectors do not pin GPU memory through ibv_reg_mr, so the same env var is harmless for them and is left untouched.

Test Plan

pytest tests/v1/kv_connector/unit/test_config.py -v

Adds three cases covering:

  1. NixlConnector + expandable_segments:True → raises with a clear message
  2. NixlConnector + expandable_segments:False (or other allocator config) → allowed
  3. Non-NIXL connector + expandable_segments:True → unaffected

Test Result

Before this change, a real GB200 NVL72 1P/1D run with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True on the prefill side accepts the config, allocates and registers KV memory, then produces 66 NIXL transfer failures at the first request and the benchmark hangs at 12/320 progress. Removing expandable_segments:True (the only change) fixes it: 0 NIXL errors, benchmark progresses normally to 132/320 within ~2 minutes.

After this change, the same misconfiguration is rejected at startup with:

ValueError: NixlConnector is incompatible with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True. PyTorch's CUDA VMM allocator can remap KV cache virtual addresses to different physical pages, invalidating the IB memory regions registered by NIXL. Unset expandable_segments:True (or remove it from PYTORCH_CUDA_ALLOC_CONF).

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added v1 bug Something isn't working kv-connector labels Apr 29, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a validation check to prevent the use of 'expandable_segments:True' in 'PYTORCH_CUDA_ALLOC_CONF' when using 'NixlConnector', as it causes memory remapping that invalidates RDMA registrations. The reviewer suggested a more robust way to parse the configuration string to avoid potential false positives with simple substring matching.

Comment thread vllm/config/kv_transfer.py Outdated
Comment on lines +117 to +118
conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
if "expandable_segments:True" in conf:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using os.environ.get to check for a substring in a comma-separated configuration string is fragile. If PYTORCH_CUDA_ALLOC_CONF contains expandable_segments:True,other_config:1, the check if "expandable_segments:True" in conf works, but if it contains expandable_segments:True1 or similar, it might produce false positives. It is safer to parse the configuration string properly.

conf_dict = dict(item.split(':') for item in os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '').split(',') if ':' in item)
            if conf_dict.get('expandable_segments') == 'True':

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 863fe56a3b

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread vllm/config/kv_transfer.py Outdated
Comment on lines +117 to +118
conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
if "expandable_segments:True" in conf:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Validate expandable_segments from PYTORCH_ALLOC_CONF too

This guard only reads PYTORCH_CUDA_ALLOC_CONF, so it is bypassed when users set PYTORCH_ALLOC_CONF (the primary allocator env var; PYTORCH_CUDA_ALLOC_CONF is just an alias). In that case NixlConnector still starts with expandable_segments:True, and the exact RDMA corruption/hang this change is meant to prevent can still occur. Please check both env var names (and add a test for the PYTORCH_ALLOC_CONF path) so the fail-fast behavior is reliable.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this might be true for more than just nixl.
Let me do a quick check for other cases, but thanks a lot for the great work bringing this up for PD @esmeetu !

@esmeetu
Copy link
Copy Markdown
Member Author

esmeetu commented Apr 29, 2026

@NickLucche , I forgot this #40812. So we must enable sleep mode to make expandable segments work.

NixlConnector pins KV cache memory once via ibv_reg_mr at startup. When
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True, PyTorch's CUDA VMM
allocator can later remap KV cache virtual addresses to different physical
pages, leaving the registered IB rkey pointing at stale physical pages.
This produces RDMA failures at runtime: IBV_WC_REM_ACCESS_ERR (synd 0x13),
"remote agent invalidated in between prepXferDlist and this call",
and NIXL_ERR_REMOTE_DISCONNECT.

The two settings are fundamentally incompatible in general: ibv_reg_mr
expects stable physical mappings for the lifetime of the MR, while
expandable_segments explicitly trades that guarantee for memory-pool
flexibility.

Sleep mode is exempt: CuMemAllocator.use_memory_pool toggles
expandable_segments off around its pool (see vllm-project#40812), so the KV cache
allocated within that context lands on stable physical pages even when
the env var is set globally.

Validate this combination at config-load time so it fails fast with an
actionable message instead of silently producing RDMA errors that surface
only at the first inter-node KV transfer.

Signed-off-by: inf-yasong <yasong.wang@inferact.ai>
@esmeetu esmeetu force-pushed the fix-nixl-expandable-segments-conflict branch from 863fe56 to 269932e Compare April 29, 2026 14:52
@esmeetu
Copy link
Copy Markdown
Member Author

esmeetu commented Apr 29, 2026

I just add new gate logic to check both PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" and --enable-sleep-model. PTAL.

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dtcccc I suppose mooncake would have a similar issue given this invalidates kv addresses, and similarly moriioo (guessing, not really familiar with moriioo).

@esmeetu I would raise this check to "if any connector is present", given we can't guarantee that for OOT connectors.

Generally speaking, given my current (perhaps superficial) understanding of this allocator, I am actually afraid using it might be sweeping under the rug a deeper issue that may arise when using a static allocator.
I believe ideally as inference engine we should be able to predict our allocations and the added dynamic factor here might actually work against us.
We could also consider disabling it altogether for kv cache allocations.

@dtcccc
Copy link
Copy Markdown
Contributor

dtcccc commented Apr 30, 2026

@dtcccc I suppose mooncake would have a similar issue given this invalidates kv addresses, and similarly moriioo (guessing, not really familiar with moriioo).

Yes, MooncakeConnector would be affected the same way.

@NickLucche
Copy link
Copy Markdown
Collaborator

@esmeetu can we change the connector check to "any connector" rather than nixl-specific?

Per @NickLucche's review feedback: NixlConnector is not the only
connector that pins KV cache memory and gets corrupted when PyTorch's
CUDA VMM allocator remaps physical pages. MooncakeConnector has the
same vulnerability (confirmed by @dtcccc), and we can't enumerate
every in-tree and out-of-tree connector that does similar pinning.

Apply the rejection whenever any KV connector is configured (rather
than NixlConnector-specifically). Sleep-mode exemption and the
expandable_segments:True trigger are unchanged.

Tests are reorganized: parametrize the rejection test over multiple
connector names (including a hypothetical OOT one), keep the
sleep-mode and benign-alloc-conf cases, and replace the
'non-NIXL connector is allowed' case with a 'no connector is allowed'
case (since the new behavior rejects non-NIXL connectors too).

Signed-off-by: inf-yasong <yasong.wang@inferact.ai>
Signed-off-by: esmeetu <jasonailu87@gmail.com>
@esmeetu
Copy link
Copy Markdown
Member Author

esmeetu commented May 5, 2026

@NickLucche SGTM. Updated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working kv-connector v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants