Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ class CacheConfig:
Note that this requires fast CPU-GPU interconnect, as part of the model is
loaded from CPU memory to GPU memory on the fly in each model forward pass.
"""
cpu_offload_params: set[str] = Field(default_factory=set)
""" The set of parameter name segments to target for CPU offloading.
Unmatched parameters are not offloaded. If this set is empty, parameters
are offloaded non-selectively until the memory limit defined by
`cpu_offload_gb` is reached.
Examples:
- For parameter name "mlp.experts.w2_weight":
- "experts" or "experts.w2_weight" will match.
- "expert" or "w2" will NOT match (must be exact segments).
This allows distinguishing parameters like "w2_weight" and "w2_weight_scale".
"""
calculate_kv_scales: bool = False
"""This enables dynamic calculation of `k_scale` and `v_scale` when
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ class EngineArgs:
disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
swap_space: float = CacheConfig.swap_space
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
cpu_offload_params: set[str] = get_field(CacheConfig, "cpu_offload_params")
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
max_num_batched_tokens: int | None = None
Expand Down Expand Up @@ -942,6 +943,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
)
cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"])
cache_group.add_argument(
"--cpu-offload-params", **cache_kwargs["cpu_offload_params"]
)
cache_group.add_argument(
"--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
)
Expand Down Expand Up @@ -1453,6 +1457,7 @@ def create_engine_config(
enable_prefix_caching=self.enable_prefix_caching,
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
cpu_offload_gb=self.cpu_offload_gb,
cpu_offload_params=self.cpu_offload_params,
calculate_kv_scales=self.calculate_kv_scales,
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
mamba_cache_dtype=self.mamba_cache_dtype,
Expand Down
24 changes: 23 additions & 1 deletion vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv
from vllm.utils.mem_utils import format_gib
from vllm.utils.platform_utils import (
is_pin_memory_available,
is_uva_available,
Expand Down Expand Up @@ -613,6 +614,7 @@ def forward(self, *args, **kwargs):

_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0
_CPU_OFFLOAD_PARAMS = set()


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
Expand All @@ -621,6 +623,11 @@ def set_cpu_offload_max_bytes(max_bytes: int) -> None:
_CPU_OFFLOAD_MAX_BYTES = max_bytes


def set_cpu_offload_params(params: set[str]) -> None:
global _CPU_OFFLOAD_PARAMS
_CPU_OFFLOAD_PARAMS = params


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
if (params := next(module.parameters(), None)) is None:
return module
Expand All @@ -642,12 +649,23 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters = False
for p in module.parameters():
for name, p in module.named_parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break

if _CPU_OFFLOAD_PARAMS:
# Check if parameter belongs to the offloading set
# Add dots here to ensure we match full segments only
# e.g., "experts.w2_weight" matches "mlp.experts.w2_weight" but not
# "mlp.experts.w2_weight_scale"
should_offload = any(
f".{param}." in f".{name}." for param in _CPU_OFFLOAD_PARAMS
)
if not should_offload:
continue

cpu_data = p.data.to(device="cpu")
if pin_memory:
cpu_data = cpu_data.pin_memory()
Expand Down Expand Up @@ -708,6 +726,10 @@ def make_layers(
]
+ [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
)
if _CPU_OFFLOAD_MAX_BYTES > 0:
logger.info(
"Total CPU offloaded parameters: %s GBs", format_gib(_CPU_OFFLOAD_BYTES)
)
return start_layer, end_layer, modules


Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,13 @@ def __init__(
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config

from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.model_executor.models.utils import (
set_cpu_offload_max_bytes,
set_cpu_offload_params,
)

set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3))
set_cpu_offload_params(self.cache_config.cpu_offload_params)

model_config = self.model_config
cache_config = self.cache_config
Expand Down