diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index 35ebfa1e73b4..af948ce3aef3 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -107,7 +107,7 @@ def close(self): self.sub.close() -def _wait_for_prefix_cache_reset(llm: LLM) -> None: +def _wait_for_prefix_cache_reset(llm: LLM, reset_connector: bool = False) -> None: """Wait for async offload transfers to finish so prefix cache can reset. The GPU-to-CPU offload runs on a CUDA stream asynchronously. While blocks @@ -115,10 +115,14 @@ def _wait_for_prefix_cache_reset(llm: LLM) -> None: ``False``. Between retries we send a dummy single-token prefill to force the engine to step, which polls the worker for completed transfers and frees GPU blocks. + + Args: + llm: The LLM instance to reset. + reset_connector: If True, also reset the KV connector state. """ _dummy_params = SamplingParams(max_tokens=1) deadline = time.monotonic() + _RESET_CACHE_TIMEOUT - while not llm.reset_prefix_cache(): + while not llm.reset_prefix_cache(reset_connector=reset_connector): if time.monotonic() > deadline: raise TimeoutError( "reset_prefix_cache did not succeed within " @@ -133,7 +137,9 @@ def _wait_for_prefix_cache_reset(llm: LLM) -> None: ) -def _latency_test(llm: LLM, subscriber: MockSubscriber | None): +def _latency_test( + llm: LLM, subscriber: MockSubscriber | None, reset_connector: bool = False +): sampling_params = SamplingParams(max_tokens=1) num_times_cpu_better_than_cold = 0 @@ -163,7 +169,7 @@ def _latency_test(llm: LLM, subscriber: MockSubscriber | None): # Wait for the async CPU offload to finish, then reset prefix cache # so the next generate() must reload from CPU rather than GPU. - _wait_for_prefix_cache_reset(llm) + _wait_for_prefix_cache_reset(llm, reset_connector=reset_connector) # Verify CPU stored events arrived (offload is done before we # attempt to load from CPU). @@ -337,3 +343,49 @@ def test_tiering_offloading() -> None: finally: subscriber.close() del llm + + +def test_fs_tiering_offloading(tmp_path) -> None: + """Tests OffloadingConnector with TieringOffloadingSpec + + fs_python secondary tier.""" + extra_config: dict = { + "cpu_bytes_to_use": 1 << 30, + "block_size": 48, + "spec_name": "TieringOffloadingSpec", + "secondary_tiers": [{"type": "fs_python", "root_dir": str(tmp_path)}], + } + kv_transfer_config = KVTransferConfig( + kv_connector="OffloadingConnector", + kv_role="kv_both", + kv_connector_extra_config=extra_config, + ) + + port: int + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("0.0.0.0", 0)) + port = s.getsockname()[1] + events_endpoint = f"tcp://*:{port}" + kv_events_config = KVEventsConfig( + enable_kv_cache_events=True, + publisher="zmq", + endpoint=events_endpoint, + topic="test", + ) + + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + max_model_len=512, + gpu_memory_utilization=0.5, + kv_events_config=kv_events_config, + kv_transfer_config=kv_transfer_config, + ) + subscriber = MockSubscriber( + events_endpoint.replace("*", "127.0.0.1"), + topic=kv_events_config.topic, + ) + try: + _latency_test(llm, subscriber, reset_connector=True) + _accuracy_test(llm, subscriber) + finally: + subscriber.close() + del llm diff --git a/tests/v1/kv_offload/test_file_mapper.py b/tests/v1/kv_offload/test_file_mapper.py new file mode 100644 index 000000000000..920eea92d96d --- /dev/null +++ b/tests/v1/kv_offload/test_file_mapper.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for FileMapper.""" + +from unittest.mock import MagicMock + +from vllm.v1.kv_offload.base import ( + OffloadingSpec, + make_offload_key, +) +from vllm.v1.kv_offload.file_mapper import FileMapper + +# --------------------------------------------------------------------------- +# Shared mocks (mirrors test_fs_tier.py pattern) +# --------------------------------------------------------------------------- + +_MOCK_VLLM_CONFIG = MagicMock() +_MOCK_VLLM_CONFIG.model_config.model = "test-model" +_MOCK_VLLM_CONFIG.cache_config.block_size = 16 +_MOCK_VLLM_CONFIG.cache_config.cache_dtype = "torch.float32" +_MOCK_VLLM_CONFIG.parallel_config.tensor_parallel_size = 1 +_MOCK_VLLM_CONFIG.parallel_config.pipeline_parallel_size = 1 +_MOCK_VLLM_CONFIG.parallel_config.prefill_context_parallel_size = 1 +_MOCK_VLLM_CONFIG.parallel_config.decode_context_parallel_size = 1 +_MOCK_VLLM_CONFIG.parallel_config.rank = 0 + +_MOCK_KV_CACHE_CONFIG = MagicMock() +_MOCK_KV_CACHE_CONFIG.kv_cache_groups = [] + +_MOCK_OFFLOADING_SPEC = MagicMock(spec=OffloadingSpec) +_MOCK_OFFLOADING_SPEC.vllm_config = _MOCK_VLLM_CONFIG +_MOCK_OFFLOADING_SPEC.kv_cache_config = _MOCK_KV_CACHE_CONFIG +_MOCK_OFFLOADING_SPEC.block_size_factor = 1 + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + + +def make_mapper_from_offloading_spec(**kwargs) -> FileMapper: + """Helper to create FileMapper with customizable mock config.""" + # Create a copy of the mock config to avoid modifying the global one + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.model = kwargs.get("model_name", "test-model") + mock_vllm_config.cache_config.block_size = kwargs.get("hash_block_size", 16) + mock_vllm_config.cache_config.cache_dtype = ( + f"torch.{kwargs.get('dtype', 'float16')}" + ) + mock_vllm_config.parallel_config.tensor_parallel_size = kwargs.get("tp_size", 1) + mock_vllm_config.parallel_config.pipeline_parallel_size = kwargs.get("pp_size", 1) + mock_vllm_config.parallel_config.prefill_context_parallel_size = kwargs.get( + "pcp_size", 1 + ) + mock_vllm_config.parallel_config.decode_context_parallel_size = kwargs.get( + "dcp_size", 1 + ) + mock_vllm_config.parallel_config.rank = kwargs.get("rank", 0) + + mock_kv_cache_config = MagicMock() + mock_kv_cache_config.kv_cache_groups = [] + + mock_offloading_spec = MagicMock(spec=OffloadingSpec) + mock_offloading_spec.vllm_config = mock_vllm_config + mock_offloading_spec.kv_cache_config = mock_kv_cache_config + mock_offloading_spec.block_size_factor = kwargs.get("block_size_factor", 1) + + return FileMapper.from_offloading_spec( + root_dir=kwargs.get("root_dir", "/tmp/cache"), + offloading_spec=mock_offloading_spec, + gpu_blocks_per_file=mock_offloading_spec.block_size_factor, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_get_file_name_full_structure(): + """ + Path must match: _r//_g/.bin + + Concretely: + - The segment immediately after base_path must end with `_r0` + - The next segment is the first 3 hex chars of the block hash + - The next segment is <2 hex chars>_g + - The final segment is .bin + """ + rank = 3 + group_idx = 2 + block_hash = bytes(range(8)) # deterministic, non-zero bytes + fm = make_mapper_from_offloading_spec(rank=rank) + key = make_offload_key(block_hash, group_idx) + path = fm.get_file_name(key) + + expected_path = ( + "/tmp/cache/test-model_588656ebcc66_r3/000/10_g2/0001020304050607.bin" + ) + assert path == expected_path + + +def test_get_run_config_fields(): + fm = make_mapper_from_offloading_spec( + model_name="my-model", + dtype="bfloat16", + tp_size=2, + ) + cfg = fm.get_run_config() + assert cfg == { + "model_name": "my-model", + "hash_block_size": 16, + "gpu_blocks_per_file": 1, + "tp_size": 2, + "pp_size": 1, + "pcp_size": 1, + "dcp_size": 1, + "dtype": "bfloat16", + "kv_cache_groups": [], + "inference_engine": "vllm", + } + + +def test_get_config_file_path(): + fm = make_mapper_from_offloading_spec() + config_path = fm.get_config_file_path() + assert config_path == f"{fm.base_path}/config.json" diff --git a/tests/v1/kv_offload/test_fs_tier.py b/tests/v1/kv_offload/test_fs_tier.py new file mode 100644 index 000000000000..fcb5879b9bdc --- /dev/null +++ b/tests/v1/kv_offload/test_fs_tier.py @@ -0,0 +1,256 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for FileSystemTierManager. + +These tests use real disk I/O to verify the Python filesystem tier implementation. +The tier manager writes KV cache blocks to disk and reads them back, verifying +data integrity throughout the process. +""" + +import os +import time +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch + +from vllm.v1.kv_offload.base import OffloadKey, ReqContext, make_offload_key +from vllm.v1.kv_offload.tiering.base import JobMetadata +from vllm.v1.kv_offload.tiering.fs.manager import ( + FileSystemTierManager, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_BLOCK_ELEMENTS = 512 * 1024 # 2 MB per block (float32 × 512K = 2MB) +_DTYPE = torch.float32 +_CTX = ReqContext(req_id="test") + +_MOCK_VLLM_CONFIG = MagicMock() +_MOCK_VLLM_CONFIG.model_config.model = "test-model" +_MOCK_VLLM_CONFIG.cache_config.block_size = 16 +_MOCK_VLLM_CONFIG.cache_config.cache_dtype = "torch.float32" +_MOCK_VLLM_CONFIG.parallel_config.tensor_parallel_size = 1 +_MOCK_VLLM_CONFIG.parallel_config.pipeline_parallel_size = 1 +_MOCK_VLLM_CONFIG.parallel_config.prefill_context_parallel_size = 1 +_MOCK_VLLM_CONFIG.parallel_config.decode_context_parallel_size = 1 +_MOCK_VLLM_CONFIG.parallel_config.rank = 0 + +_MOCK_KV_CACHE_CONFIG = MagicMock() +_MOCK_KV_CACHE_CONFIG.kv_cache_groups = [] + +_MOCK_OFFLOADING_SPEC = MagicMock() +_MOCK_OFFLOADING_SPEC.vllm_config = _MOCK_VLLM_CONFIG +_MOCK_OFFLOADING_SPEC.kv_cache_config = _MOCK_KV_CACHE_CONFIG +_MOCK_OFFLOADING_SPEC.block_size_factor = 1 + + +def key(n: int) -> OffloadKey: + return make_offload_key(n.to_bytes(8, "big"), 0) + + +def make_job( + job_id: int, + keys: list[OffloadKey], + block_ids: list[int] | None = None, + is_promotion: bool = False, +) -> JobMetadata: + if block_ids is None: + block_ids = list(range(len(keys))) + return JobMetadata( + job_id=job_id, + keys=keys, + block_ids=np.array(block_ids, dtype=np.int64), + is_promotion=is_promotion, + req_context=_CTX, + ) + + +def drain(tier: FileSystemTierManager, max_rounds: int = 40) -> list: + """ + Call get_finished() repeatedly until no new results arrive for 5 + consecutive rounds or max_rounds is reached. + """ + results = [] + idle = 0 + for _ in range(max_rounds): + time.sleep(0.01) + new = list(tier.get_finished()) + results.extend(new) + if new: + idle = 0 + else: + idle += 1 + if idle >= 5: + break + return results + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def fs_tier(tmp_path): + tensor = torch.zeros((4, _BLOCK_ELEMENTS), dtype=_DTYPE) + mock_view = memoryview(tensor.numpy()) + tier = FileSystemTierManager( + offloading_spec=_MOCK_OFFLOADING_SPEC, + primary_kv_view=mock_view, + tier_type="fs_python", + root_dir=str(tmp_path), + n_read_threads=4, + n_write_threads=4, + ) + yield tier, tensor + tier.shutdown() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_lookup_empty_tier(fs_tier): + tier, _ = fs_tier + assert tier.lookup(key(1), _CTX) is False + assert tier.lookup(key(2), _CTX) is False + + +def test_store_creates_file_and_lookup_succeeds(fs_tier): + tier, _ = fs_tier + job = make_job(1, [key(1)], [0]) + tier.submit_store(job) + results = drain(tier) + assert len(results) == 1 + assert results[0].success + assert tier.lookup(key(1), _CTX) is True + dest = tier.file_mapper.get_file_name(key(1)) + assert os.path.exists(dest), f"Expected file at {dest}" + + +def test_store_then_load_roundtrip(fs_tier): + tier, _ = fs_tier + job_s = make_job(1, [key(1), key(2)], [0, 1]) + tier.submit_store(job_s) + store_results = drain(tier) + assert all(r.success for r in store_results) + + assert tier.lookup(key(1), _CTX) is True + assert tier.lookup(key(2), _CTX) is True + + job_l = make_job(2, [key(1), key(2)], [2, 3], is_promotion=True) + tier.submit_load(job_l) + load_results = drain(tier) + assert all(r.success for r in load_results) + # Blocks stay on disk after load + assert tier.lookup(key(1), _CTX) is True + assert tier.lookup(key(2), _CTX) is True + + +def test_invalid_path_raises_at_construction(): + """Construction must fail immediately when the config file cannot be written.""" + tensor = torch.zeros((32, _BLOCK_ELEMENTS), dtype=_DTYPE) + mock_view = memoryview(tensor.numpy()) + + with pytest.raises(OSError): + FileSystemTierManager( + offloading_spec=_MOCK_OFFLOADING_SPEC, + primary_kv_view=mock_view, + tier_type="fs_python", + root_dir="/dev/null/invalid_path", + ) + + +def test_failed_load_missing_file(fs_tier): + """Test that loading a block whose file does not exist results in a failed job.""" + tier, _ = fs_tier + job = make_job(1, [key(99)], [0], is_promotion=True) + tier.submit_load(job) + results = drain(tier) + assert len(results) == 1 + assert not results[0].success + + +def test_multiple_jobs_tracked_independently(fs_tier): + tier, _ = fs_tier + job1 = make_job(1, [key(1)], [0]) + job2 = make_job(2, [key(2)], [1]) + tier.submit_store(job1) + tier.submit_store(job2) + results = drain(tier) + job_ids = {r.job_id for r in results} + assert job_ids == {1, 2} + assert tier.lookup(key(1), _CTX) is True + assert tier.lookup(key(2), _CTX) is True + + +def test_multi_block_job_partial_failure(fs_tier): + """A load job where one block file is missing yields a single failed JobResult.""" + tier, _ = fs_tier + # Store two of three keys + tier.submit_store(make_job(1, [key(10), key(11)], [0, 1])) + assert all(r.success for r in drain(tier)) + + # Load all three — key(99) was never stored + tier.submit_load( + make_job(2, [key(10), key(11), key(99)], [0, 1, 2], is_promotion=True) + ) + results = drain(tier) + + assert len(results) == 1 + assert results[0].job_id == 2 + assert not results[0].success + + +def test_shutdown_discards_pending_tasks(fs_tier): + """Shutdown clears both queues and stops all worker threads without draining.""" + tier, _ = fs_tier + # Submit many tasks to ensure some remain pending + for i in range(10): + tier.submit_store(make_job(i, [key(i)], [i % 4])) + + # Shutdown immediately without draining + tier.shutdown() + + # Verify queues are cleared and threads stopped + assert len(tier._pool._load_q) == 0 + assert len(tier._pool._store_q) == 0 + assert all(not t.is_alive() for t in tier._pool._threads) + + +def test_store_load_data_integrity(fs_tier): + """Data written by store must be exactly recovered by load.""" + tier, tensor = fs_tier + # Populate tensor with random data + tensor[:] = torch.rand((4, _BLOCK_ELEMENTS), dtype=_DTYPE) + + # Store first 2 blocks + num_store = 2 + expected = tensor[:num_store].clone() + + store_ids = list(range(num_store)) + keys = [key(i) for i in range(num_store)] + + tier.submit_store(make_job(1, keys, store_ids)) + results = drain(tier) + assert all(r.success for r in results) + + # Overwrite source blocks to prove data is read from disk + tensor[:num_store] = 0.0 + + # Load into last 2 blocks + load_ids = [2, 3] + tier.submit_load(make_job(2, keys, load_ids, is_promotion=True)) + results = drain(tier) + assert all(r.success for r in results) + + for i, bid in enumerate(load_ids): + assert torch.allclose(tensor[bid], expected[i]), ( + f"Block {bid} data mismatch after store+load" + ) diff --git a/vllm/v1/kv_offload/file_mapper.py b/vllm/v1/kv_offload/file_mapper.py new file mode 100644 index 000000000000..7184a5d1ce13 --- /dev/null +++ b/vllm/v1/kv_offload/file_mapper.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +import json + +from vllm.v1.kv_offload.base import ( + OffloadingSpec, + OffloadKey, + get_offload_block_hash, + get_offload_group_idx, +) + +_BASE_PATH_HASH_LEN = 12 +_CONFIG_FILENAME = "config.json" + + +class FileMapper: + """ + FileMapper maps KV blocks (given by their hash) to file names. + """ + + def __init__( + self, + root_dir: str, + model_name: str, + hash_block_size: int, + gpu_blocks_per_file: int, + tp_size: int, + pp_size: int, + pcp_size: int, + dcp_size: int, + rank: int, + dtype: str, + kv_cache_groups: list[dict] | None = None, + inference_engine: str = "vllm", + parallel_agnostic: bool = False, + ): + """ + Initialize the file mapper. Each worker constructs its own, but + `config.json` is shared across workers since rank lives outside the hash. + When `parallel_agnostic=True`, tp/pp/pcp/dcp are forced to 1 and rank + to 0 so multiple parallelism layouts collapse into the same folder. + """ + if parallel_agnostic: + tp_size = pp_size = pcp_size = dcp_size = 1 + rank = 0 + self.rank: int = rank + self.fields: dict = { + "model_name": model_name, + "hash_block_size": hash_block_size, + "gpu_blocks_per_file": gpu_blocks_per_file, + "tp_size": tp_size, + "pp_size": pp_size, + "pcp_size": pcp_size, + "dcp_size": dcp_size, + "dtype": str(dtype), + "kv_cache_groups": kv_cache_groups or [], + "inference_engine": inference_engine, + } + self.base_path: str = self._compute_base_path(root_dir, self.fields) + + @classmethod + def from_offloading_spec( + cls, + root_dir: str, + offloading_spec: OffloadingSpec, + gpu_blocks_per_file: int = 1, + parallel_agnostic: bool = False, + ) -> "FileMapper": + """Build a FileMapper from an OffloadingSpec.""" + vllm_config = offloading_spec.vllm_config + kv_cache_config = offloading_spec.kv_cache_config + + parallel_config = vllm_config.parallel_config + dtype = str(vllm_config.cache_config.cache_dtype).replace("torch.", "") + kv_cache_groups = [ + { + "block_size": group.kv_cache_spec.block_size, + "layer_names": list(group.layer_names), + } + for group in kv_cache_config.kv_cache_groups + ] + return cls( + root_dir=root_dir, + model_name=vllm_config.model_config.model, + hash_block_size=vllm_config.cache_config.block_size, + gpu_blocks_per_file=gpu_blocks_per_file, + tp_size=parallel_config.tensor_parallel_size, + pp_size=parallel_config.pipeline_parallel_size, + pcp_size=parallel_config.prefill_context_parallel_size, + dcp_size=parallel_config.decode_context_parallel_size, + rank=parallel_config.rank, + dtype=dtype, + kv_cache_groups=kv_cache_groups, + parallel_agnostic=parallel_agnostic, + ) + + def get_file_name(self, key: OffloadKey) -> str: + """Map an OffloadKey to _r//_g/.bin.""" + hash_hex = get_offload_block_hash(key).hex() + group_idx = get_offload_group_idx(key) + subfolder1, subfolder2 = hash_hex[:3], hash_hex[3:5] + return ( + f"{self.base_path}_r{self.rank}" + f"/{subfolder1}/{subfolder2}_g{group_idx}/{hash_hex}.bin" + ) + + def get_run_config(self) -> dict: + return dict(self.fields) + + def get_config_file_path(self) -> str: + return f"{self.base_path}/{_CONFIG_FILENAME}" + + @staticmethod + def _compute_base_path(root_dir: str, fields: dict) -> str: + """ + Layout: /_/. + safe_model_name replaces '/' with '_' so HuggingFace IDs don't nest. + """ + canonical = json.dumps(fields, sort_keys=True, separators=(",", ":")) + digest = hashlib.sha256(canonical.encode("utf-8")).hexdigest()[ + :_BASE_PATH_HASH_LEN + ] + safe_model_name = fields["model_name"].replace("/", "_") + return f"{root_dir}/{safe_model_name}_{digest}" diff --git a/vllm/v1/kv_offload/tiering/factory.py b/vllm/v1/kv_offload/tiering/factory.py index 430f74ca5424..bc5f1f6ad6c1 100644 --- a/vllm/v1/kv_offload/tiering/factory.py +++ b/vllm/v1/kv_offload/tiering/factory.py @@ -57,3 +57,9 @@ def create_secondary_tier( "vllm.v1.kv_offload.tiering.example.manager", "ExampleSecondaryTierManager", ) + +SecondaryTierFactory.register_tier( + "fs_python", + "vllm.v1.kv_offload.tiering.fs.manager", + "FileSystemTierManager", +) diff --git a/vllm/v1/kv_offload/tiering/fs/__init__.py b/vllm/v1/kv_offload/tiering/fs/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/v1/kv_offload/tiering/fs/io.py b/vllm/v1/kv_offload/tiering/fs/io.py new file mode 100644 index 000000000000..c5a82a73c674 --- /dev/null +++ b/vllm/v1/kv_offload/tiering/fs/io.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging +import os +import random +import threading + +logger = logging.getLogger(__name__) + +# O_DIRECT is Linux-specific and not available on macOS +O_DIRECT = getattr(os, "O_DIRECT", 0) + +# Thread-local storage for unique temporary file suffixes +_thread_local = threading.local() + + +def _get_tmp_suffix() -> str: + """Generate a thread-local unique suffix for temporary files.""" + try: + return _thread_local.tmp_suffix + except AttributeError: + _thread_local.tmp_suffix = f"_{random.randint(0, 2**63 - 1)}.tmp" + return _thread_local.tmp_suffix + + +def _ensure_dirs(path: str) -> None: + """Create parent directories of *path* if they don't exist.""" + os.makedirs(os.path.dirname(path), exist_ok=True) + + +def store_block( + dest_path: str, + buffer: memoryview, + offset: int, + block_size: int, +) -> None: + """ + Store callback: Writes to a temp file then atomically replaces the destination. + """ + # Check if block already exists to avoid redundant writes + if os.path.exists(dest_path): + return + + tmp_path = dest_path + _get_tmp_suffix() + # Ensure parent directories exist + _ensure_dirs(dest_path) + + # Write block atomically. Cast to a flat byte view so the slice uses byte + # indices; the raw memoryview may be multi-dimensional with itemsize > 1. + view_slice = buffer.cast("B")[offset : offset + block_size] + try: + fd = os.open( + tmp_path, + os.O_CREAT | os.O_EXCL | os.O_WRONLY | os.O_TRUNC | O_DIRECT, + 0o644, + ) + try: + written = os.write(fd, view_slice) + if written < len(view_slice): + raise OSError( + f"Short write: expected {len(view_slice)} bytes, wrote {written}" + ) + finally: + os.close(fd) + os.replace(tmp_path, dest_path) + except Exception: + try: + os.remove(tmp_path) + except OSError as cleanup_exc: + logger.warning("Failed to remove temp file %s: %s", tmp_path, cleanup_exc) + raise + + +def load_block( + source_path: str, + view: memoryview, + offset: int, + block_size: int, +) -> None: + """ + Load callback: read one KV block from disk. Remove the file on failure. + """ + fd: int | None = None + view_slice = view.cast("B")[offset : offset + block_size] + try: + fd = os.open(source_path, os.O_RDONLY | O_DIRECT) + bytes_read = os.readv(fd, [view_slice]) + if bytes_read < block_size: + raise OSError(f"Short read: expected {block_size} bytes, read {bytes_read}") + except Exception: + try: + os.remove(source_path) + except OSError as cleanup_exc: + logger.warning( + "Failed to remove unreadable file %s: %s", source_path, cleanup_exc + ) + raise + finally: + if fd is not None: + os.close(fd) diff --git a/vllm/v1/kv_offload/tiering/fs/manager.py b/vllm/v1/kv_offload/tiering/fs/manager.py new file mode 100644 index 000000000000..25318b760d94 --- /dev/null +++ b/vllm/v1/kv_offload/tiering/fs/manager.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +FileSystemTierManager: Pure-Python file system secondary tier for KV cache offloading. + +Store path: + Data is written to a temp file () via os.write, + then os.replace'd to the final path (without .tmp). + +Load path: + Data is read from the block file directly via os.readv into the + provided memoryview slice. + +File naming: _r//_g/.bin + (hash-based subdirectories to limit directory fan-out) +""" + +import functools +import json +import os +from collections.abc import Iterable +from typing import TYPE_CHECKING + +from vllm.logger import init_logger +from vllm.v1.kv_offload.base import OffloadKey, ReqContext +from vllm.v1.kv_offload.file_mapper import FileMapper +from vllm.v1.kv_offload.tiering.base import ( + JobMetadata, + JobResult, + SecondaryTierManager, +) +from vllm.v1.kv_offload.tiering.fs.io import load_block, store_block +from vllm.v1.kv_offload.tiering.fs.thread_pool import DualQueueThreadPool + +if TYPE_CHECKING: + from vllm.v1.kv_offload.base import OffloadingSpec + +logger = init_logger(__name__) + + +class FileSystemTierManager(SecondaryTierManager): + """ + Pure-Python disk-backed secondary tier. + + Read-priority threads service load jobs preferentially; write-priority + threads service store jobs preferentially. Both groups can drain either + queue, so neither starves. + + submit_store / submit_load are non-blocking: they enqueue tasks and return. + get_finished() polls job completion and returns completed JobResults. + + """ + + def __init__( + self, + offloading_spec: "OffloadingSpec", + primary_kv_view: memoryview, + tier_type: str, + root_dir: str, + n_read_threads: int = 16, + n_write_threads: int = 16, + ): + """ + Args: + offloading_spec: contains the vllm_config, kv_cache_config + and block_size_factor. + primary_kv_view: Memoryview of the primary tier's CPU KV cache. + tier_type: Tier type identifier, set by SecondaryTierFactory. + root_dir: Root directory for block files. + n_read_threads: Number of read-priority I/O threads. + n_write_threads: Number of write-priority I/O threads. + """ + super().__init__(offloading_spec, primary_kv_view, tier_type) + + # Extract block size from primary view + assert primary_kv_view.strides is not None, ( + "primary_kv_view.strides cannot be None" + ) + self._block_size: int = primary_kv_view.strides[0] + + # Create file mapper + self.file_mapper = FileMapper.from_offloading_spec( + root_dir=root_dir, + offloading_spec=offloading_spec, + gpu_blocks_per_file=offloading_spec.block_size_factor, + ) + + # Write config file + config_path = self.file_mapper.get_config_file_path() + os.makedirs(os.path.dirname(config_path), exist_ok=True) + if not os.path.exists(config_path): + with open(config_path, "w") as f: + json.dump( + self.file_mapper.get_run_config(), f, indent=2, sort_keys=True + ) + + self._pool = DualQueueThreadPool( + n_read_threads, + n_write_threads, + thread_name_prefix="vllm_kv_py_fs", + ) + + def lookup( + self, key: OffloadKey, req_context: ReqContext | None = None + ) -> bool | None: + return os.path.exists(self.file_mapper.get_file_name(key)) + + def submit_store(self, job_metadata: JobMetadata) -> None: + tasks = ( + functools.partial( + store_block, + self.file_mapper.get_file_name(key), + self._primary_kv_view, + int(bid) * self._block_size, + self._block_size, + ) + for key, bid in zip(job_metadata.keys, job_metadata.block_ids) + ) + self._pool.enqueue_store(job_metadata.job_id, len(job_metadata.keys), tasks) + + def submit_load(self, job_metadata: JobMetadata) -> None: + tasks = ( + functools.partial( + load_block, + self.file_mapper.get_file_name(key), + self._primary_kv_view, + int(bid) * self._block_size, + self._block_size, + ) + for key, bid in zip(job_metadata.keys, job_metadata.block_ids) + ) + self._pool.enqueue_load(job_metadata.job_id, len(job_metadata.keys), tasks) + + def get_finished(self) -> Iterable[JobResult]: + """ + Collect completed jobs from the finished-jobs queue. + """ + return ( + JobResult(job_id=job_id, success=success) + for job_id, success in self._pool.get_finished() + ) + + def shutdown(self) -> None: + """ + Release resources held by this tier. + + Shuts down the thread pool, clearing pending tasks and waiting for + active threads to complete. + """ + self._pool.shutdown(wait=True) diff --git a/vllm/v1/kv_offload/tiering/fs/thread_pool.py b/vllm/v1/kv_offload/tiering/fs/thread_pool.py new file mode 100644 index 000000000000..80704babd7e5 --- /dev/null +++ b/vllm/v1/kv_offload/tiering/fs/thread_pool.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Thread pool: + Two queues (load, store) and two sets of threads: + - Load-priority threads: drain the load queue first, then the store queue. + - Store-priority threads: drain the store queue first, then the load queue. + Load jobs are enqueued to the load queue; store jobs to the store queue. +""" + +import threading +from collections import deque +from collections.abc import Callable, Iterable + +from vllm.logger import init_logger +from vllm.v1.kv_offload.tiering.base import JobId + +logger = init_logger(__name__) + + +class JobState: + """ + Thread-safe completion tracker for a set of per-block I/O tasks. + + Each task calls task_done(success) when it finishes. + """ + + __slots__ = ("_job_id", "_n_tasks", "_completed", "_success", "_lock") + + def __init__(self, job_id: JobId, n_tasks: int) -> None: + self._job_id: JobId = job_id + self._n_tasks = n_tasks + self._completed = 0 + self._success = True + self._lock = threading.Lock() + + @property + def job_id(self) -> JobId: + return self._job_id + + def task_done(self, success: bool) -> tuple[bool, bool]: + """Returns if job completed and success flag""" + with self._lock: + self._completed += 1 + if not success: + self._success = False + return self._completed == self._n_tasks, self._success + + +class DualQueueThreadPool: + """ + Thread pool with two task queues (load and store) and two thread groups. + + Load-priority threads drain the load queue first, then fall back to the + store queue. Store-priority threads do the reverse. Both queues share + a single condition variable. + """ + + def __init__( + self, + n_read_threads: int, + n_write_threads: int, + thread_name_prefix: str = "fs_secondary_tier", + ) -> None: + self._load_q: deque = deque() + self._store_q: deque = deque() + self._condition = threading.Condition(threading.Lock()) + self._stop = False + self._threads: list[threading.Thread] = [] + self._finished_q: deque[tuple[JobId, bool]] = deque() + + for i in range(n_read_threads): + t = threading.Thread( + target=self._worker, + args=(True,), + name=f"{thread_name_prefix}_l{i}", + daemon=True, + ) + t.start() + self._threads.append(t) + + for i in range(n_write_threads): + t = threading.Thread( + target=self._worker, + args=(False,), + name=f"{thread_name_prefix}_s{i}", + daemon=True, + ) + t.start() + self._threads.append(t) + + def enqueue_load( + self, + job_id: JobId, + n_tasks: int, + tasks: Iterable[Callable], + ) -> None: + """Enqueue load tasks for a job (high-priority for load-priority threads).""" + state = JobState(job_id, n_tasks) + with self._condition: + for fn in tasks: + self._load_q.append((fn, state)) + self._condition.notify(n_tasks) + + def enqueue_store( + self, + job_id: JobId, + n_tasks: int, + tasks: Iterable[Callable], + ) -> None: + """Enqueue store tasks for a job (high-priority for store-priority threads).""" + state = JobState(job_id, n_tasks) + with self._condition: + for fn in tasks: + self._store_q.append((fn, state)) + self._condition.notify(n_tasks) + + def get_finished(self) -> list[tuple[JobId, bool]]: + jobs = [] + while self._finished_q: + jobs.append(self._finished_q.popleft()) + return jobs + + def shutdown(self, wait: bool = True) -> None: + with self._condition: + self._stop = True + self._load_q.clear() + self._store_q.clear() + self._condition.notify_all() + if wait: + for t in self._threads: + t.join() + + def _worker(self, load_priority: bool) -> None: + # Wait for tasks, process from primary queue first, fall back to secondary. + while True: + with self._condition: + self._condition.wait_for( + lambda: self._stop or self._load_q or self._store_q + ) + if self._stop: + return + primary = self._load_q if load_priority else self._store_q + secondary = self._store_q if load_priority else self._load_q + task, state = primary.popleft() if primary else secondary.popleft() + try: + task() + job_finished, success = state.task_done(True) + except Exception as exc: + logger.error( + "FileSystemTierManagerPython: job %s block I/O failed: %s", + state.job_id, + exc, + ) + job_finished, success = state.task_done(False) + + if job_finished: + self._finished_q.append((state.job_id, success))