Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 50 additions & 25 deletions csrc/cuda_view.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,58 @@
#include <torch/cuda.h>
#include <cuda_runtime.h>

// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
// memory, and that UVA (Unified Virtual Addressing) is enabled.
// This function assumes that `cpu_tensor` is a CPU tensor,
// and that UVA (Unified Virtual Addressing) is enabled.
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems to mix two functionalities in one function. I'd prefer to separate them. keep the original get_cuda_view_from_cpu_tensor untouched, and have another alloc_pinned_cpu_tensor_and_get_cuda_view(num_bytes) function.

Copy link
Copy Markdown
Contributor Author

@wzhao18 wzhao18 Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing. I personally find it cleaner to keep a unified function that "creates a CUDA view from a cpu tensor". The only difference is whether the CPU tensor is already allocated with pinned memory. In both cases, the returned CUDA view keeps reference of the CPU buffer.

Let me know if you insist on having separate functions. I can update it that way.

TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");

// Get raw host pointer from CPU tensor
void* host_ptr = cpu_tensor.data_ptr();
// handle empty tensor
if (cpu_tensor.numel() == 0) {
return torch::empty(cpu_tensor.sizes(),
cpu_tensor.options().device(torch::kCUDA));
}

if (cpu_tensor.is_pinned()) {
// If CPU tensor is pinned, directly get the device pointer.
void* host_ptr = const_cast<void*>(cpu_tensor.data_ptr());
void* device_ptr = nullptr;
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
TORCH_CHECK(err == cudaSuccess,
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));

return torch::from_blob(
device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(),
[base = cpu_tensor](void*) {}, // keep cpu tensor alive
cpu_tensor.options().device(torch::kCUDA));
}

// If CPU tensor is not pinned, allocate a new pinned memory buffer.
torch::Tensor contiguous_cpu = cpu_tensor.contiguous();
size_t nbytes = contiguous_cpu.nbytes();

void* host_ptr = nullptr;
cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
if (err != cudaSuccess) {
AT_ERROR("cudaHostAlloc failed: ", cudaGetErrorString(err));
}

err = cudaMemcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes,
cudaMemcpyDefault);
if (err != cudaSuccess) {
cudaFreeHost(host_ptr);
AT_ERROR("cudaMemcpy failed: ", cudaGetErrorString(err));
}

// Get a device pointer corresponding to the pinned host memory
void* device_ptr = nullptr;
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
TORCH_CHECK(err == cudaSuccess,
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));

// We'll use the same sizes, strides, and dtype as the CPU tensor.
// TODO: check if layout is respected.
auto sizes = cpu_tensor.sizes();
auto strides = cpu_tensor.strides();
auto options = cpu_tensor.options().device(torch::kCUDA);

// use default no-op deleter, since the memory is owned by the original CPU
// tensor
torch::Tensor cuda_tensor =
torch::from_blob(device_ptr, sizes, strides, options);

TORCH_CHECK(cuda_tensor.device().is_cuda(),
"Resulting tensor is not on CUDA device");

return cuda_tensor;
}
err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
if (err != cudaSuccess) {
cudaFreeHost(host_ptr);
AT_ERROR("cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
}

auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };

return torch::from_blob(device_ptr, contiguous_cpu.sizes(),
contiguous_cpu.strides(), deleter,
contiguous_cpu.options().device(torch::kCUDA));
}
23 changes: 21 additions & 2 deletions tests/basic_correctness/test_cpu_offload.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest

from ..utils import compare_two_settings


def test_cpu_offload():
@pytest.mark.parametrize("disable_pin_memory", [False, True])
@pytest.mark.parametrize("disable_uva", [False, True])
def test_cpu_offload(disable_pin_memory, disable_uva):
env_vars = {
"VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY": str(int(disable_pin_memory)),
"VLLM_WEIGHT_OFFLOADING_DISABLE_UVA": str(int(disable_uva)),
}

args = ["--cpu-offload-gb", "1"]

# cuda graph only works with UVA offloading
if disable_uva:
args.append("--enforce-eager")

compare_two_settings(
"hmellor/tiny-random-LlamaForCausalLM", [], ["--cpu-offload-gb", "1"]
model="hmellor/tiny-random-LlamaForCausalLM",
arg1=[],
arg2=args,
env1=None,
env2=env_vars,
)
10 changes: 10 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@
VLLM_USE_V2_MODEL_RUNNER: bool = False
VLLM_LOG_MODEL_INSPECTION: bool = False
VLLM_DEBUG_MFU_METRICS: bool = False
VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY: bool = False
VLLM_WEIGHT_OFFLOADING_DISABLE_UVA: bool = False
VLLM_DISABLE_LOG_LOGO: bool = False
VLLM_LORA_DISABLE_PDL: bool = False

Expand Down Expand Up @@ -1542,6 +1544,14 @@ def _get_or_set_default() -> str:
"VLLM_DEBUG_MFU_METRICS": lambda: bool(
int(os.getenv("VLLM_DEBUG_MFU_METRICS", "0"))
),
# Disable using pytorch's pin memory for CPU offloading.
"VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY": lambda: bool(
int(os.getenv("VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY", "0"))
),
# Disable using UVA (Unified Virtual Addressing) for CPU offloading.
"VLLM_WEIGHT_OFFLOADING_DISABLE_UVA": lambda: bool(
int(os.getenv("VLLM_WEIGHT_OFFLOADING_DISABLE_UVA", "0"))
),
Comment on lines +1547 to +1554
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Since we also do kv offloading, it might be better to explicitly say weight offloading

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the review. I will rename them to VLLM_WEIGHT_OFFLOADING...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated.

# Disable logging of vLLM logo at server startup time.
"VLLM_DISABLE_LOG_LOGO": lambda: bool(int(os.getenv("VLLM_DISABLE_LOG_LOGO", "0"))),
# Disable PDL for LoRA, as enabling PDL with LoRA on SM100 causes
Expand Down
40 changes: 23 additions & 17 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch import nn
from typing_extensions import assert_never

import vllm.envs as envs
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention, MLAAttention
Expand All @@ -25,6 +26,7 @@
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.tracing import instrument
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor

logger = init_logger(__name__)

Expand Down Expand Up @@ -111,7 +113,8 @@ def process_weights_after_loading(
):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
with device_loading_context(module, target_device):
module.process_weights_after_loading(model_config.dtype)

# Needed for torchao model reloading via model.reload_weights
# @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights`
Expand All @@ -127,38 +130,41 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
return

original_device_states: dict[str, torch.device] = {}
uva_offloaded_parameters: list[str] = []

# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
if getattr(p, "_vllm_is_uva_offloaded", False):
uva_offloaded_parameters.append(name)
# Parameters already on target device are not touched

try:
yield module

finally:
use_pin_memory = (
is_pin_memory_available()
and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
)
# Restore parameters to their original devices, ignoring new parameters
pin_memory = is_pin_memory_available()
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
if original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
p.data = cpu_data
else:
p.data = p.data.to(original_device)
# New parameters or parameters already on target device are untouched
p.data = p.data.to(original_device)

# parameter is UVA offloaded, but was replaced with a new device tensor
# re-offload it to CPU using UVA
if name in uva_offloaded_parameters and not getattr(
p, "_vllm_is_uva_offloaded", False
):
cpu_data = p.data.to(device="cpu")
if use_pin_memory:
cpu_data = cpu_data.pin_memory()
p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
p._vllm_is_uva_offloaded = True


_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
Expand Down
35 changes: 17 additions & 18 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.nn.modules.module import register_module_module_registration_hook
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -633,11 +634,10 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
return module

pin_memory = is_pin_memory_available()
uva_available = is_uva_available()

assert uva_available, "V1 CPU offloading requires uva (pin memory) support"
uva_offloading = True
pin_memory = (
is_pin_memory_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
)
uva_offloading = is_uva_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_UVA

# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
Expand All @@ -648,22 +648,16 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
# one module might have some parameters offloaded and some not
break

# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
cpu_data = p.data.to(device="cpu")
if pin_memory:
cpu_data = cpu_data.pin_memory()

if not uva_offloading:
p.data = cpu_data
else:
# keep the cpu data alive
p._vllm_offloaded_cpu_data = cpu_data
p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
p._vllm_is_uva_offloaded = True

_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True

Expand All @@ -678,7 +672,12 @@ def forward(*args, **kwargs):
k: v.to(device, non_blocking=True)
for k, v in module.state_dict().items()
}
output = functional_call(module, device_state, args=args, kwargs=kwargs)

# set `tie_weights=False` as tied weights in original model
# become untied when calling .to(device) individually
output = functional_call(
module, device_state, args=args, kwargs=kwargs, tie_weights=False
)
module.forward = forward
return output

Expand Down
10 changes: 8 additions & 2 deletions vllm/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,12 +678,18 @@ def get_accelerator_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tens
"""
Get an accelerator view of a CPU tensor using Unified Virtual Addressing (UVA).
"""
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
from vllm.platforms import current_platform

if current_platform.is_xpu():
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
return torch.ops._C.get_xpu_view_from_cpu_tensor(cpu_tensor)
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
elif current_platform.is_cuda():
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
else:
raise ValueError(
f"`get_accelerator_view_from_cpu_tensor` is currently "
f"not supported in: {current_platform.device_name}"
)


# Helper function used in testing.
Expand Down