From 1c2af08bdb582af7e8c88b47fbaf315519e15754 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Wed, 21 Jan 2026 19:51:46 +0000 Subject: [PATCH 01/12] Use for offloading functional call Signed-off-by: wzhao18 --- vllm/model_executor/models/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index c47a6248ad6d..3362c4e2646e 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -678,7 +678,7 @@ def forward(*args, **kwargs): k: v.to(device, non_blocking=True) for k, v in module.state_dict().items() } - output = functional_call(module, device_state, args=args, kwargs=kwargs) + output = functional_call(module, device_state, args=args, kwargs=kwargs, tie_weights=False) module.forward = forward return output From 42d2ca6fd5528fd0a33d40be3cf1a925cbe36ef7 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Wed, 21 Jan 2026 20:15:02 +0000 Subject: [PATCH 02/12] Set UVA offloading to false Signed-off-by: wzhao18 --- vllm/model_executor/models/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 3362c4e2646e..d57a1acf5bdf 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -637,7 +637,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: uva_available = is_uva_available() assert uva_available, "V1 CPU offloading requires uva (pin memory) support" - uva_offloading = True + uva_offloading = False # offload parameters to CPU # use pin_memory if possible, which helps cudagraph capture speed @@ -678,6 +678,9 @@ def forward(*args, **kwargs): k: v.to(device, non_blocking=True) for k, v in module.state_dict().items() } + + # set `tie_weights=False` as tied weights in original model + # become untied when calling .to(device) output = functional_call(module, device_state, args=args, kwargs=kwargs, tie_weights=False) module.forward = forward return output From ac59a01b43448b51c76335c19ee5470ca6ec0562 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Wed, 21 Jan 2026 21:16:16 +0000 Subject: [PATCH 03/12] Fix MLAAttention weights not in target device during processing after loading Signed-off-by: wzhao18 --- vllm/model_executor/model_loader/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 51f62c15b30e..ba46efc803bf 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -111,7 +111,8 @@ def process_weights_after_loading( ): # TODO(lucas): see if there is a way to unify the signatures # of process_weights_after_loading - module.process_weights_after_loading(model_config.dtype) + with device_loading_context(module, target_device): + module.process_weights_after_loading(model_config.dtype) # Needed for torchao model reloading via model.reload_weights # @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights` From f9882faa939a616878cb9ae03344a6b6b62e4fe4 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Wed, 21 Jan 2026 22:03:10 +0000 Subject: [PATCH 04/12] Fix multiple elements write to same memory addr error in gpu->cpu transfer Signed-off-by: wzhao18 --- vllm/model_executor/model_loader/utils.py | 17 +++-------------- vllm/model_executor/models/utils.py | 14 ++++---------- 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index ba46efc803bf..2601ec38403e 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -145,20 +145,9 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device) for name, p in module.named_parameters(): if name in original_device_states: original_device: torch.device = original_device_states[name] - if original_device.type == "cpu": - # `torch.empty_like` does not support `pin_memory` argument - cpu_data = torch.empty_strided( - size=p.data.size(), - stride=p.data.stride(), - dtype=p.data.dtype, - layout=p.data.layout, - device="cpu", - pin_memory=pin_memory, - ) - cpu_data.copy_(p.data) - p.data = cpu_data - else: - p.data = p.data.to(original_device) + p.data = p.data.to(original_device) + if original_device.type == "cpu" and pin_memory: + p.data = p.data.pin_memory() # New parameters or parameters already on target device are untouched diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index d57a1acf5bdf..5d44267d54d6 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -648,16 +648,10 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: # one module might have some parameters offloaded and some not break - # `torch.empty_like` does not support `pin_memory` argument - cpu_data = torch.empty_strided( - size=p.data.size(), - stride=p.data.stride(), - dtype=p.data.dtype, - layout=p.data.layout, - device="cpu", - pin_memory=pin_memory, - ) - cpu_data.copy_(p.data) + cpu_data = p.data.to(device="cpu") + if pin_memory: + cpu_data = cpu_data.pin_memory() + if not uva_offloading: p.data = cpu_data else: From 666a31a53462358f1eef3f78e283f8d80d236924 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Wed, 21 Jan 2026 22:31:11 +0000 Subject: [PATCH 05/12] Not use pin memory for kimi-k2 on DGX station Signed-off-by: wzhao18 --- vllm/model_executor/model_loader/utils.py | 1 + vllm/model_executor/models/utils.py | 1 + 2 files changed, 2 insertions(+) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 2601ec38403e..b65648fe964b 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -142,6 +142,7 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device) finally: # Restore parameters to their original devices, ignoring new parameters pin_memory = is_pin_memory_available() + pin_memory = False for name, p in module.named_parameters(): if name in original_device_states: original_device: torch.device = original_device_states[name] diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 5d44267d54d6..795bc25a7f6f 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -638,6 +638,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: assert uva_available, "V1 CPU offloading requires uva (pin memory) support" uva_offloading = False + pin_memory = False # offload parameters to CPU # use pin_memory if possible, which helps cudagraph capture speed From 559fb8df57b16deb84ef251be15d892a2029d5cb Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Thu, 22 Jan 2026 22:16:58 +0000 Subject: [PATCH 06/12] Support UVA without pytorch pin memory Signed-off-by: wzhao18 --- csrc/cuda_view.cu | 74 +++++++++++++++-------- vllm/envs.py | 10 +++ vllm/model_executor/model_loader/utils.py | 25 ++++++-- vllm/model_executor/models/utils.py | 23 +++---- vllm/utils/torch_utils.py | 10 ++- 5 files changed, 99 insertions(+), 43 deletions(-) diff --git a/csrc/cuda_view.cu b/csrc/cuda_view.cu index 9853fc942bab..c851b1a0cea9 100644 --- a/csrc/cuda_view.cu +++ b/csrc/cuda_view.cu @@ -1,34 +1,58 @@ #include #include #include +#include -// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned -// memory, and that UVA (Unified Virtual Addressing) is enabled. +// This function assumes that `cpu_tensor` is a CPU tensor, +// and that UVA (Unified Virtual Addressing) is enabled. torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) { TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU"); - // Get raw host pointer from CPU tensor - void* host_ptr = cpu_tensor.data_ptr(); + if (cpu_tensor.is_pinned()) { + // If CPU tensor is pinned, directly get the device pointer. + void* host_ptr = const_cast(cpu_tensor.data_ptr()); + void* device_ptr = nullptr; + cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); + TORCH_CHECK(err == cudaSuccess, + "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err)); + + return torch::from_blob( + device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(), + [base = cpu_tensor](void*) {}, // keep cpu tensor alive + cpu_tensor.options().device(torch::kCUDA)); + } + + // If CPU tensor is not pinned, allocate a new pinned memory buffer. + torch::Tensor contiguous_cpu = cpu_tensor.contiguous(); + size_t nbytes = contiguous_cpu.nbytes(); + long page_size = sysconf(_SC_PAGESIZE); + size_t aligned_size = (nbytes + page_size - 1) & ~(page_size - 1); + + void* host_ptr = mmap(nullptr, aligned_size, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + + if (host_ptr == MAP_FAILED) { + AT_ERROR("mmap failed to allocate ", aligned_size, " bytes"); + } + + std::memcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes); + + cudaError_t err = + cudaHostRegister(host_ptr, aligned_size, cudaHostRegisterDefault); + if (err != cudaSuccess) { + munmap(host_ptr, aligned_size); + AT_ERROR("cudaHostRegister failed: ", cudaGetErrorString(err)); + } - // Get a device pointer corresponding to the pinned host memory void* device_ptr = nullptr; - cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); - TORCH_CHECK(err == cudaSuccess, - "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err)); - - // We'll use the same sizes, strides, and dtype as the CPU tensor. - // TODO: check if layout is respected. - auto sizes = cpu_tensor.sizes(); - auto strides = cpu_tensor.strides(); - auto options = cpu_tensor.options().device(torch::kCUDA); - - // use default no-op deleter, since the memory is owned by the original CPU - // tensor - torch::Tensor cuda_tensor = - torch::from_blob(device_ptr, sizes, strides, options); - - TORCH_CHECK(cuda_tensor.device().is_cuda(), - "Resulting tensor is not on CUDA device"); - - return cuda_tensor; -} + cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); + + auto deleter = [host_ptr, aligned_size](void*) { + cudaHostUnregister(host_ptr); + munmap(host_ptr, aligned_size); + }; + + return torch::from_blob(device_ptr, contiguous_cpu.sizes(), + contiguous_cpu.strides(), deleter, + contiguous_cpu.options().device(torch::kCUDA)); +} \ No newline at end of file diff --git a/vllm/envs.py b/vllm/envs.py index caddf0b7642e..be08cbe9332a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -229,6 +229,8 @@ VLLM_USE_V2_MODEL_RUNNER: bool = False VLLM_LOG_MODEL_INSPECTION: bool = False VLLM_DEBUG_MFU_METRICS: bool = False + VLLM_OFFLOADING_DISABLE_PIN_MEMORY: bool = False + VLLM_OFFLOADING_DISABLE_UVA: bool = False VLLM_DISABLE_LOG_LOGO: bool = False VLLM_LORA_DISABLE_PDL: bool = False @@ -1532,6 +1534,14 @@ def _get_or_set_default() -> str: "VLLM_DEBUG_MFU_METRICS": lambda: bool( int(os.getenv("VLLM_DEBUG_MFU_METRICS", "0")) ), + # Disable using pytorch's pin memory for CPU offloading. + "VLLM_OFFLOADING_DISABLE_PIN_MEMORY": lambda: bool( + int(os.getenv("VLLM_OFFLOADING_DISABLE_PIN_MEMORY", "0")) + ), + # Disable using UVA (Unified Virtual Addressing) for CPU offloading. + "VLLM_OFFLOADING_DISABLE_UVA": lambda: bool( + int(os.getenv("VLLM_OFFLOADING_DISABLE_UVA", "0")) + ), # Disable logging of vLLM logo at server startup time. "VLLM_DISABLE_LOG_LOGO": lambda: bool(int(os.getenv("VLLM_DISABLE_LOG_LOGO", "0"))), # Disable PDL for LoRA, as enabling PDL with LoRA on SM100 causes diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index b65648fe964b..8d1cbae82321 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -11,6 +11,8 @@ from torch import nn from typing_extensions import assert_never +import vllm.envs as envs +from vllm.attention.layer import Attention, MLAAttention from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention, MLAAttention @@ -25,6 +27,7 @@ from vllm.model_executor.models.interfaces import SupportsQuant from vllm.tracing import instrument from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor logger = init_logger(__name__) @@ -128,28 +131,40 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device) return original_device_states: dict[str, torch.device] = {} + uva_offloaded_parameters: list[str] = [] # Store original device states and move parameters to GPU if they're on CPU for name, p in module.named_parameters(): if p.device.type == "cpu": original_device_states[name] = p.device p.data = p.data.to(target_device) + if getattr(p, "_vllm_is_uva_offloaded", False): + uva_offloaded_parameters.append(name) # Parameters already on target device are not touched try: yield module finally: + use_pin_memory = ( + is_pin_memory_available() and not envs.VLLM_OFFLOADING_DISABLE_PIN_MEMORY + ) # Restore parameters to their original devices, ignoring new parameters - pin_memory = is_pin_memory_available() - pin_memory = False for name, p in module.named_parameters(): if name in original_device_states: original_device: torch.device = original_device_states[name] p.data = p.data.to(original_device) - if original_device.type == "cpu" and pin_memory: - p.data = p.data.pin_memory() - # New parameters or parameters already on target device are untouched + + # parameter is UVA offloaded, but was replaced with a new device tensor + # re-offload it to CPU using UVA + if name in uva_offloaded_parameters and not getattr( + p, "_vllm_is_uva_offloaded", False + ): + cpu_data = p.data.to(device="cpu") + if use_pin_memory: + cpu_data = cpu_data.pin_memory() + p.data = get_cuda_view_from_cpu_tensor(cpu_data) + p._vllm_is_uva_offloaded = True _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]() diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 795bc25a7f6f..040a3cc1df7d 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -13,6 +13,7 @@ from torch.nn.modules.module import register_module_module_registration_hook from transformers import PretrainedConfig +import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import ( get_tensor_model_parallel_rank, @@ -633,12 +634,10 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: return module - pin_memory = is_pin_memory_available() - uva_available = is_uva_available() - - assert uva_available, "V1 CPU offloading requires uva (pin memory) support" - uva_offloading = False - pin_memory = False + pin_memory = ( + is_pin_memory_available() and not envs.VLLM_OFFLOADING_DISABLE_PIN_MEMORY + ) + uva_offloading = is_uva_available() and not envs.VLLM_OFFLOADING_DISABLE_UVA # offload parameters to CPU # use pin_memory if possible, which helps cudagraph capture speed @@ -652,13 +651,13 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: cpu_data = p.data.to(device="cpu") if pin_memory: cpu_data = cpu_data.pin_memory() - + if not uva_offloading: p.data = cpu_data else: - # keep the cpu data alive - p._vllm_offloaded_cpu_data = cpu_data p.data = get_accelerator_view_from_cpu_tensor(cpu_data) + p._vllm_is_uva_offloaded = True + _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() offloaded_parameters = True @@ -675,8 +674,10 @@ def forward(*args, **kwargs): } # set `tie_weights=False` as tied weights in original model - # become untied when calling .to(device) - output = functional_call(module, device_state, args=args, kwargs=kwargs, tie_weights=False) + # become untied when calling .to(device) individually + output = functional_call( + module, device_state, args=args, kwargs=kwargs, tie_weights=False + ) module.forward = forward return output diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 0274b305e47f..1bff517fd7ba 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -678,12 +678,18 @@ def get_accelerator_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tens """ Get an accelerator view of a CPU tensor using Unified Virtual Addressing (UVA). """ - assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" from vllm.platforms import current_platform if current_platform.is_xpu(): + assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" return torch.ops._C.get_xpu_view_from_cpu_tensor(cpu_tensor) - return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) + elif current_platform.is_cuda(): + return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) + else: + raise ValueError( + f"`get_accelerator_view_from_cpu_tensor` is currently " + f"not supported in: {current_platform.device_name}" + ) # Helper function used in testing. From ae3e2216ad62c171eea3fde030f61005f9686118 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Tue, 27 Jan 2026 06:17:08 +0000 Subject: [PATCH 07/12] Use cudahostalloc for page-locked memory Signed-off-by: wzhao18 --- csrc/cuda_view.cu | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/csrc/cuda_view.cu b/csrc/cuda_view.cu index c851b1a0cea9..a162646737d1 100644 --- a/csrc/cuda_view.cu +++ b/csrc/cuda_view.cu @@ -1,7 +1,6 @@ #include #include #include -#include // This function assumes that `cpu_tensor` is a CPU tensor, // and that UVA (Unified Virtual Addressing) is enabled. @@ -25,32 +24,28 @@ torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) { // If CPU tensor is not pinned, allocate a new pinned memory buffer. torch::Tensor contiguous_cpu = cpu_tensor.contiguous(); size_t nbytes = contiguous_cpu.nbytes(); - long page_size = sysconf(_SC_PAGESIZE); - size_t aligned_size = (nbytes + page_size - 1) & ~(page_size - 1); - void* host_ptr = mmap(nullptr, aligned_size, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - - if (host_ptr == MAP_FAILED) { - AT_ERROR("mmap failed to allocate ", aligned_size, " bytes"); + void* host_ptr = nullptr; + cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped); + if (err != cudaSuccess) { + AT_ERROR("cudaHostAlloc failed: ", cudaGetErrorString(err)); } - std::memcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes); - - cudaError_t err = - cudaHostRegister(host_ptr, aligned_size, cudaHostRegisterDefault); + err = cudaMemcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes, + cudaMemcpyDefault); if (err != cudaSuccess) { - munmap(host_ptr, aligned_size); - AT_ERROR("cudaHostRegister failed: ", cudaGetErrorString(err)); + cudaFreeHost(host_ptr); + AT_ERROR("cudaMemcpy failed: ", cudaGetErrorString(err)); } void* device_ptr = nullptr; - cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); + err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); + if (err != cudaSuccess) { + cudaFreeHost(host_ptr); + AT_ERROR("cudaHostGetDevicePointer failed: ", cudaGetErrorString(err)); + } - auto deleter = [host_ptr, aligned_size](void*) { - cudaHostUnregister(host_ptr); - munmap(host_ptr, aligned_size); - }; + auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); }; return torch::from_blob(device_ptr, contiguous_cpu.sizes(), contiguous_cpu.strides(), deleter, From c706555825c886bc847e828adfdf17ce8eeb19b1 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Tue, 27 Jan 2026 07:09:09 +0000 Subject: [PATCH 08/12] Add offloading tests Signed-off-by: wzhao18 --- tests/basic_correctness/test_cpu_offload.py | 23 +++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/basic_correctness/test_cpu_offload.py b/tests/basic_correctness/test_cpu_offload.py index 89839372c309..86740b4b5e82 100644 --- a/tests/basic_correctness/test_cpu_offload.py +++ b/tests/basic_correctness/test_cpu_offload.py @@ -1,10 +1,29 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + from ..utils import compare_two_settings -def test_cpu_offload(): +@pytest.mark.parametrize("disable_pin_memory", [False, True]) +@pytest.mark.parametrize("disable_uva", [False, True]) +def test_cpu_offload(disable_pin_memory, disable_uva): + env_vars = { + "VLLM_OFFLOADING_DISABLE_PIN_MEMORY": str(int(disable_pin_memory)), + "VLLM_OFFLOADING_DISABLE_UVA": str(int(disable_uva)), + } + + args = ["--cpu-offload-gb", "1"] + + # cuda graph only works with UVA offloading + if disable_uva: + args.append("--enforce-eager") + compare_two_settings( - "hmellor/tiny-random-LlamaForCausalLM", [], ["--cpu-offload-gb", "1"] + model="hmellor/tiny-random-LlamaForCausalLM", + arg1=[], + arg2=args, + env1=None, + env2=env_vars, ) From 170368bf4777e5980961fff0734ad814cd9f491d Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Tue, 27 Jan 2026 17:54:23 +0000 Subject: [PATCH 09/12] resolve conflicts Signed-off-by: wzhao18 --- vllm/model_executor/model_loader/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 8d1cbae82321..aead774aa6f4 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -12,7 +12,6 @@ from typing_extensions import assert_never import vllm.envs as envs -from vllm.attention.layer import Attention, MLAAttention from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention, MLAAttention From 85a8eaa15dfd2d307291d6569767581a7afdee90 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Wed, 4 Feb 2026 16:43:42 +0000 Subject: [PATCH 10/12] Handle empty tensor case in get_cuda_view_from_cpu_tensor Signed-off-by: wzhao18 --- csrc/cuda_view.cu | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/csrc/cuda_view.cu b/csrc/cuda_view.cu index a162646737d1..73b368cb6003 100644 --- a/csrc/cuda_view.cu +++ b/csrc/cuda_view.cu @@ -7,6 +7,12 @@ torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) { TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU"); + // handle empty tensor + if (cpu_tensor.numel() == 0) { + return torch::empty(cpu_tensor.sizes(), + cpu_tensor.options().device(torch::kCUDA)); + } + if (cpu_tensor.is_pinned()) { // If CPU tensor is pinned, directly get the device pointer. void* host_ptr = const_cast(cpu_tensor.data_ptr()); From e065905d2d98ca5ea1e19399a4c1eeb513421243 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Fri, 6 Feb 2026 15:24:11 +0000 Subject: [PATCH 11/12] Merge from main Signed-off-by: wzhao18 --- vllm/model_executor/model_loader/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index aead774aa6f4..afa6dcb11807 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -26,7 +26,7 @@ from vllm.model_executor.models.interfaces import SupportsQuant from vllm.tracing import instrument from vllm.utils.platform_utils import is_pin_memory_available -from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor +from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor logger = init_logger(__name__) @@ -162,7 +162,7 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device) cpu_data = p.data.to(device="cpu") if use_pin_memory: cpu_data = cpu_data.pin_memory() - p.data = get_cuda_view_from_cpu_tensor(cpu_data) + p.data = get_accelerator_view_from_cpu_tensor(cpu_data) p._vllm_is_uva_offloaded = True From 990a24109be2e8c0ec495f144dd052d4c5c8d648 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Mon, 9 Feb 2026 18:28:01 +0000 Subject: [PATCH 12/12] Rename env vars Signed-off-by: wzhao18 --- tests/basic_correctness/test_cpu_offload.py | 4 ++-- vllm/envs.py | 12 ++++++------ vllm/model_executor/model_loader/utils.py | 3 ++- vllm/model_executor/models/utils.py | 4 ++-- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/basic_correctness/test_cpu_offload.py b/tests/basic_correctness/test_cpu_offload.py index 86740b4b5e82..c1df36b369a9 100644 --- a/tests/basic_correctness/test_cpu_offload.py +++ b/tests/basic_correctness/test_cpu_offload.py @@ -10,8 +10,8 @@ @pytest.mark.parametrize("disable_uva", [False, True]) def test_cpu_offload(disable_pin_memory, disable_uva): env_vars = { - "VLLM_OFFLOADING_DISABLE_PIN_MEMORY": str(int(disable_pin_memory)), - "VLLM_OFFLOADING_DISABLE_UVA": str(int(disable_uva)), + "VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY": str(int(disable_pin_memory)), + "VLLM_WEIGHT_OFFLOADING_DISABLE_UVA": str(int(disable_uva)), } args = ["--cpu-offload-gb", "1"] diff --git a/vllm/envs.py b/vllm/envs.py index cd83e84eed52..fcab6a6dd351 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -229,8 +229,8 @@ VLLM_USE_V2_MODEL_RUNNER: bool = False VLLM_LOG_MODEL_INSPECTION: bool = False VLLM_DEBUG_MFU_METRICS: bool = False - VLLM_OFFLOADING_DISABLE_PIN_MEMORY: bool = False - VLLM_OFFLOADING_DISABLE_UVA: bool = False + VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY: bool = False + VLLM_WEIGHT_OFFLOADING_DISABLE_UVA: bool = False VLLM_DISABLE_LOG_LOGO: bool = False VLLM_LORA_DISABLE_PDL: bool = False @@ -1535,12 +1535,12 @@ def _get_or_set_default() -> str: int(os.getenv("VLLM_DEBUG_MFU_METRICS", "0")) ), # Disable using pytorch's pin memory for CPU offloading. - "VLLM_OFFLOADING_DISABLE_PIN_MEMORY": lambda: bool( - int(os.getenv("VLLM_OFFLOADING_DISABLE_PIN_MEMORY", "0")) + "VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY": lambda: bool( + int(os.getenv("VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY", "0")) ), # Disable using UVA (Unified Virtual Addressing) for CPU offloading. - "VLLM_OFFLOADING_DISABLE_UVA": lambda: bool( - int(os.getenv("VLLM_OFFLOADING_DISABLE_UVA", "0")) + "VLLM_WEIGHT_OFFLOADING_DISABLE_UVA": lambda: bool( + int(os.getenv("VLLM_WEIGHT_OFFLOADING_DISABLE_UVA", "0")) ), # Disable logging of vLLM logo at server startup time. "VLLM_DISABLE_LOG_LOGO": lambda: bool(int(os.getenv("VLLM_DISABLE_LOG_LOGO", "0"))), diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index afa6dcb11807..dc525c4541af 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -146,7 +146,8 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device) finally: use_pin_memory = ( - is_pin_memory_available() and not envs.VLLM_OFFLOADING_DISABLE_PIN_MEMORY + is_pin_memory_available() + and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY ) # Restore parameters to their original devices, ignoring new parameters for name, p in module.named_parameters(): diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 040a3cc1df7d..c942178d0af3 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -635,9 +635,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: return module pin_memory = ( - is_pin_memory_available() and not envs.VLLM_OFFLOADING_DISABLE_PIN_MEMORY + is_pin_memory_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY ) - uva_offloading = is_uva_available() and not envs.VLLM_OFFLOADING_DISABLE_UVA + uva_offloading = is_uva_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_UVA # offload parameters to CPU # use pin_memory if possible, which helps cudagraph capture speed