diff --git a/vllm/v1/kv_offload/factory.py b/vllm/v1/kv_offload/factory.py index ecbaebb0d967..d2a8f6b23880 100644 --- a/vllm/v1/kv_offload/factory.py +++ b/vllm/v1/kv_offload/factory.py @@ -56,3 +56,6 @@ def create_spec( OffloadingSpecFactory.register_spec( "CPUOffloadingSpec", "vllm.v1.kv_offload.cpu.spec", "CPUOffloadingSpec" ) +OffloadingSpecFactory.register_spec( + "FileOffloadingSpec", "vllm.v1.kv_offload.file.spec", "FileOffloadingSpec" +) diff --git a/vllm/v1/kv_offload/file/__init__.py b/vllm/v1/kv_offload/file/__init__.py new file mode 100644 index 000000000000..b1724050c0ba --- /dev/null +++ b/vllm/v1/kv_offload/file/__init__.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +File-based KV cache offloading. + +This module provides file-based offloading for KV cache data, +storing blocks as binary files on disk. +""" +from vllm.v1.kv_offload.file.handler import FileOffloadingHandler +from vllm.v1.kv_offload.file.load_store_spec import FileLoadStoreSpec +from vllm.v1.kv_offload.file.manager import FileOffloadingManager +from vllm.v1.kv_offload.file.spec import FileOffloadingSpec + +__all__ = [ + "FileOffloadingHandler", + "FileOffloadingManager", + "FileOffloadingSpec", + "FileLoadStoreSpec", +] diff --git a/vllm/v1/kv_offload/file/handler.py b/vllm/v1/kv_offload/file/handler.py new file mode 100644 index 000000000000..bec96bd308aa --- /dev/null +++ b/vllm/v1/kv_offload/file/handler.py @@ -0,0 +1,291 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +FileOffloadingHandler: Worker-side file I/O for KV cache offloading. +""" + +import os +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field + +import numpy as np +import torch + +from vllm.logger import init_logger +from vllm.v1.kv_offload.file.load_store_spec import FileLoadStoreSpec +from vllm.v1.kv_offload.mediums import GPULoadStoreSpec +from vllm.v1.kv_offload.worker.worker import OffloadingHandler, TransferResult + +logger = init_logger(__name__) + + +@dataclass +class PendingTransfer: + job_id: int + done_event: threading.Event = field(default_factory=threading.Event) + result: TransferResult | None = None + + +class FileOffloadingHandler(OffloadingHandler): + """ + Handles KV data transfer between GPU memory and file storage. + + Uses a thread pool for async file I/O. + + Transfer types: + - GPU -> FILE (offload): serialize GPU tensor slices to binary files + - FILE -> GPU (restore): deserialize binary files to GPU tensor slices + """ + + def __init__( + self, + gpu_tensors: list[torch.Tensor], + block_size_bytes: int, + num_threads: int = 4, + ): + self.gpu_tensors = gpu_tensors + self.block_size_bytes = block_size_bytes + self.num_tensors = len(gpu_tensors) + + self._executor = ThreadPoolExecutor(max_workers=num_threads) + self._pending: dict[int, PendingTransfer] = {} + self._lock = threading.Lock() + + # Pre-allocate CPU buffers for each thread + self._cpu_buffers: dict[int, torch.Tensor] = {} + + def _get_cpu_buffer(self) -> torch.Tensor: + """Get a cached CPU buffer for transfers.""" + thread_id = threading.get_ident() + if thread_id not in self._cpu_buffers: + self._cpu_buffers[thread_id] = torch.empty( + self.block_size_bytes * self.num_tensors, + dtype=torch.int8, + pin_memory=True, + ) + return self._cpu_buffers[thread_id] + + def transfer_async(self, job_id: int, spec) -> bool: + """ + Initiate an asynchronous file transfer. + + Args: + job_id: unique job ID for completion tracking + spec: (src_spec, dst_spec) tuple + + Returns: + True if transfer was submitted successfully. + """ + src_spec, dst_spec = spec + + if isinstance(src_spec, GPULoadStoreSpec) and isinstance( + dst_spec, FileLoadStoreSpec + ): + # GPU -> FILE (offload) + pending = PendingTransfer(job_id=job_id) + with self._lock: + self._pending[job_id] = pending + self._executor.submit( + self._transfer_gpu_to_file, pending, src_spec, dst_spec + ) + elif isinstance(src_spec, FileLoadStoreSpec) and isinstance( + dst_spec, GPULoadStoreSpec + ): + # FILE -> GPU (restore) + pending = PendingTransfer(job_id=job_id) + with self._lock: + self._pending[job_id] = pending + self._executor.submit( + self._transfer_file_to_gpu, pending, src_spec, dst_spec + ) + else: + logger.error( + "Unsupported transfer: %s -> %s", + type(src_spec).__name__, + type(dst_spec).__name__, + ) + return False + + return True + + def _transfer_gpu_to_file( + self, + pending: PendingTransfer, + gpu_spec: GPULoadStoreSpec, + file_spec: FileLoadStoreSpec, + ) -> None: + """Transfer KV data from GPU to file.""" + t0 = time.monotonic() + transfer_size = 0 + + try: + src_blocks = gpu_spec.block_ids + dst_paths = file_spec.file_paths + dst_offsets = file_spec.block_offsets + + # Handle multiple tensor groups + group_sizes = ( + gpu_spec.group_sizes + if hasattr(gpu_spec, "group_sizes") + else [len(src_blocks)] + ) + block_indices = ( + gpu_spec.block_indices if hasattr(gpu_spec, "block_indices") else None + ) + del block_indices # not used yet + + tensor_idx = 0 + block_idx = 0 + for group_size in group_sizes: + gpu_tensor = self.gpu_tensors[tensor_idx] + group_blocks = src_blocks[block_idx : block_idx + group_size] + group_paths = dst_paths[block_idx : block_idx + group_size] + group_offsets = dst_offsets[block_idx : block_idx + group_size] + + for block_id, file_path, offset in zip( + group_blocks, group_paths, group_offsets + ): + # Copy GPU tensor slice to CPU using torch + gpu_slice = gpu_tensor[int(block_id)].cpu() + src_bytes = gpu_slice.numpy().tobytes() + + # Write to file (create if not exists) + os.makedirs(os.path.dirname(file_path) or ".", exist_ok=True) + with open(file_path, "wb") as f: + f.seek(offset) + f.write(src_bytes) + + transfer_size += self.block_size_bytes + + tensor_idx += 1 + block_idx += group_size + + result = TransferResult( + job_id=pending.job_id, + success=True, + transfer_size=transfer_size, + transfer_time=time.monotonic() - t0, + transfer_type=("GPU", "FILE"), + ) + except Exception as e: + logger.error("GPU->FILE transfer failed for job %d: %r", pending.job_id, e) + result = TransferResult( + job_id=pending.job_id, + success=False, + transfer_time=time.monotonic() - t0, + transfer_type=("GPU", "FILE"), + ) + + with self._lock: + pending.result = result + pending.done_event.set() + + def _transfer_file_to_gpu( + self, + pending: PendingTransfer, + file_spec: FileLoadStoreSpec, + gpu_spec: GPULoadStoreSpec, + ) -> None: + """Transfer KV data from file to GPU.""" + t0 = time.monotonic() + transfer_size = 0 + + try: + src_paths = file_spec.file_paths + src_offsets = file_spec.block_offsets + dst_blocks = gpu_spec.block_ids + + # Handle multiple tensor groups + group_sizes = ( + gpu_spec.group_sizes + if hasattr(gpu_spec, "group_sizes") + else [len(dst_blocks)] + ) + block_indices = ( + gpu_spec.block_indices if hasattr(gpu_spec, "block_indices") else None + ) + del block_indices # not used yet + + tensor_idx = 0 + block_idx = 0 + for group_size in group_sizes: + gpu_tensor = self.gpu_tensors[tensor_idx] + group_paths = src_paths[block_idx : block_idx + group_size] + group_offsets = src_offsets[block_idx : block_idx + group_size] + group_blocks = dst_blocks[block_idx : block_idx + group_size] + + for file_path, offset, block_id in zip( + group_paths, group_offsets, group_blocks + ): + # Read from file + with open(file_path, "rb") as f: + f.seek(offset) + data = f.read(self.block_size_bytes) + + # Convert to torch tensor and copy to GPU + data_tensor = torch.from_numpy( + np.frombuffer(data, dtype=np.int8) + ).clone() + gpu_tensor[int(block_id)].copy_(data_tensor) + + transfer_size += self.block_size_bytes + + tensor_idx += 1 + block_idx += group_size + + result = TransferResult( + job_id=pending.job_id, + success=True, + transfer_size=transfer_size, + transfer_time=time.monotonic() - t0, + transfer_type=("FILE", "GPU"), + ) + except Exception as e: + logger.error("FILE->GPU transfer failed for job %d: %r", pending.job_id, e) + result = TransferResult( + job_id=pending.job_id, + success=False, + transfer_time=time.monotonic() - t0, + transfer_type=("FILE", "GPU"), + ) + + with self._lock: + pending.result = result + pending.done_event.set() + + def get_finished(self) -> list[TransferResult]: + """Get list of finished transfers.""" + results = [] + with self._lock: + done_ids = [ + job_id for job_id, p in self._pending.items() if p.done_event.is_set() + ] + for job_id in done_ids: + pending = self._pending.pop(job_id) + if pending.result: + results.append(pending.result) + + return results + + def wait(self, job_ids: set[int]) -> None: + """Wait for specified jobs to complete (blocking).""" + for job_id in job_ids: + with self._lock: + pending = self._pending.get(job_id) + if pending: + pending.done_event.wait() + + def shutdown(self) -> None: + """Shutdown the handler and release resources.""" + # Wait for all pending transfers + with self._lock: + pending_ids = list(self._pending.keys()) + + for job_id in pending_ids: + self.wait({job_id}) + + self._executor.shutdown(wait=True) + self._cpu_buffers.clear() + logger.info("FileOffloadingHandler shutdown complete") diff --git a/vllm/v1/kv_offload/file/load_store_spec.py b/vllm/v1/kv_offload/file/load_store_spec.py new file mode 100644 index 000000000000..b2ff559555d2 --- /dev/null +++ b/vllm/v1/kv_offload/file/load_store_spec.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +LoadStoreSpec for file-based KV offloading. +""" +from vllm.v1.kv_offload.abstract import LoadStoreSpec + + +class FileLoadStoreSpec(LoadStoreSpec): + """ + Spec for loading/storing KV blocks from/to files. + + file_paths: list of file paths for the blocks. + block_offsets: byte offsets within each file (for mmap-style access). + """ + + def __init__(self, file_paths: list[str], block_offsets: list[int] | None = None): + self.file_paths = file_paths + self.block_offsets = block_offsets or [0] * len(file_paths) + + @staticmethod + def medium() -> str: + return "FILE" diff --git a/vllm/v1/kv_offload/file/manager.py b/vllm/v1/kv_offload/file/manager.py new file mode 100644 index 000000000000..497cb718b172 --- /dev/null +++ b/vllm/v1/kv_offload/file/manager.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +FileOffloadingManager: Manages KV cache offloading to file storage. +""" + +import ctypes +from collections import OrderedDict +from collections.abc import Iterable +from pathlib import Path + +from vllm.logger import init_logger +from vllm.v1.kv_offload.abstract import ( + LoadStoreSpec, + OffloadingEvent, + OffloadingManager, + OffloadKey, + PrepareStoreOutput, + ReqContext, +) +from vllm.v1.kv_offload.file.load_store_spec import FileLoadStoreSpec + +logger = init_logger(__name__) + + +class FileBlockStatus(ctypes.Structure): + """ + Metadata for a file-stored block. + """ + + _fields_ = [ + ("ref_cnt", ctypes.c_int32), + ("block_id", ctypes.c_int64), + ("is_ready", ctypes.c_bool), + ] + + def __init__(self, block_id: int): + super().__init__() + self.ref_cnt = -1 # -1 = not ready yet + self.block_id = block_id + self.is_ready = False + + +class FileOffloadingManager(OffloadingManager): + """ + An OffloadingManager that stores KV blocks as files on disk. + + File layout: + {storage_dir}/ + {key_hex_0}.bin + {key_hex_1}.bin + ... + + Each block is stored as a separate file for simplicity. + The manager tracks which blocks are stored and manages eviction. + """ + + def __init__( + self, + storage_dir: str, + num_blocks: int, + block_size_bytes: int, + enable_events: bool = False, + ): + self.storage_dir = Path(storage_dir) + self.storage_dir.mkdir(parents=True, exist_ok=True) + self._num_blocks = num_blocks + self.block_size_bytes = block_size_bytes + self._num_allocated_blocks = 0 + self._free_list: list[int] = [] + + # key -> FileBlockStatus + self._blocks: OrderedDict[OffloadKey, FileBlockStatus] = OrderedDict() + self.events: list[OffloadingEvent] | None = [] if enable_events else None + + def _get_num_free_blocks(self) -> int: + return len(self._free_list) + self._num_blocks - self._num_allocated_blocks + + def _allocate_blocks(self, keys: list[OffloadKey]) -> list[FileBlockStatus]: + num_fresh = min(len(keys), self._num_blocks - self._num_allocated_blocks) + num_reused = len(keys) - num_fresh + + blocks: list[FileBlockStatus] = [] + for _ in range(num_fresh): + blocks.append(FileBlockStatus(self._num_allocated_blocks)) + self._num_allocated_blocks += 1 + + for _ in range(num_reused): + blocks.append(FileBlockStatus(self._free_list.pop())) + return blocks + + def _free_block(self, block: FileBlockStatus) -> None: + self._free_list.append(block.block_id) + + def _key_to_path(self, key: OffloadKey) -> Path: + return self.storage_dir / f"{key.hex()}.bin" + + def _delete_file(self, key: OffloadKey) -> None: + path = self._key_to_path(key) + try: + if path.exists(): + path.unlink() + except OSError as e: + logger.warning("Failed to delete file %s: %r", path, e) + + # --- OffloadingManager interface --- + + def lookup( + self, + keys: Iterable[OffloadKey], + req_context: ReqContext, + ) -> int | None: + hit_count = 0 + for key in keys: + block = self._blocks.get(key) + if block is None or not block.is_ready: + break + hit_count += 1 + return hit_count + + def prepare_load( + self, + keys: Iterable[OffloadKey], + req_context: ReqContext, + ) -> LoadStoreSpec: + keys_list = list(keys) + file_paths: list[str] = [] + block_offsets: list[int] = [] + + for key in keys_list: + block = self._blocks.get(key) + assert block is not None, f"Block {key!r} not found in file cache" + assert block.is_ready, f"Block {key!r} is not ready for reading" + block.ref_cnt += 1 + file_paths.append(str(self._key_to_path(key))) + # Each file contains one block's data, so offset is always 0 + block_offsets.append(0) + + return FileLoadStoreSpec(file_paths, block_offsets) + + def touch(self, keys: Iterable[OffloadKey]) -> None: + for key in reversed(list(keys)): + if key in self._blocks: + self._blocks.move_to_end(key) + + def complete_load(self, keys: Iterable[OffloadKey]) -> None: + for key in keys: + block = self._blocks.get(key) + assert block is not None, f"Block {key!r} not found" + assert block.ref_cnt > 0, f"Block {key!r} ref_cnt is already 0" + block.ref_cnt -= 1 + + def prepare_store( + self, + keys: Iterable[OffloadKey], + req_context: ReqContext, + ) -> PrepareStoreOutput | None: + keys_list = list(keys) + + # Filter out blocks already stored + keys_to_store = [k for k in keys_list if k not in self._blocks] + + if not keys_to_store: + return PrepareStoreOutput( + keys_to_store=[], + store_spec=FileLoadStoreSpec([], []), + evicted_keys=[], + ) + + num_blocks_to_evict = len(keys_to_store) - self._get_num_free_blocks() + + to_evict: list[OffloadKey] = [] + if num_blocks_to_evict > 0: + protected = set(keys_list) + candidates: list[tuple[OffloadKey, FileBlockStatus]] = [] + for key, block in self._blocks.items(): + if block.ref_cnt == 0 and key not in protected: + candidates.append((key, block)) + if len(candidates) == num_blocks_to_evict: + break + + if len(candidates) < num_blocks_to_evict: + return None # Cannot evict enough blocks + + for key, _ in candidates: + self._blocks.pop(key) + self._free_block(self._blocks.get(key) or FileBlockStatus(-1)) + self._delete_file(key) + to_evict.append(key) + + if to_evict and self.events is not None: + self.events.append( + OffloadingEvent(keys=to_evict, medium="FILE", removed=True) + ) + + # Allocate blocks for new entries + new_blocks = self._allocate_blocks(keys_to_store) + for key, block in zip(keys_to_store, new_blocks): + self._blocks[key] = block + + # Build store spec with file paths + # Each key maps to one file, so offset is always 0 + file_paths = [str(self._key_to_path(k)) for k in keys_to_store] + block_offsets = [0] * len(keys_to_store) + + return PrepareStoreOutput( + keys_to_store=keys_to_store, + store_spec=FileLoadStoreSpec(file_paths, block_offsets), + evicted_keys=to_evict, + ) + + def complete_store(self, keys: Iterable[OffloadKey], success: bool = True) -> None: + stored_keys: list[OffloadKey] = [] + + if success: + for key in keys: + block = self._blocks.get(key) + if block is not None and not block.is_ready: + block.ref_cnt = 0 + block.is_ready = True + stored_keys.append(key) + else: + for key in keys: + block = self._blocks.get(key) + if block is not None and not block.is_ready: + self._free_block(block) + self._blocks.pop(key) + self._delete_file(key) + + if stored_keys and self.events is not None: + self.events.append( + OffloadingEvent(keys=stored_keys, medium="FILE", removed=False) + ) + + def take_events(self) -> Iterable[OffloadingEvent]: + if self.events is not None: + yield from self.events + self.events.clear() + + def shutdown(self) -> None: + # Clean up all files on shutdown + for key in list(self._blocks.keys()): + self._delete_file(key) + self._blocks.clear() + self._free_list.clear() + logger.info("FileOffloadingManager shutdown complete") diff --git a/vllm/v1/kv_offload/file/spec.py b/vllm/v1/kv_offload/file/spec.py new file mode 100644 index 000000000000..51f71dcc0ba8 --- /dev/null +++ b/vllm/v1/kv_offload/file/spec.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +FileOffloadingSpec: File-based KV cache offloading implementation. +""" + +from collections.abc import Iterator +from typing import TYPE_CHECKING + +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager +from vllm.v1.kv_offload.file.handler import FileOffloadingHandler +from vllm.v1.kv_offload.file.load_store_spec import FileLoadStoreSpec +from vllm.v1.kv_offload.file.manager import FileOffloadingManager +from vllm.v1.kv_offload.mediums import GPULoadStoreSpec +from vllm.v1.kv_offload.spec import CanonicalKVCaches, OffloadingSpec +from vllm.v1.kv_offload.worker.worker import OffloadingHandler + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.v1.kv_cache_interface import KVCacheConfig + +logger = init_logger(__name__) + + +class FileOffloadingSpec(OffloadingSpec): + """ + File-based offloading spec. + + Configuration options (via kv_connector_extra_config): + - storage_dir: Directory to store KV cache files + (default: /tmp/vllm_offload) + - num_blocks: Maximum number of blocks to store + - block_size_bytes: Size of each block in bytes + (computed from KV cache config if not specified) + - enable_events: Enable offloading events for debugging + - num_threads: Number of threads for file I/O + (default: 4) + """ + + def __init__(self, vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig"): + super().__init__(vllm_config, kv_cache_config) + + # Get storage directory + self.storage_dir = self.extra_config.get("storage_dir", "/tmp/vllm_offload") + + # Get number of blocks + self.num_blocks = self.extra_config.get("num_blocks") + if self.num_blocks is None: + raise ValueError( + "num_blocks must be specified in kv_connector_extra_config " + "for FileOffloadingSpec" + ) + + # Compute block size if not specified + self.block_size_bytes = self.extra_config.get("block_size_bytes") + if self.block_size_bytes is None: + # Compute from KV cache config + page_sizes = { + kv_cache_group.kv_cache_spec.page_size_bytes + for kv_cache_group in kv_cache_config.kv_cache_groups + } + if len(page_sizes) != 1: + raise ValueError( + "Cannot compute block_size_bytes: multiple page sizes found" + ) + page_size = page_sizes.pop() + kv_bytes_per_block = ( + page_size + * len(kv_cache_config.kv_cache_tensors) + * vllm_config.parallel_config.world_size + ) + self.block_size_bytes = kv_bytes_per_block * self.block_size_factor + + # Number of I/O threads for file operations + self.num_threads = self.extra_config.get("num_threads", 4) + + # Enable offloading events + kv_events_config = vllm_config.kv_events_config + self.enable_events = ( + kv_events_config is not None and kv_events_config.enable_kv_cache_events + ) + + # Manager instance (scheduler-side) + self._manager: OffloadingManager | None = None + + # Handler instance (worker-side) + self._handler: FileOffloadingHandler | None = None + + logger.info( + "FileOffloadingSpec: storage_dir=%s, num_blocks=%d, block_size_bytes=%d", + self.storage_dir, + self.num_blocks, + self.block_size_bytes, + ) + + def get_manager(self) -> OffloadingManager: + """Get the offloading manager (scheduler-side).""" + if self._manager is None: + self._manager = FileOffloadingManager( + storage_dir=self.storage_dir, + num_blocks=self.num_blocks, + block_size_bytes=self.block_size_bytes, + enable_events=self.enable_events, + ) + return self._manager + + def get_handlers( + self, kv_caches: CanonicalKVCaches + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: + """ + Get offloading handlers for GPU ↔ FILE transfers. + + Yields: + Tuples of (src_type, dst_type, handler). + """ + if not current_platform.is_cuda_alike(): + raise Exception( + "File-based offloading is currently only supported on CUDA-alike GPUs" + ) + + if self._handler is None: + self._handler = FileOffloadingHandler( + gpu_tensors=[t.tensor for t in kv_caches.tensors], + block_size_bytes=self.block_size_bytes, + num_threads=self.num_threads, + ) + + assert self._handler is not None + + # GPU → FILE (offload) + yield GPULoadStoreSpec, FileLoadStoreSpec, self._handler + + # FILE → GPU (restore) + yield FileLoadStoreSpec, GPULoadStoreSpec, self._handler