Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions tests/v1/kv_offload/test_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional

import numpy as np

from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent,
PrepareStoreOutput)
from vllm.v1.kv_offload.backends.cpu import CPUBackend
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec


@dataclass
class ExpectedPrepareStoreOutput:
block_hashes_to_store: list[int]
store_block_ids: list[int]
block_hashes_evicted: list[int]


def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
return [BlockHash(str(i).encode()) for i in int_hashes]


def verify_store_output(
prepare_store_output: Optional[PrepareStoreOutput],
expected_prepare_store_output: ExpectedPrepareStoreOutput):
assert prepare_store_output is not None
assert (prepare_store_output.block_hashes_to_store == to_hashes(
expected_prepare_store_output.block_hashes_to_store))
assert (prepare_store_output.block_hashes_evicted == to_hashes(
expected_prepare_store_output.block_hashes_evicted))
store_spec = prepare_store_output.store_spec
assert isinstance(store_spec, CPULoadStoreSpec)
expected_array = np.array(expected_prepare_store_output.store_block_ids,
dtype=np.int64)
assert np.array_equal(expected_array, store_spec.block_ids)


def verify_load_output(prepare_load_output: LoadStoreSpec,
expected_prepare_load_output: list[int]):
assert isinstance(prepare_load_output, CPULoadStoreSpec)
expected_array = np.array(expected_prepare_load_output, dtype=np.int64)
assert np.array_equal(expected_array, prepare_load_output.block_ids)


def verify_events(events: Iterable[OffloadingEvent],
block_size: int,
expected_stores: tuple[set[int], ...] = (),
expected_evictions: tuple[set[int], ...] = ()):
stores: list[set[BlockHash]] = []
evictions: list[set[BlockHash]] = []
for event in events:
assert event.medium == CPULoadStoreSpec.medium()
assert event.block_size == block_size
if event.removed:
evictions.append(set(event.block_hashes))
else:
stores.append(set(event.block_hashes))

def to_hash_sets(
int_sets: tuple[set[int], ...]) -> tuple[set[BlockHash], ...]:
return tuple([set(to_hashes(list(int_set))) for int_set in int_sets])

assert tuple(evictions) == to_hash_sets(expected_evictions)
assert tuple(stores) == to_hash_sets(expected_stores)


def test_cpu_manager():
"""
Tests LRUOffloadingManager with a CPUBackend.
"""
# initialize a CPU backend with a capacity of 4 blocks
block_size = 256
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
cpu_manager = LRUOffloadingManager(cpu_backend, enable_events=True)

# prepare store [1, 2]
prepare_store_output = cpu_manager.prepare_store(to_hashes([1, 2]))
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
block_hashes_to_store=[1, 2],
store_block_ids=[0, 1],
block_hashes_evicted=[],
))

# lookup [1, 2] -> not ready
assert cpu_manager.lookup(to_hashes([1, 2])) == 0

# no events so far
assert list(cpu_manager.take_events()) == []

# complete store [1, 2]
cpu_manager.complete_store(to_hashes([1, 2]))
verify_events(cpu_manager.take_events(),
block_size=block_size,
expected_stores=({1, 2}, ))

# lookup [1, 2]
assert cpu_manager.lookup(to_hashes([1])) == 1
assert cpu_manager.lookup(to_hashes([1, 2])) == 2
assert cpu_manager.lookup(to_hashes([1, 2, 3])) == 2

# prepare store [2, 3, 4, 5] -> evicts [1]
prepare_store_output = cpu_manager.prepare_store(to_hashes([2, 3, 4, 5]))
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
block_hashes_to_store=[3, 4, 5],
store_block_ids=[2, 3, 0],
block_hashes_evicted=[1],
))

# verify eviction event
verify_events(cpu_manager.take_events(),
block_size=block_size,
expected_evictions=({1}, ))

# prepare store with no space
assert cpu_manager.prepare_store(to_hashes([1, 6])) is None

# complete store [2, 3, 4, 5]
cpu_manager.complete_store(to_hashes([2, 3, 4, 5]))

# prepare load [2, 3]
prepare_load_output = cpu_manager.prepare_load(to_hashes([2, 3]))
verify_load_output(prepare_load_output, [1, 2])

# prepare store with no space ([2, 3] is being loaded)
assert cpu_manager.prepare_store(to_hashes([6, 7, 8])) is None

# complete load [2, 3]
cpu_manager.complete_load(to_hashes([2, 3]))

# prepare store [6, 7, 8] -> evicts [2, 3, 4] (oldest)
prepare_store_output = cpu_manager.prepare_store(to_hashes([6, 7, 8]))
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
block_hashes_to_store=[6, 7, 8],
store_block_ids=[3, 2, 1],
block_hashes_evicted=[2, 3, 4],
))

# complete store [6, 7, 8]
cpu_manager.complete_store(to_hashes([6, 7, 8]))

# touch [5, 6, 7] (move to end of LRU order)
cpu_manager.touch(to_hashes([5, 6, 7]))

# prepare store [7, 9] -> evicts [8] (oldest following previous touch)
prepare_store_output = cpu_manager.prepare_store(to_hashes([9]))
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
block_hashes_to_store=[9],
store_block_ids=[1],
block_hashes_evicted=[8],
))

# complete store [7, 9] with failure
cpu_manager.complete_store(to_hashes([7, 9]), success=False)

# assert [7] is still stored, but [9] is not
assert cpu_manager.lookup(to_hashes([7])) == 1
assert cpu_manager.lookup(to_hashes([9])) == 0

verify_events(cpu_manager.take_events(),
block_size=block_size,
expected_stores=({3, 4, 5}, {6, 7, 8}),
expected_evictions=({2, 3, 4}, {8}))
96 changes: 96 additions & 0 deletions vllm/v1/kv_offload/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ctypes
from abc import ABC, abstractmethod
from collections.abc import Iterable

from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_offload.abstract import LoadStoreSpec


class BlockStatus(ctypes.Structure):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about the use of c struct for this in particular?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to reduce the memory footprint per block.
From what I've read this allows to use more compact representation.

"""
Offloading status for a single block of KV data.
Holds the following information:

ref_cnt - the current number of transfers using this block as a source.
A value of -1 indicates the block is not yet ready to be read.
load_store_spec - backend-specific information on how to actually
read/write the block.
"""
_fields_ = [("ref_cnt", ctypes.c_int32)]

def __init__(self):
super().__init__()
# initialize block as "not ready" (ref_cnt = -1)
self.ref_cnt = -1

@property
def is_ready(self) -> bool:
"""
Returns whether the block is ready to be read.
"""
return self.ref_cnt >= 0


class Backend(ABC):
"""
An abstract class for allocating and returning specs for writing
KV blocks to some backend.
"""

def __init__(self, block_size: int, medium: str):
self.block_size = block_size
self.medium = medium

@abstractmethod
def get_num_free_blocks(self):
"""
Returns the number of current number of blocks that can be allocated.
"""
pass

@abstractmethod
def allocate_blocks(self,
block_hashes: list[BlockHash]) -> list[BlockStatus]:
"""
Allocate space for writing blocks.
This method assumes there is enough space for allocation.
It is unsafe to use without checking get_num_free_blocks beforehand.

Args:
block_hashes: the hashes identifying the blocks to be written.

Returns:
A list of BlockStatus for the allocated blocks.
The ref_cnt of each returned item will be -1, meaning the block
is not yet ready to be read.
"""
pass

@abstractmethod
def free(self, block: BlockStatus):
"""
Free a previously allocated block.
You should only call this function with blocks returned by
allocate_blocks, and only once per each block.

Args:
block: The block to be freed.
"""
pass

def get_load_store_spec(self, block_hashes: Iterable[BlockHash],
blocks: Iterable[BlockStatus]) -> LoadStoreSpec:
"""
Get backend-specific information on how to read/write blocks.

Args:
block_hashes: the list of block hashes identifying the blocks.
blocks: the list of blocks.

Returns:
A LoadStoreSpec that can be used by a worker
to read/write the blocks.
"""
raise NotImplementedError
61 changes: 61 additions & 0 deletions vllm/v1/kv_offload/backends/cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ctypes
from collections.abc import Iterable

from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_offload.abstract import LoadStoreSpec
from vllm.v1.kv_offload.backend import Backend, BlockStatus
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec


class CPUBlockStatus(BlockStatus):
_fields_ = BlockStatus._fields_ + [("block_id", ctypes.c_int64)
] # type: ignore

def __init__(self, block_id: int):
super().__init__()
self.block_id = block_id


class CPUBackend(Backend):

def __init__(self, block_size: int, num_blocks: int):
super().__init__(block_size=block_size,
medium=CPULoadStoreSpec.medium())

self.num_blocks: int = num_blocks
self.num_allocated_blocks: int = 0
self.allocated_blocks_free_list: list[int] = []

def get_num_free_blocks(self):
return (len(self.allocated_blocks_free_list) + self.num_blocks -
self.num_allocated_blocks)

def allocate_blocks(self,
block_hashes: list[BlockHash]) -> list[BlockStatus]:
num_fresh_blocks = min(len(block_hashes),
self.num_blocks - self.num_allocated_blocks)
num_reused_blocks = len(block_hashes) - num_fresh_blocks
assert len(self.allocated_blocks_free_list) >= num_reused_blocks

# allocate fresh blocks
blocks: list[BlockStatus] = []
for _ in range(num_fresh_blocks):
blocks.append(CPUBlockStatus(self.num_allocated_blocks))
self.num_allocated_blocks += 1

# allocate reused blocks
for _ in range(num_reused_blocks):
block_id = self.allocated_blocks_free_list.pop()
blocks.append(CPUBlockStatus(block_id))

return blocks

def free(self, block: BlockStatus):
assert isinstance(block, CPUBlockStatus)
self.allocated_blocks_free_list.append(block.block_id)

def get_load_store_spec(self, block_hashes: Iterable[BlockHash],
blocks: Iterable[BlockStatus]) -> LoadStoreSpec:
return CPULoadStoreSpec([block.block_id for block in blocks])
Loading