diff --git a/tests/v1/kv_offload/test_cpu_manager.py b/tests/v1/kv_offload/test_cpu_manager.py index ac44c04db732..eea0367bf503 100644 --- a/tests/v1/kv_offload/test_cpu_manager.py +++ b/tests/v1/kv_offload/test_cpu_manager.py @@ -12,9 +12,8 @@ OffloadingEvent, PrepareStoreOutput, ) -from vllm.v1.kv_offload.arc_manager import ARCOffloadingManager -from vllm.v1.kv_offload.backends.cpu import CPUBackend -from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager +from vllm.v1.kv_offload.cpu.manager import CPUOffloadingManager +from vllm.v1.kv_offload.cpu.policies.arc import ARCCachePolicy from vllm.v1.kv_offload.mediums import CPULoadStoreSpec @@ -79,12 +78,12 @@ def to_hash_sets(int_sets: tuple[set[int], ...]) -> tuple[set[BlockHash], ...]: assert tuple(stores) == to_hash_sets(expected_stores) -@pytest.mark.parametrize("manager_class", [LRUOffloadingManager, ARCOffloadingManager]) -def test_already_stored_block_not_evicted_during_prepare_store(manager_class): +@pytest.mark.parametrize("eviction_policy", ["lru", "arc"]) +def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy): """ Regression test: a block that is already stored must not be evicted by prepare_store() when it needs to make room for new blocks. - Applies to both LRUOffloadingManager and ARCOffloadingManager. + Applies to both lru and arc policies. Scenario: - Store blocks [1, 2] and complete. @@ -96,8 +95,12 @@ def test_already_stored_block_not_evicted_during_prepare_store(manager_class): - After complete_store([2, 3, 4, 5]), block 2 must still be present. """ block_size = 256 - cpu_backend = CPUBackend(block_size=block_size, num_blocks=4) - manager = manager_class(cpu_backend, enable_events=True) + manager = CPUOffloadingManager( + block_size=block_size, + num_blocks=4, + cache_policy=eviction_policy, + enable_events=True, + ) # store [1, 2] and complete manager.prepare_store(to_hashes([1, 2])) @@ -129,12 +132,13 @@ def test_already_stored_block_not_evicted_during_prepare_store(manager_class): def test_cpu_manager(): """ - Tests LRUOffloadingManager with a CPUBackend. + Tests CPUOffloadingManager with lru policy. """ # initialize a CPU backend with a capacity of 4 blocks block_size = 256 - cpu_backend = CPUBackend(block_size=block_size, num_blocks=4) - cpu_manager = LRUOffloadingManager(cpu_backend, enable_events=True) + cpu_manager = CPUOffloadingManager( + block_size=block_size, num_blocks=4, cache_policy="lru", enable_events=True + ) # prepare store [1, 2] prepare_store_output = cpu_manager.prepare_store(to_hashes([1, 2])) @@ -241,13 +245,15 @@ def test_cpu_manager(): def test_arc_manager_basic(): """ - Tests ARCOffloadingManager basic operations with a CPUBackend. + Tests CPUOffloadingManager with arc policy. Verifies that ARC handles store, load, and lookup operations correctly. """ - # initialize a CPU backend with a capacity of 4 blocks block_size = 256 - cpu_backend = CPUBackend(block_size=block_size, num_blocks=4) - arc_manager = ARCOffloadingManager(cpu_backend, enable_events=True) + arc_manager = CPUOffloadingManager( + block_size=block_size, num_blocks=4, cache_policy="arc", enable_events=True + ) + arc_policy = arc_manager._policy + assert isinstance(arc_policy, ARCCachePolicy) # prepare store [1, 2] prepare_store_output = arc_manager.prepare_store(to_hashes([1, 2])) @@ -278,8 +284,8 @@ def test_arc_manager_basic(): assert arc_manager.lookup(to_hashes([1, 2, 3])) == 2 # blocks should be in T1 (recent) - assert len(arc_manager.t1) == 2 - assert len(arc_manager.t2) == 0 + assert len(arc_policy.t1) == 2 + assert len(arc_policy.t2) == 0 def test_arc_manager_t1_to_t2_promotion(): @@ -288,23 +294,26 @@ def test_arc_manager_t1_to_t2_promotion(): This is a key feature of ARC's adaptive behavior. """ block_size = 256 - cpu_backend = CPUBackend(block_size=block_size, num_blocks=4) - arc_manager = ARCOffloadingManager(cpu_backend, enable_events=False) + arc_manager = CPUOffloadingManager( + block_size=block_size, num_blocks=4, cache_policy="arc", enable_events=False + ) + arc_policy = arc_manager._policy + assert isinstance(arc_policy, ARCCachePolicy) # store and complete block 1 arc_manager.prepare_store(to_hashes([1])) arc_manager.complete_store(to_hashes([1])) # block 1 starts in T1 (recent) - assert to_hashes([1])[0] in arc_manager.t1 - assert to_hashes([1])[0] not in arc_manager.t2 + assert to_hashes([1])[0] in arc_policy.t1 + assert to_hashes([1])[0] not in arc_policy.t2 # touch block 1 (simulate second access) arc_manager.touch(to_hashes([1])) # block 1 should now be in T2 (frequent) - assert to_hashes([1])[0] not in arc_manager.t1 - assert to_hashes([1])[0] in arc_manager.t2 + assert to_hashes([1])[0] not in arc_policy.t1 + assert to_hashes([1])[0] in arc_policy.t2 def test_arc_manager_eviction_with_load(): @@ -313,8 +322,9 @@ def test_arc_manager_eviction_with_load(): Verifies that blocks being loaded (ref_cnt > 0) cannot be evicted. """ block_size = 256 - cpu_backend = CPUBackend(block_size=block_size, num_blocks=4) - arc_manager = ARCOffloadingManager(cpu_backend, enable_events=True) + arc_manager = CPUOffloadingManager( + block_size=block_size, num_blocks=4, cache_policy="arc", enable_events=True + ) # prepare and complete store [1, 2, 3, 4] prepare_store_output = arc_manager.prepare_store(to_hashes([1, 2, 3, 4])) @@ -354,28 +364,31 @@ def test_arc_manager_adaptive_target(): When a block in B2 is accessed, target_t1_size decreases. """ block_size = 256 - cpu_backend = CPUBackend(block_size=block_size, num_blocks=2) - arc_manager = ARCOffloadingManager(cpu_backend, enable_events=False) + arc_manager = CPUOffloadingManager( + block_size=block_size, num_blocks=2, cache_policy="arc", enable_events=False + ) + arc_policy = arc_manager._policy + assert isinstance(arc_policy, ARCCachePolicy) # store blocks 1, 2 (fills cache) arc_manager.prepare_store(to_hashes([1, 2])) arc_manager.complete_store(to_hashes([1, 2])) - initial_target = arc_manager.target_t1_size + initial_target = arc_policy.target_t1_size # store block 3, evicting block 1 (moves to B1 ghost list) arc_manager.prepare_store(to_hashes([3])) arc_manager.complete_store(to_hashes([3])) # block 1 should be in B1 (ghost list) - assert to_hashes([1])[0] in arc_manager.b1 + assert to_hashes([1])[0] in arc_policy.b1 # touch block 1 (cache miss, but in B1) # this should increase target_t1_size (favor recency) arc_manager.touch(to_hashes([1])) # target should have increased - assert arc_manager.target_t1_size > initial_target + assert arc_policy.target_t1_size > initial_target def test_arc_manager_t1_t2_eviction_policy(): @@ -384,8 +397,11 @@ def test_arc_manager_t1_t2_eviction_policy(): If |T1| >= target_t1_size, evict from T1, otherwise from T2. """ block_size = 256 - cpu_backend = CPUBackend(block_size=block_size, num_blocks=4) - arc_manager = ARCOffloadingManager(cpu_backend, enable_events=False) + arc_manager = CPUOffloadingManager( + block_size=block_size, num_blocks=4, cache_policy="arc", enable_events=False + ) + arc_policy = arc_manager._policy + assert isinstance(arc_policy, ARCCachePolicy) # store blocks 1, 2, 3, 4 arc_manager.prepare_store(to_hashes([1, 2, 3, 4])) @@ -395,12 +411,12 @@ def test_arc_manager_t1_t2_eviction_policy(): arc_manager.touch(to_hashes([3, 4])) # now: T1 = {1, 2}, T2 = {3, 4} - assert len(arc_manager.t1) == 2 - assert len(arc_manager.t2) == 2 + assert len(arc_policy.t1) == 2 + assert len(arc_policy.t2) == 2 # set target_t1_size to prefer evicting from T1 # (when |T1| >= target, evict from T1) - arc_manager.target_t1_size = 1 + arc_policy.target_t1_size = 1 # store block 5, should evict from T1 (block 1, LRU in T1) output = arc_manager.prepare_store(to_hashes([5])) @@ -410,9 +426,9 @@ def test_arc_manager_t1_t2_eviction_policy(): arc_manager.complete_store(to_hashes([5])) # block 1 should be in B1 (ghost list) - assert to_hashes([1])[0] in arc_manager.b1 + assert to_hashes([1])[0] in arc_policy.b1 # block 5 should be in T1 - assert to_hashes([5])[0] in arc_manager.t1 + assert to_hashes([5])[0] in arc_policy.t1 def test_arc_manager_ghost_list_bounds(): @@ -421,8 +437,11 @@ def test_arc_manager_ghost_list_bounds(): They should be capped at cache_capacity. """ block_size = 256 - cpu_backend = CPUBackend(block_size=block_size, num_blocks=2) - arc_manager = ARCOffloadingManager(cpu_backend, enable_events=False) + arc_manager = CPUOffloadingManager( + block_size=block_size, num_blocks=2, cache_policy="arc", enable_events=False + ) + arc_policy = arc_manager._policy + assert isinstance(arc_policy, ARCCachePolicy) # fill cache with blocks 1, 2 arc_manager.prepare_store(to_hashes([1, 2])) @@ -434,8 +453,8 @@ def test_arc_manager_ghost_list_bounds(): arc_manager.complete_store(to_hashes([i])) # ghost lists should not exceed cache_capacity - assert len(arc_manager.b1) <= arc_manager.cache_capacity - assert len(arc_manager.b2) <= arc_manager.cache_capacity + assert len(arc_policy.b1) <= arc_policy.cache_capacity + assert len(arc_policy.b2) <= arc_policy.cache_capacity def test_arc_manager_touch_ordering(): @@ -444,8 +463,11 @@ def test_arc_manager_touch_ordering(): Similar to LRU test but verifies T1/T2 ordering. """ block_size = 256 - cpu_backend = CPUBackend(block_size=block_size, num_blocks=4) - arc_manager = ARCOffloadingManager(cpu_backend, enable_events=True) + arc_manager = CPUOffloadingManager( + block_size=block_size, num_blocks=4, cache_policy="arc", enable_events=True + ) + arc_policy = arc_manager._policy + assert isinstance(arc_policy, ARCCachePolicy) # store blocks 1, 2, 3, 4 arc_manager.prepare_store(to_hashes([1, 2, 3, 4])) @@ -459,8 +481,8 @@ def test_arc_manager_touch_ordering(): arc_manager.touch(to_hashes([1, 3, 4])) # T1 = {2}, T2 = {1, 3, 4} (in that order, with 4 most recent) - assert len(arc_manager.t1) == 1 - assert len(arc_manager.t2) == 3 + assert len(arc_policy.t1) == 1 + assert len(arc_policy.t2) == 3 # store block 5, should evict from T1 (block 2, only one in T1) prepare_store_output = arc_manager.prepare_store(to_hashes([5])) @@ -480,8 +502,11 @@ def test_arc_manager_failed_store(): Similar to LRU test but for ARC. """ block_size = 256 - cpu_backend = CPUBackend(block_size=block_size, num_blocks=4) - arc_manager = ARCOffloadingManager(cpu_backend, enable_events=True) + arc_manager = CPUOffloadingManager( + block_size=block_size, num_blocks=4, cache_policy="arc", enable_events=True + ) + arc_policy = arc_manager._policy + assert isinstance(arc_policy, ARCCachePolicy) # store blocks 1, 2, 3, 4 arc_manager.prepare_store(to_hashes([1, 2, 3, 4])) @@ -498,12 +523,12 @@ def test_arc_manager_failed_store(): # block 5 should not be in cache assert arc_manager.lookup(to_hashes([5])) == 0 # block 5 should not be in T1 or T2 - assert to_hashes([5])[0] not in arc_manager.t1 - assert to_hashes([5])[0] not in arc_manager.t2 + assert to_hashes([5])[0] not in arc_policy.t1 + assert to_hashes([5])[0] not in arc_policy.t2 # evicted block should still be gone (in B1 ghost list) evicted_hash = prepare_store_output.block_hashes_evicted[0] - assert evicted_hash in arc_manager.b1 + assert evicted_hash in arc_policy.b1 def test_arc_manager_full_scenario(): @@ -512,8 +537,11 @@ def test_arc_manager_full_scenario(): Similar to the full LRU test but adapted for ARC behavior. """ block_size = 256 - cpu_backend = CPUBackend(block_size=block_size, num_blocks=4) - arc_manager = ARCOffloadingManager(cpu_backend, enable_events=True) + arc_manager = CPUOffloadingManager( + block_size=block_size, num_blocks=4, cache_policy="arc", enable_events=True + ) + arc_policy = arc_manager._policy + assert isinstance(arc_policy, ARCCachePolicy) # store [1, 2] arc_manager.prepare_store(to_hashes([1, 2])) @@ -529,8 +557,8 @@ def test_arc_manager_full_scenario(): arc_manager.touch(to_hashes([2, 3])) # T1 has {4, 5}, T2 has {2, 3} - assert len(arc_manager.t1) == 2 - assert len(arc_manager.t2) == 2 + assert len(arc_policy.t1) == 2 + assert len(arc_policy.t2) == 2 # store [6] -> should evict from T1 (4 is oldest in T1) prepare_store_output = arc_manager.prepare_store(to_hashes([6])) @@ -548,11 +576,12 @@ def test_arc_manager_full_scenario(): def test_filter_reused_manager(): """ - Tests FilterReusedOffloadingManager with a CPUBackend. + Tests FilterReusedOffloadingManager with a CPUOffloadingManager. """ block_size = 256 - cpu_backend = CPUBackend(block_size=block_size, num_blocks=4) - lru_manager = LRUOffloadingManager(cpu_backend, enable_events=True) + lru_manager = CPUOffloadingManager( + block_size=block_size, num_blocks=4, cache_policy="lru", enable_events=True + ) from vllm.v1.kv_offload.reuse_manager import FilterReusedOffloadingManager diff --git a/vllm/v1/kv_offload/arc_manager.py b/vllm/v1/kv_offload/arc_manager.py deleted file mode 100644 index e3bb54a2cac3..000000000000 --- a/vllm/v1/kv_offload/arc_manager.py +++ /dev/null @@ -1,244 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import OrderedDict -from collections.abc import Iterable - -from vllm.v1.core.kv_cache_utils import BlockHash -from vllm.v1.kv_offload.abstract import ( - LoadStoreSpec, - OffloadingEvent, - OffloadingManager, - PrepareStoreOutput, -) -from vllm.v1.kv_offload.backend import Backend, BlockStatus - - -class ARCOffloadingManager(OffloadingManager): - """ - An OffloadingManager implementing the ARC (Adaptive Replacement Cache) - eviction policy with a pluggable backend. - - Data Structures: - T1: Recent cache containing blocks accessed once. - T2: Frequent cache containing blocks accessed multiple times. - B1/B2: Ghost lists tracking recently evicted blocks from T1/T2. - target_t1_size: Adaptive target size for the T1 partition. - - Algorithm Flow: - 1. Cache lookup (lookup): - Searches T1 and T2 for block hashes and counts consecutive hits - until a miss or non-ready block is encountered. - - 2. Cache touch (touch) - Adaptive Learning: - For each block_hash (in reverse order): - - If in T1: Move to T2 (promotion from recent to frequent). - - If in T2: Move to MRU position (end of queue). - - If in B1 ghost list: Increase target_t1_size. - - If in B2 ghost list: Decrease target_t1_size. - - 3. Block eviction (prepare_store) - Adaptive Replacement: - Determines eviction source based on adaptive target: - - If T1 size > target_t1_size: Evict from T1, add to B1. - - Otherwise: Evict from T2, add to B2. - Finally, bound each ghost list size. - - 4. Block insertion (prepare_store): - New blocks are always inserted into T1 and removed from B1/B2 if - present. Blocks may later be promoted to T2 during touch operations. - - Adaptive Behavior: - The algorithm self-tunes the recency vs. frequency trade-off: - - B1 hit: Recent access patterns matter more → increase T1. - - B2 hit: Frequent access patterns matter more → decrease T1. - """ - - def __init__(self, backend: Backend, enable_events: bool = False): - self.backend: Backend = backend - self.target_t1_size: float = 0.0 - self.t1: OrderedDict[BlockHash, BlockStatus] = OrderedDict() - self.t2: OrderedDict[BlockHash, BlockStatus] = OrderedDict() - # block_hash -> None (only care about presence) - self.b1: OrderedDict[BlockHash, None] = OrderedDict() - self.b2: OrderedDict[BlockHash, None] = OrderedDict() - self.events: list[OffloadingEvent] | None = [] if enable_events else None - self.cache_capacity: int = self.backend.get_num_free_blocks() - - def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None: - hit_count = 0 - for block_hash in block_hashes: - block = self.t1.get(block_hash) or self.t2.get(block_hash) - if block is None or not block.is_ready: - break - hit_count += 1 - return hit_count - - def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec: - blocks = [] - for block_hash in block_hashes: - block = self.t1.get(block_hash) or self.t2.get(block_hash) - assert block is not None, f"Block {block_hash!r} not found in cache" - assert block.is_ready, f"Block {block_hash!r} is not ready for reading" - - block.ref_cnt += 1 - blocks.append(block) - - return self.backend.get_load_store_spec(block_hashes, blocks) - - def touch(self, block_hashes: Iterable[BlockHash]): - for block_hash in reversed(list(block_hashes)): - if block_hash in self.t1: - block = self.t1.pop(block_hash) - if not block.is_ready: - # block was just prepared to be stored, not really touched twice - # keep it in T1 and mark as most recently used - self.t1[block_hash] = block - else: - self.t2[block_hash] = block - - elif block_hash in self.t2: - self.t2.move_to_end(block_hash) - - elif block_hash in self.b1: - delta = max(1, len(self.b2) / len(self.b1)) - self.target_t1_size = min( - self.target_t1_size + delta, self.cache_capacity - ) - # move to MRU position (end) to keep it fresh in the ghost list - self.b1.move_to_end(block_hash) - - elif block_hash in self.b2: - delta = max(1, len(self.b1) / len(self.b2)) - self.target_t1_size = max(self.target_t1_size - delta, 0) - # move to MRU position (end) to keep it fresh in the ghost list - self.b2.move_to_end(block_hash) - - def complete_load(self, block_hashes: Iterable[BlockHash]): - for block_hash in block_hashes: - block = self.t1.get(block_hash) or self.t2.get(block_hash) - assert block is not None, f"Block {block_hash!r} not found" - assert block.ref_cnt > 0, f"Block {block_hash!r} ref_cnt is already 0" - - block.ref_cnt -= 1 - - def prepare_store( - self, block_hashes: Iterable[BlockHash] - ) -> PrepareStoreOutput | None: - block_hashes_list = list(block_hashes) - - block_hashes_to_store = [] - for block_hash in block_hashes_list: - if block_hash not in self.t1 and block_hash not in self.t2: - block_hashes_to_store.append(block_hash) - - if not block_hashes_to_store: - return PrepareStoreOutput( - block_hashes_to_store=[], - store_spec=self.backend.get_load_store_spec([], []), - block_hashes_evicted=[], - ) - - num_blocks_to_evict = ( - len(block_hashes_to_store) - self.backend.get_num_free_blocks() - ) - - to_evict = [] - if num_blocks_to_evict > 0: - # Blocks from the original input are excluded from eviction candidates: - # a block that was already stored must remain in the cache after this call. - protected = set(block_hashes_list) - while num_blocks_to_evict > 0: - block_to_evict = None - if len(self.t1) >= int(self.target_t1_size): - # try to evict the least recently used (oldest) block from T1 - for block_hash, block in self.t1.items(): - if block.ref_cnt == 0 and block_hash not in protected: - block_to_evict = (block_hash, block) - eviction_t = self.t1 - eviction_b = self.b1 - break - if not block_to_evict: - # try to evict the least recently used (oldest) block from T2 - for block_hash, block in self.t2.items(): - if block.ref_cnt == 0 and block_hash not in protected: - block_to_evict = (block_hash, block) - eviction_t = self.t2 - eviction_b = self.b2 - break - else: - # cannot evict enough blocks, cache is full of in-use items - return None - - block_hash, block = block_to_evict - del eviction_t[block_hash] - eviction_b[block_hash] = None - to_evict.append(block_hash) - self.backend.free(block) - num_blocks_to_evict -= 1 - - for b in [self.b1, self.b2]: - for i in range(len(b) - self.cache_capacity): - b.popitem(last=False) - - if to_evict and self.events is not None: - self.events.append( - OffloadingEvent( - block_hashes=to_evict, - block_size=self.backend.block_size, - medium=self.backend.medium, - removed=True, - ) - ) - - blocks = self.backend.allocate_blocks(block_hashes_to_store) - assert len(blocks) == len(block_hashes_to_store), ( - "Backend did not allocate the expected number of blocks" - ) - - for block_hash, block in zip(block_hashes_to_store, blocks): - self.t1[block_hash] = block - - self.b1.pop(block_hash, None) - self.b2.pop(block_hash, None) - - store_spec = self.backend.get_load_store_spec(block_hashes_to_store, blocks) - - return PrepareStoreOutput( - block_hashes_to_store=block_hashes_to_store, - store_spec=store_spec, - block_hashes_evicted=to_evict, - ) - - def complete_store(self, block_hashes: Iterable[BlockHash], success: bool = True): - stored_block_hashes: list[BlockHash] = [] - - if success: - for block_hash in block_hashes: - block = self.t1.get(block_hash) or self.t2.get(block_hash) - - if block is not None and not block.is_ready: - block.ref_cnt = 0 - stored_block_hashes.append(block_hash) - else: - for block_hash in block_hashes: - block = self.t1.pop(block_hash, None) - - if block is None: - block = self.t2.pop(block_hash, None) - - if block is not None and not block.is_ready: - self.backend.free(block) - - if stored_block_hashes and self.events is not None: - self.events.append( - OffloadingEvent( - block_hashes=stored_block_hashes, - block_size=self.backend.block_size, - medium=self.backend.medium, - removed=False, - ) - ) - - def take_events(self) -> Iterable[OffloadingEvent]: - if self.events is not None: - yield from self.events - self.events.clear() diff --git a/vllm/v1/kv_offload/backend.py b/vllm/v1/kv_offload/backend.py deleted file mode 100644 index 538f7bf0584b..000000000000 --- a/vllm/v1/kv_offload/backend.py +++ /dev/null @@ -1,97 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import ctypes -from abc import ABC, abstractmethod -from collections.abc import Iterable - -from vllm.v1.core.kv_cache_utils import BlockHash -from vllm.v1.kv_offload.abstract import LoadStoreSpec - - -class BlockStatus(ctypes.Structure): - """ - Offloading status for a single block of KV data. - Holds the following information: - - ref_cnt - the current number of transfers using this block as a source. - A value of -1 indicates the block is not yet ready to be read. - load_store_spec - backend-specific information on how to actually - read/write the block. - """ - - _fields_ = [("ref_cnt", ctypes.c_int32)] - - def __init__(self): - super().__init__() - # initialize block as "not ready" (ref_cnt = -1) - self.ref_cnt = -1 - - @property - def is_ready(self) -> bool: - """ - Returns whether the block is ready to be read. - """ - return self.ref_cnt >= 0 - - -class Backend(ABC): - """ - An abstract class for allocating and returning specs for writing - KV blocks to some backend. - """ - - def __init__(self, block_size: int, medium: str): - self.block_size = block_size - self.medium = medium - - @abstractmethod - def get_num_free_blocks(self): - """ - Returns the number of current number of blocks that can be allocated. - """ - pass - - @abstractmethod - def allocate_blocks(self, block_hashes: list[BlockHash]) -> list[BlockStatus]: - """ - Allocate space for writing blocks. - This method assumes there is enough space for allocation. - It is unsafe to use without checking get_num_free_blocks beforehand. - - Args: - block_hashes: the hashes identifying the blocks to be written. - - Returns: - A list of BlockStatus for the allocated blocks. - The ref_cnt of each returned item will be -1, meaning the block - is not yet ready to be read. - """ - pass - - @abstractmethod - def free(self, block: BlockStatus): - """ - Free a previously allocated block. - You should only call this function with blocks returned by - allocate_blocks, and only once per each block. - - Args: - block: The block to be freed. - """ - pass - - def get_load_store_spec( - self, block_hashes: Iterable[BlockHash], blocks: Iterable[BlockStatus] - ) -> LoadStoreSpec: - """ - Get backend-specific information on how to read/write blocks. - - Args: - block_hashes: the list of block hashes identifying the blocks. - blocks: the list of blocks. - - Returns: - A LoadStoreSpec that can be used by a worker - to read/write the blocks. - """ - raise NotImplementedError diff --git a/vllm/v1/kv_offload/backends/cpu.py b/vllm/v1/kv_offload/backends/cpu.py deleted file mode 100644 index 736cf37853cd..000000000000 --- a/vllm/v1/kv_offload/backends/cpu.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import ctypes -from collections.abc import Iterable - -from vllm.v1.core.kv_cache_utils import BlockHash -from vllm.v1.kv_offload.abstract import LoadStoreSpec -from vllm.v1.kv_offload.backend import Backend, BlockStatus -from vllm.v1.kv_offload.mediums import CPULoadStoreSpec - - -class CPUBlockStatus(BlockStatus): - _fields_ = BlockStatus._fields_ + [("block_id", ctypes.c_int64)] # type: ignore - - def __init__(self, block_id: int): - super().__init__() - self.block_id = block_id - - -class CPUBackend(Backend): - def __init__(self, block_size: int, num_blocks: int): - super().__init__(block_size=block_size, medium=CPULoadStoreSpec.medium()) - - self.num_blocks: int = num_blocks - self.num_allocated_blocks: int = 0 - self.allocated_blocks_free_list: list[int] = [] - - def get_num_free_blocks(self): - return ( - len(self.allocated_blocks_free_list) - + self.num_blocks - - self.num_allocated_blocks - ) - - def allocate_blocks(self, block_hashes: list[BlockHash]) -> list[BlockStatus]: - num_fresh_blocks = min( - len(block_hashes), self.num_blocks - self.num_allocated_blocks - ) - num_reused_blocks = len(block_hashes) - num_fresh_blocks - assert len(self.allocated_blocks_free_list) >= num_reused_blocks - - # allocate fresh blocks - blocks: list[BlockStatus] = [] - for _ in range(num_fresh_blocks): - blocks.append(CPUBlockStatus(self.num_allocated_blocks)) - self.num_allocated_blocks += 1 - - # allocate reused blocks - for _ in range(num_reused_blocks): - block_id = self.allocated_blocks_free_list.pop() - blocks.append(CPUBlockStatus(block_id)) - - return blocks - - def free(self, block: BlockStatus): - assert isinstance(block, CPUBlockStatus) - self.allocated_blocks_free_list.append(block.block_id) - - def get_load_store_spec( - self, block_hashes: Iterable[BlockHash], blocks: Iterable[BlockStatus] - ) -> LoadStoreSpec: - return CPULoadStoreSpec([block.block_id for block in blocks]) diff --git a/vllm/v1/kv_offload/backends/__init__.py b/vllm/v1/kv_offload/cpu/__init__.py similarity index 100% rename from vllm/v1/kv_offload/backends/__init__.py rename to vllm/v1/kv_offload/cpu/__init__.py diff --git a/vllm/v1/kv_offload/cpu/manager.py b/vllm/v1/kv_offload/cpu/manager.py new file mode 100644 index 000000000000..66f0e6736a9d --- /dev/null +++ b/vllm/v1/kv_offload/cpu/manager.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Literal + +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.kv_offload.abstract import ( + LoadStoreSpec, + OffloadingEvent, + OffloadingManager, + PrepareStoreOutput, +) +from vllm.v1.kv_offload.cpu.policies.abstract import BlockStatus, CachePolicy +from vllm.v1.kv_offload.cpu.policies.arc import ARCCachePolicy +from vllm.v1.kv_offload.cpu.policies.lru import LRUCachePolicy +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec + +_CACHE_POLICIES: dict[str, type[CachePolicy]] = { + "lru": LRUCachePolicy, + "arc": ARCCachePolicy, +} + + +class CPUOffloadingManager(OffloadingManager): + """ + An OffloadingManager with a pluggable CachePolicy (LRU or ARC). + + The manager owns all shared logic: ref-counting, event emission, + block pool management, and the prepare_store/complete_store skeletons. + Policy-specific block organization and eviction decisions are delegated + to the CachePolicy implementation. + """ + + def __init__( + self, + block_size: int, + num_blocks: int, + cache_policy: Literal["lru", "arc"] = "lru", + enable_events: bool = False, + ): + self.block_size: int = block_size + self.medium: str = CPULoadStoreSpec.medium() + self._num_blocks: int = num_blocks + self._num_allocated_blocks: int = 0 + self._free_list: list[int] = [] + self.events: list[OffloadingEvent] | None = [] if enable_events else None + policy_cls = _CACHE_POLICIES.get(cache_policy) + if policy_cls is None: + raise ValueError( + f"Unknown cache policy: {cache_policy!r}. " + f"Supported: {list(_CACHE_POLICIES)}" + ) + self._policy: CachePolicy = policy_cls(cache_capacity=num_blocks) + + # --- block pool --- + + def _get_num_free_blocks(self) -> int: + return len(self._free_list) + self._num_blocks - self._num_allocated_blocks + + def _allocate_blocks(self, block_hashes: list[BlockHash]) -> list[BlockStatus]: + num_fresh = min( + len(block_hashes), self._num_blocks - self._num_allocated_blocks + ) + num_reused = len(block_hashes) - num_fresh + assert len(self._free_list) >= num_reused + + # allocate fresh blocks + blocks: list[BlockStatus] = [] + for _ in range(num_fresh): + blocks.append(BlockStatus(self._num_allocated_blocks)) + self._num_allocated_blocks += 1 + + # allocate reused blocks + for _ in range(num_reused): + blocks.append(BlockStatus(self._free_list.pop())) + return blocks + + def _free_block(self, block: BlockStatus) -> None: + self._free_list.append(block.block_id) + + def _get_load_store_spec( + self, + block_hashes: Iterable[BlockHash], + blocks: Iterable[BlockStatus], + ) -> CPULoadStoreSpec: + return CPULoadStoreSpec([block.block_id for block in blocks]) + + # --- OffloadingManager interface --- + + def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None: + hit_count = 0 + for block_hash in block_hashes: + block = self._policy.get(block_hash) + if block is None or not block.is_ready: + break + hit_count += 1 + return hit_count + + def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec: + blocks = [] + for block_hash in block_hashes: + block = self._policy.get(block_hash) + assert block is not None, f"Block {block_hash!r} not found in cache" + assert block.is_ready, f"Block {block_hash!r} is not ready for reading" + block.ref_cnt += 1 + blocks.append(block) + return self._get_load_store_spec(block_hashes, blocks) + + def touch(self, block_hashes: Iterable[BlockHash]) -> None: + self._policy.touch(block_hashes) + + def complete_load(self, block_hashes: Iterable[BlockHash]) -> None: + for block_hash in block_hashes: + block = self._policy.get(block_hash) + assert block is not None, f"Block {block_hash!r} not found" + assert block.ref_cnt > 0, f"Block {block_hash!r} ref_cnt is already 0" + block.ref_cnt -= 1 + + def prepare_store( + self, block_hashes: Iterable[BlockHash] + ) -> PrepareStoreOutput | None: + block_hashes_list = list(block_hashes) + + # filter out blocks that are already stored + block_hashes_to_store = [ + bh for bh in block_hashes_list if self._policy.get(bh) is None + ] + + if not block_hashes_to_store: + return PrepareStoreOutput( + block_hashes_to_store=[], + store_spec=self._get_load_store_spec([], []), + block_hashes_evicted=[], + ) + + num_blocks_to_evict = len(block_hashes_to_store) - self._get_num_free_blocks() + + to_evict: list[BlockHash] = [] + if num_blocks_to_evict > 0: + # Blocks from the original input are excluded from eviction candidates: + # a block that was already stored must remain in the cache after this call. + protected = set(block_hashes_list) + evicted = self._policy.evict(num_blocks_to_evict, protected) + if evicted is None: + return None + for block_hash, block in evicted: + self._free_block(block) + to_evict.append(block_hash) + + if to_evict and self.events is not None: + self.events.append( + OffloadingEvent( + block_hashes=to_evict, + block_size=self.block_size, + medium=self.medium, + removed=True, + ) + ) + + blocks = self._allocate_blocks(block_hashes_to_store) + assert len(blocks) == len(block_hashes_to_store), ( + "Block pool did not allocate the expected number of blocks" + ) + + for block_hash, block in zip(block_hashes_to_store, blocks): + self._policy.insert(block_hash, block) + + # build store specs for allocated blocks + store_spec = self._get_load_store_spec(block_hashes_to_store, blocks) + + return PrepareStoreOutput( + block_hashes_to_store=block_hashes_to_store, + store_spec=store_spec, + block_hashes_evicted=to_evict, + ) + + def complete_store( + self, block_hashes: Iterable[BlockHash], success: bool = True + ) -> None: + stored_block_hashes: list[BlockHash] = [] + + if success: + for block_hash in block_hashes: + block = self._policy.get(block_hash) + if block is not None and not block.is_ready: + block.ref_cnt = 0 + stored_block_hashes.append(block_hash) + else: + for block_hash in block_hashes: + block = self._policy.get(block_hash) + if block is not None and not block.is_ready: + self._policy.remove(block_hash) + self._free_block(block) + + if stored_block_hashes and self.events is not None: + self.events.append( + OffloadingEvent( + block_hashes=stored_block_hashes, + block_size=self.block_size, + medium=self.medium, + removed=False, + ) + ) + + def take_events(self) -> Iterable[OffloadingEvent]: + if self.events is not None: + yield from self.events + self.events.clear() diff --git a/vllm/v1/kv_offload/cpu/policies/__init__.py b/vllm/v1/kv_offload/cpu/policies/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/v1/kv_offload/cpu/policies/abstract.py b/vllm/v1/kv_offload/cpu/policies/abstract.py new file mode 100644 index 000000000000..b45bb34cbd2e --- /dev/null +++ b/vllm/v1/kv_offload/cpu/policies/abstract.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ctypes +from abc import ABC, abstractmethod +from collections.abc import Iterable + +from vllm.v1.core.kv_cache_utils import BlockHash + + +class BlockStatus(ctypes.Structure): + """ + Offloading status for a single block of KV data. + Holds the following information: + + ref_cnt - the current number of transfers using this block as a source. + A value of -1 indicates the block is not yet ready to be read. + block_id - index of the physical CPU buffer slot. + """ + + _fields_ = [("ref_cnt", ctypes.c_int32), ("block_id", ctypes.c_int64)] + + def __init__(self, block_id: int): + super().__init__() + # initialize block as "not ready" (ref_cnt = -1) + self.ref_cnt = -1 + self.block_id = block_id + + @property + def is_ready(self) -> bool: + """ + Returns whether the block is ready to be read. + """ + return self.ref_cnt >= 0 + + +class CachePolicy(ABC): + """ + Encapsulates both block organization (data structures) and replacement + decisions (which block to evict). LRU and ARC differ in both dimensions — + ARC's ghost lists and target_t1_size live at the intersection of storage + and eviction, so they cannot be separated cleanly. + """ + + @abstractmethod + def __init__(self, cache_capacity: int) -> None: ... + + @abstractmethod + def get(self, block_hash: BlockHash) -> BlockStatus | None: + """Find block in data structures. Returns None if not present.""" + + @abstractmethod + def insert(self, block_hash: BlockHash, block: BlockStatus) -> None: + """Add a newly allocated block. For ARC: also removes from ghost lists.""" + + @abstractmethod + def remove(self, block_hash: BlockHash) -> None: + """Remove a block (used to clean up after a failed store).""" + + @abstractmethod + def touch(self, block_hashes: Iterable[BlockHash]) -> None: + """Mark blocks as recently used.""" + + @abstractmethod + def evict( + self, n: int, protected: set[BlockHash] + ) -> list[tuple[BlockHash, BlockStatus]] | None: + """ + Evict exactly n blocks, skipping any in protected. + + Returns a list of (block_hash, block) for the evicted blocks, + or None if n evictions cannot be satisfied. The operation is atomic: + if None is returned, no state changes are made. + + For ARC: ghost list cleanup (trimming to cache_capacity) is performed + at the end of a successful eviction. + """ diff --git a/vllm/v1/kv_offload/cpu/policies/arc.py b/vllm/v1/kv_offload/cpu/policies/arc.py new file mode 100644 index 000000000000..fdcb16badd45 --- /dev/null +++ b/vllm/v1/kv_offload/cpu/policies/arc.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import OrderedDict +from collections.abc import Iterable + +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.kv_offload.cpu.policies.abstract import BlockStatus, CachePolicy + + +class ARCCachePolicy(CachePolicy): + """ + ARC (Adaptive Replacement Cache) cache policy. + + Data Structures: + T1: Recent cache containing blocks accessed once. + T2: Frequent cache containing blocks accessed multiple times. + B1/B2: Ghost lists tracking recently evicted blocks from T1/T2. + target_t1_size: Adaptive target size for the T1 partition. + + Algorithm Flow: + 1. Cache lookup (lookup): + Searches T1 and T2 for block hashes and counts consecutive hits + until a miss or non-ready block is encountered. + + 2. Cache touch (touch) - Adaptive Learning: + For each block_hash (in reverse order): + - If in T1: Move to T2 (promotion from recent to frequent). + - If in T2: Move to MRU position (end of queue). + - If in B1 ghost list: Increase target_t1_size. + - If in B2 ghost list: Decrease target_t1_size. + + 3. Block eviction (evict) - Adaptive Replacement: + Determines eviction source based on adaptive target: + - If T1 size >= target_t1_size: Evict from T1, add to B1. + - Otherwise: Evict from T2, add to B2. + Finally, bound each ghost list size. + + 4. Block insertion (insert): + New blocks are always inserted into T1 and removed from B1/B2 if + present. Blocks may later be promoted to T2 during touch operations. + + Adaptive Behavior: + The algorithm self-tunes the recency vs. frequency trade-off: + - B1 hit: Recent access patterns matter more → increase T1. + - B2 hit: Frequent access patterns matter more → decrease T1. + """ + + def __init__(self, cache_capacity: int): + self.cache_capacity: int = cache_capacity + self.target_t1_size: float = 0.0 + self.t1: OrderedDict[BlockHash, BlockStatus] = OrderedDict() + self.t2: OrderedDict[BlockHash, BlockStatus] = OrderedDict() + # block_hash -> None (only care about presence) + self.b1: OrderedDict[BlockHash, None] = OrderedDict() + self.b2: OrderedDict[BlockHash, None] = OrderedDict() + + def get(self, block_hash: BlockHash) -> BlockStatus | None: + return self.t1.get(block_hash) or self.t2.get(block_hash) + + def insert(self, block_hash: BlockHash, block: BlockStatus) -> None: + self.t1[block_hash] = block + self.b1.pop(block_hash, None) + self.b2.pop(block_hash, None) + + def remove(self, block_hash: BlockHash) -> None: + if self.t1.pop(block_hash, None) is None: + self.t2.pop(block_hash, None) + + def touch(self, block_hashes: Iterable[BlockHash]) -> None: + for block_hash in reversed(list(block_hashes)): + if block_hash in self.t1: + block = self.t1.pop(block_hash) + if not block.is_ready: + # block was just prepared to be stored, not really touched + # twice — keep it in T1 and mark as most recently used + self.t1[block_hash] = block + else: + self.t2[block_hash] = block + + elif block_hash in self.t2: + self.t2.move_to_end(block_hash) + + elif block_hash in self.b1: + delta = max(1, len(self.b2) / len(self.b1)) + self.target_t1_size = min( + self.target_t1_size + delta, self.cache_capacity + ) + # move to MRU position (end) to keep it fresh in the ghost list + self.b1.move_to_end(block_hash) + + elif block_hash in self.b2: + delta = max(1, len(self.b1) / len(self.b2)) + self.target_t1_size = max(self.target_t1_size - delta, 0) + # move to MRU position (end) to keep it fresh in the ghost list + self.b2.move_to_end(block_hash) + + def evict( + self, n: int, protected: set[BlockHash] + ) -> list[tuple[BlockHash, BlockStatus]] | None: + if n == 0: + return [] + + # Collect candidates atomically: simulate T1 size changes as we select, + # but do not modify actual data structures until all n are found. + candidates: list[ + tuple[BlockHash, BlockStatus, bool] + ] = [] # (hash, block, from_t1) + already_selected: set[BlockHash] = set() + virtual_t1_size = len(self.t1) + + for _ in range(n): + candidate: tuple[BlockHash, BlockStatus, bool] | None = None + + if virtual_t1_size >= int(self.target_t1_size): + for block_hash, block in self.t1.items(): + if ( + block.ref_cnt == 0 + and block_hash not in protected + and block_hash not in already_selected + ): + candidate = (block_hash, block, True) + virtual_t1_size -= 1 + break + + if candidate is None: + for block_hash, block in self.t2.items(): + if ( + block.ref_cnt == 0 + and block_hash not in protected + and block_hash not in already_selected + ): + candidate = (block_hash, block, False) + break + if candidate is None: + return None + + candidates.append(candidate) + already_selected.add(candidate[0]) + + # Apply all evictions now that we know n candidates exist. + result: list[tuple[BlockHash, BlockStatus]] = [] + for block_hash, block, from_t1 in candidates: + if from_t1: + del self.t1[block_hash] + self.b1[block_hash] = None + else: + del self.t2[block_hash] + self.b2[block_hash] = None + result.append((block_hash, block)) + + # Trim ghost lists to cache_capacity. + for ghost in (self.b1, self.b2): + for _ in range(len(ghost) - self.cache_capacity): + ghost.popitem(last=False) + + return result diff --git a/vllm/v1/kv_offload/cpu/policies/lru.py b/vllm/v1/kv_offload/cpu/policies/lru.py new file mode 100644 index 000000000000..b29b81f3c82e --- /dev/null +++ b/vllm/v1/kv_offload/cpu/policies/lru.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import OrderedDict +from collections.abc import Iterable + +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.kv_offload.cpu.policies.abstract import BlockStatus, CachePolicy + + +class LRUCachePolicy(CachePolicy): + """LRU cache policy backed by a single OrderedDict.""" + + def __init__(self, cache_capacity: int): + # cache_capacity unused by LRU but accepted for a uniform constructor + self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict() + + def get(self, block_hash: BlockHash) -> BlockStatus | None: + return self.blocks.get(block_hash) + + def insert(self, block_hash: BlockHash, block: BlockStatus) -> None: + self.blocks[block_hash] = block + + def remove(self, block_hash: BlockHash) -> None: + del self.blocks[block_hash] + + def touch(self, block_hashes: Iterable[BlockHash]) -> None: + for block_hash in reversed(list(block_hashes)): + if block_hash in self.blocks: + self.blocks.move_to_end(block_hash) + + def evict( + self, n: int, protected: set[BlockHash] + ) -> list[tuple[BlockHash, BlockStatus]] | None: + if n == 0: + return [] + candidates: list[tuple[BlockHash, BlockStatus]] = [] + for block_hash, block in self.blocks.items(): + if block.ref_cnt == 0 and block_hash not in protected: + candidates.append((block_hash, block)) + if len(candidates) == n: + break + if len(candidates) < n: + return None + for block_hash, _ in candidates: + del self.blocks[block_hash] + return candidates diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu/spec.py similarity index 83% rename from vllm/v1/kv_offload/cpu.py rename to vllm/v1/kv_offload/cpu/spec.py index b1acff99ea1a..810967077a40 100644 --- a/vllm/v1/kv_offload/cpu.py +++ b/vllm/v1/kv_offload/cpu/spec.py @@ -9,9 +9,7 @@ from vllm.v1.attention.backend import AttentionBackend from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager -from vllm.v1.kv_offload.arc_manager import ARCOffloadingManager -from vllm.v1.kv_offload.backends.cpu import CPUBackend -from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager +from vllm.v1.kv_offload.cpu.manager import CPUOffloadingManager from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.reuse_manager import FilterReusedOffloadingManager from vllm.v1.kv_offload.spec import OffloadingSpec @@ -68,23 +66,13 @@ def get_manager(self) -> OffloadingManager: assert len(self.gpu_block_size) == 1 gpu_block_size = self.gpu_block_size[0] offloaded_block_size = gpu_block_size * self.block_size_factor - backend = CPUBackend( - block_size=offloaded_block_size, num_blocks=self.num_blocks - ) - if self.eviction_policy == "lru": - self._manager = LRUOffloadingManager( - backend=backend, enable_events=enable_events - ) - elif self.eviction_policy == "arc": - self._manager = ARCOffloadingManager( - backend=backend, enable_events=enable_events - ) - else: - raise ValueError( - f"Unknown eviction policy: {self.eviction_policy}. " - f"Supported policies: lru, arc" - ) + self._manager = CPUOffloadingManager( + block_size=offloaded_block_size, + num_blocks=self.num_blocks, + cache_policy=self.eviction_policy, # type: ignore[arg-type] + enable_events=enable_events, + ) # store_threshold: how many times a block must appear in lookup() # before it is eligible for CPU offloading. Values < 2 disable diff --git a/vllm/v1/kv_offload/factory.py b/vllm/v1/kv_offload/factory.py index d42f2cc63ba5..ecbaebb0d967 100644 --- a/vllm/v1/kv_offload/factory.py +++ b/vllm/v1/kv_offload/factory.py @@ -54,5 +54,5 @@ def create_spec( # Register various specs here. OffloadingSpecFactory.register_spec( - "CPUOffloadingSpec", "vllm.v1.kv_offload.cpu", "CPUOffloadingSpec" + "CPUOffloadingSpec", "vllm.v1.kv_offload.cpu.spec", "CPUOffloadingSpec" ) diff --git a/vllm/v1/kv_offload/lru_manager.py b/vllm/v1/kv_offload/lru_manager.py deleted file mode 100644 index 43dc7f7f19dd..000000000000 --- a/vllm/v1/kv_offload/lru_manager.py +++ /dev/null @@ -1,146 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import OrderedDict -from collections.abc import Iterable - -from vllm.v1.core.kv_cache_utils import BlockHash -from vllm.v1.kv_offload.abstract import ( - LoadStoreSpec, - OffloadingEvent, - OffloadingManager, - PrepareStoreOutput, -) -from vllm.v1.kv_offload.backend import Backend, BlockStatus - - -class LRUOffloadingManager(OffloadingManager): - """ - An OffloadingManager with a pluggable backend, which evicts blocks by LRU. - """ - - def __init__(self, backend: Backend, enable_events: bool = False): - self.backend: Backend = backend - # block_hash -> BlockStatus - self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict() - self.events: list[OffloadingEvent] | None = [] if enable_events else None - - def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None: - hit_count = 0 - for block_hash in block_hashes: - block = self.blocks.get(block_hash) - if block is None or not block.is_ready: - break - hit_count += 1 - return hit_count - - def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec: - blocks = [] - for block_hash in block_hashes: - block = self.blocks[block_hash] - assert block.is_ready - block.ref_cnt += 1 - blocks.append(block) - - return self.backend.get_load_store_spec(block_hashes, blocks) - - def touch(self, block_hashes: Iterable[BlockHash]): - for block_hash in reversed(list(block_hashes)): - if self.blocks.get(block_hash): - self.blocks.move_to_end(block_hash) - - def complete_load(self, block_hashes: Iterable[BlockHash]): - for block_hash in block_hashes: - block = self.blocks[block_hash] - assert block.ref_cnt > 0 - block.ref_cnt -= 1 - - def prepare_store( - self, block_hashes: Iterable[BlockHash] - ) -> PrepareStoreOutput | None: - block_hashes_list = list(block_hashes) - - # filter out blocks that are already stored - block_hashes_to_store = [ - block_hash - for block_hash in block_hashes_list - if block_hash not in self.blocks - ] - - num_blocks_to_evict = ( - len(block_hashes_to_store) - self.backend.get_num_free_blocks() - ) - - # build list of blocks to evict - to_evict = [] - if num_blocks_to_evict > 0: - # Blocks from the original input are excluded from eviction candidates: - # a block that was already stored must remain in the cache after this call. - protected = set(block_hashes_list) - for block_hash, block in self.blocks.items(): - if block.ref_cnt == 0 and block_hash not in protected: - to_evict.append(block_hash) - num_blocks_to_evict -= 1 - if num_blocks_to_evict == 0: - break - else: - # we could not evict enough blocks - return None - - # evict blocks - for block_hash in to_evict: - self.backend.free(self.blocks.pop(block_hash)) - - if to_evict and self.events is not None: - self.events.append( - OffloadingEvent( - block_hashes=to_evict, - block_size=self.backend.block_size, - medium=self.backend.medium, - removed=True, - ) - ) - - blocks = self.backend.allocate_blocks(block_hashes_to_store) - assert len(blocks) == len(block_hashes_to_store) - - for block_hash, block in zip(block_hashes_to_store, blocks): - self.blocks[block_hash] = block - - # build store specs for allocated blocks - store_spec = self.backend.get_load_store_spec(block_hashes_to_store, blocks) - - return PrepareStoreOutput( - block_hashes_to_store=block_hashes_to_store, - store_spec=store_spec, - block_hashes_evicted=to_evict, - ) - - def complete_store(self, block_hashes: Iterable[BlockHash], success: bool = True): - stored_block_hashes: list[BlockHash] = [] - if success: - for block_hash in block_hashes: - block = self.blocks[block_hash] - if not block.is_ready: - block.ref_cnt = 0 - stored_block_hashes.append(block_hash) - else: - for block_hash in block_hashes: - block = self.blocks[block_hash] - if not block.is_ready: - self.backend.free(block) - del self.blocks[block_hash] - - if stored_block_hashes and self.events is not None: - self.events.append( - OffloadingEvent( - block_hashes=stored_block_hashes, - block_size=self.backend.block_size, - medium=self.backend.medium, - removed=False, - ) - ) - - def take_events(self) -> Iterable[OffloadingEvent]: - if self.events is not None: - yield from self.events - self.events.clear()