Skip to content

[Feature] Support CPU Offloading without Pytorch Pinned Memory that leads to doubled allocation#32993

Merged
vllm-bot merged 17 commits intovllm-project:mainfrom
wzhao18:wzhao/kimi-k2-dgx-offload
Feb 13, 2026
Merged

[Feature] Support CPU Offloading without Pytorch Pinned Memory that leads to doubled allocation#32993
vllm-bot merged 17 commits intovllm-project:mainfrom
wzhao18:wzhao/kimi-k2-dgx-offload

Conversation

@wzhao18
Copy link
Contributor

@wzhao18 wzhao18 commented Jan 24, 2026

Purpose

Background:
Current PyTorch pin_memory implementation rounds allocation to the next power of 2 (see Issue #150517, PR #171662). In the worst case, this results in 2x the expected CPU memory usage. For instance, running nvidia/Kimi-K2-Thinking-NVFP4 on a system with one GB300 GPU and 500 GB of CPU memory results in OOM, even though sufficient memory is available.

Script to demonstrate the issue:

import torch

import sys
import psutil
gb_size = int(sys.argv[1])

print(f"Allocating {gb_size} GB")
before = psutil.virtual_memory().used
torch.empty(gb_size * 1024 ** 3, dtype=torch.uint8, device="cpu", pin_memory=True)

after = psutil.virtual_memory().used
print(f"Memory used: {(after - before) / 1024 ** 3} GB")

Allocating 257 GB of pinned memory leads to 512 GB CPU allocation.

Changes:
This PR provides a workaround to get around using Pytorch pin memory to use UVA-based (unified virtual addressing) CPU offloading. In addition, it fixes various issues on the non-UVA offloading code path (explicit transfer between CPU and GPU).

There is a pytorch PR that is looking to address the doubled pin memory usage issue. This PR provides a temporary workaround for the problem. The workaround is enabled through two environment variables:

  • VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY: disable using pytorch pin memory in CPU offloading
  • VLLM_WEIGHT_OFFLOADING_DISABLE_UVA: disable using UVA in CPU offloading

When neither env vars are set, the original offloading logic (pytorch pinned memory + UVA) is used.

Test Plan

Added testing for the env-var enabled code paths in tests/basic_correctness/test_cpu_offload.py.

Also tested nvidia/Kimi-K2-Thinking-NVFP4 on one GB300 GPU and 500 GB CPU memory:

On current main:

python3 -m vllm.entrypoints.openai.api_server \
	--model nvidia/Kimi-K2-Thinking-NVFP4 \
	--trust-remote-code \
        --cpu-offload-gb 350
File "/vllm/model_executor/models/utils.py", line 707, in make_layers
^[[core.py:935]     maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
 [core.py:935]   File "/vllm/model_executor/models/utils.py", line 652, in maybe_offload_to_cpu
 [core.py:935]     cpu_data = torch.empty_strided(
 [core.py:935]                ^^^^^^^^^^^^^^^^^^^^
 [core.py:935]   File "/vllm/.venv/lib/python3.12/site-packages/torch/utils/_device.py", line 109, in __torch_function__
 [core.py:935]     return func(*args, **kwargs)
 [core.py:935]            ^^^^^^^^^^^^^^^^^^^^^
 [core.py:935] torch.AcceleratorError: CUDA error: out of memory

Test Result

With the changes, the following command will work.

VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY=1 \
python3 -m vllm.entrypoints.openai.api_server \
	--model nvidia/Kimi-K2-Thinking-NVFP4 \
	--trust-remote-code \
        --cpu-offload-gb 350
[gpu_model_runner.py:3947] Model loading took 199.97 GiB memory and 173.309022 seconds

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the nvidia label Jan 24, 2026
@wzhao18 wzhao18 changed the title [Feature] Support CPU Offloading without Pytorch Pinned Memory that leads to 2x allocation [Feature] Support CPU Offloading without Pytorch Pinned Memory that leads to doubled allocation Jan 24, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable workaround for the excessive memory allocation issue with PyTorch's pin_memory during CPU offloading. The approach of using mmap and cudaHostRegister for non-pinned tensors is well-implemented and controlled via new environment variables. The related code changes across different files are consistent and logical. However, I've identified a critical issue in the C++/CUDA implementation where an error condition is not handled, which could lead to crashes.

@mergify
Copy link

mergify bot commented Jan 24, 2026

Hi @wzhao18, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

@youkaichao
Copy link
Member

Current PyTorch pin_memory implementation rounds allocation to the next power of 2

that's an interesting finding. I think we can directly remove pytorch pin_memory in the offloading code path. we don't need the caching behavior in pytorch for the offloading functionality.

// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
// memory, and that UVA (Unified Virtual Addressing) is enabled.
// This function assumes that `cpu_tensor` is a CPU tensor,
// and that UVA (Unified Virtual Addressing) is enabled.
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems to mix two functionalities in one function. I'd prefer to separate them. keep the original get_cuda_view_from_cpu_tensor untouched, and have another alloc_pinned_cpu_tensor_and_get_cuda_view(num_bytes) function.

Copy link
Contributor Author

@wzhao18 wzhao18 Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing. I personally find it cleaner to keep a unified function that "creates a CUDA view from a cpu tensor". The only difference is whether the CPU tensor is already allocated with pinned memory. In both cases, the returned CUDA view keeps reference of the CPU buffer.

Let me know if you insist on having separate functions. I can update it that way.

cudaHostRegister(host_ptr, aligned_size, cudaHostRegisterDefault);
if (err != cudaSuccess) {
munmap(host_ptr, aligned_size);
AT_ERROR("cudaHostRegister failed: ", cudaGetErrorString(err));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is cudaHostAlloc enough? do we need to use system mmap and then use cudaHostRegister?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try cudaHostAlloc and get back here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have changed implementation to use cudaHostAlloc instead of system mmap. Thanks for the suggestion.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this direction looks good to me!

@mergify
Copy link

mergify bot commented Jan 27, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wzhao18.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 27, 2026
@wzhao18 wzhao18 force-pushed the wzhao/kimi-k2-dgx-offload branch from e9dda7f to 6ff4cf9 Compare January 27, 2026 17:54
@mergify mergify bot removed the needs-rebase label Jan 27, 2026
@wzhao18 wzhao18 force-pushed the wzhao/kimi-k2-dgx-offload branch from 6ff4cf9 to 7144115 Compare January 27, 2026 17:59
@wzhao18 wzhao18 requested a review from youkaichao January 29, 2026 22:07
@wzhao18
Copy link
Contributor Author

wzhao18 commented Jan 30, 2026

@youkaichao Thanks for the review. I have updated the PR based on your feedback. I am still weighing whether we should split get_cuda_view_from_cpu_tensor into two functions. Could you take another look and let me know which approach you prefer?

Besides that, is there anything else missing before this is ready to merge?

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 3, 2026
@wzhao18 wzhao18 force-pushed the wzhao/kimi-k2-dgx-offload branch from 7144115 to 619ca40 Compare February 4, 2026 16:49
@wzhao18
Copy link
Contributor Author

wzhao18 commented Feb 4, 2026

@youkaichao @mgoin The failing CI tests seem to be flaky. The quantized model test fails on main too. Could you rerun the other failed tests? Thanks

@wzhao18 wzhao18 force-pushed the wzhao/kimi-k2-dgx-offload branch from 619ca40 to 9c1cdf1 Compare February 6, 2026 01:48
@mergify
Copy link

mergify bot commented Feb 6, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wzhao18.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 6, 2026
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
… loading

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
…nsfer

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
@AndreasKaratzas
Copy link
Collaborator

AndreasKaratzas commented Feb 13, 2026

Working on a fix. Will update this comment with upcoming PR.

EDIT 1: Proposed fix by @wzhao18 looks like it is working. Still testing ...

EDIT 2: All tests passed on ROCm. Pushed #34543

vllm-bot pushed a commit that referenced this pull request Feb 14, 2026
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
athrael-soju pushed a commit to athrael-soju/vllm that referenced this pull request Feb 14, 2026
 (vllm-project#34543)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Athrael Soju <athrael.soju@gmail.com>
athrael-soju pushed a commit to athrael-soju/vllm that referenced this pull request Feb 14, 2026
…eads to doubled allocation (vllm-project#32993)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Athrael Soju <athrael.soju@gmail.com>
athrael-soju pushed a commit to athrael-soju/vllm that referenced this pull request Feb 14, 2026
 (vllm-project#34543)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Athrael Soju <athrael.soju@gmail.com>
wzhao18 added a commit to wzhao18/vllm that referenced this pull request Feb 18, 2026
…eads to doubled allocation (vllm-project#32993)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
wzhao18 pushed a commit to wzhao18/vllm that referenced this pull request Feb 18, 2026
 (vllm-project#34543)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
eldarkurtic pushed a commit to eldarkurtic/vllm that referenced this pull request Feb 19, 2026
…eads to doubled allocation (vllm-project#32993)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
eldarkurtic pushed a commit to eldarkurtic/vllm that referenced this pull request Feb 19, 2026
 (vllm-project#34543)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
@minosfuture
Copy link
Contributor

Hi @wzhao18, thanks for this work!

IIUC, with the change to get_cuda_view_from_cpu_tensor for non-pinned tensors, the original CPU tensor's memory is effectively abandoned — the data now lives in a freshly allocated pinned buffer. The original cpu_tensor still exists and holds its old address, but the returned CUDA view points to completely different memory. Any subsequent writes to the original CPU tensor would not be reflected in the CUDA view.

I find this a bit confusing. The name says "view" which implies aliased memory (no copy), but for non-pinned tensors it silently copies into a new pinned buffer. The caller would reasonably expect that mutations to the original CPU tensor are visible through the "view" — that's true for pinned tensors but false for non-pinned ones.

Would it make more sense to rename this func, or separate the handling of pinned-/non-pinned tensor to two functions? The latter would be better for backward compatibility.

@wzhao18
Copy link
Contributor Author

wzhao18 commented Feb 19, 2026

Hi @minosfuture. Thanks for pointing out this concern. I agree this could lead to different behaviors (aliasing or non-aliasing) of get_cuda_view_from_cpu_tensor depending on whether the input tensor is pinned or not. I will create a PR to separate the handling of pinned-/non-pinned tensor for offloading to two functions.

ZJY0516 pushed a commit to ZJY0516/vllm that referenced this pull request Feb 23, 2026
…eads to doubled allocation (vllm-project#32993)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
ZJY0516 pushed a commit to ZJY0516/vllm that referenced this pull request Feb 23, 2026
 (vllm-project#34543)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
…eads to doubled allocation (vllm-project#32993)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…eads to doubled allocation (vllm-project#32993)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
@ehfd
Copy link
Contributor

ehfd commented Mar 9, 2026

Just for people where the environment variables don't work; they are VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY and VLLM_WEIGHT_OFFLOADING_DISABLE_UVA.

@wzhao18
Copy link
Contributor Author

wzhao18 commented Mar 9, 2026

@ehfd Thanks for pointing this out. The flag names were updated during PR review. I will update the description to reflect this change.

@ehfd
Copy link
Contributor

ehfd commented Mar 20, 2026

@wzhao18 Question: Does --offload-backend prefetch work with VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY? Does VLLM_WEIGHT_OFFLOADING_DISABLE_UVA need to be set as well?

Moreover, what is the behavior when VLLM_WEIGHT_OFFLOADING_DISABLE_UVA is combined with --offload-backend uva?

@wzhao18
Copy link
Contributor Author

wzhao18 commented Mar 20, 2026

@wzhao18 Question: Does --offload-backend prefetch work with VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY? Does VLLM_WEIGHT_OFFLOADING_DISABLE_UVA need to be set as well?

Moreover, what is the behavior when VLLM_WEIGHT_OFFLOADING_DISABLE_UVA is combined with --offload-backend uva?

No prefetch offloader is a different backend. The two env vars are within the UVA offloader. So they won't be effective together.

Disabling uva in uva backend is kind of confusing but it essentially loads the weights explicitly rather than automatically through UVA. The naming is kind of a legacy problem. But essentially the code path is a fallback for system without UVA capability.

he-yufeng added a commit to he-yufeng/vllm that referenced this pull request Mar 20, 2026
…h offloader

The prefetch CPU offloader unconditionally allocates pinned memory,
ignoring the VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY env var. On
unified-memory systems like GH200 where pinned memory eats into the
GPU memory pool, this causes OOM during model loading.

The UVA offloader already checks this env var (added in vllm-project#32993). Apply
the same check to the three places in the prefetch backend that
allocate or assert pinned memory:
- _offload_to_cpu_internal(): initial CPU storage allocation
- _update_cpu_storage_from_param(): re-pinning after weight processing
- start_onload_to_static(): pinned memory assertion

Fixes vllm-project#37672

Signed-off-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

8 participants