Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 287 additions & 2 deletions tests/v1/core/test_single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
)
from vllm.v1.core.single_type_kv_cache_manager import (
ChunkedLocalAttentionManager,
MambaManager,
SlidingWindowManager,
)
from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, SlidingWindowSpec
from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
MambaSpec,
SlidingWindowSpec,
)

pytestmark = pytest.mark.cpu_test

Expand Down Expand Up @@ -480,4 +485,284 @@ def test_predictor_matches_allocator_blocks_calculation_with_admission_cap():
f"num_tokens={num_tokens}: predictor returned {predicted} "
f"but allocator pulled {len(new_blocks)}"
)
total_computed = num_tokens


## Mamba prefix caching with PD


def _make_mamba_align_manager(block_size, block_pool):
spec = MambaSpec(
block_size=block_size,
shapes=((1,), (1,)),
dtypes=(torch.float32,),
mamba_cache_mode="align",
)
return MambaManager(
spec,
block_pool=block_pool,
enable_caching=True,
kv_cache_group_id=0,
max_admission_blocks_per_request=10**9,
)


class _FakeRequest:
"""Minimal stand-in for vllm.v1.request.Request used by cache_blocks."""

def __init__(self, request_id, block_hashes):
self.request_id = request_id
self.block_hashes = block_hashes


def test_mamba_align_num_cached_block_excludes_null_blocks_pd():
"""PD path: all tokens are external, no local prefix hit.

allocate_new_computed_blocks must set num_cached_block to 0
(not to the number of null padding blocks), otherwise cache_blocks()
will early-return and never hash the real state block.
"""
block_size = 128
block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = _make_mamba_align_manager(block_size, block_pool)

request_id = "req_pd"
num_external = 400

manager.allocate_new_computed_blocks(
request_id,
new_computed_blocks=[],
num_local_computed_tokens=0,
num_external_computed_tokens=num_external,
)

# get_num_skipped_tokens(400)=399 → num_skipped_blocks=3
# req_blocks = [null, null, null, fresh]
req_blocks = manager.req_to_blocks[request_id]
assert len(req_blocks) == 4
assert all(req_blocks[i].is_null for i in range(3))
assert not req_blocks[3].is_null

# The fix: num_cached_block must be 0, not 3.
assert manager.num_cached_block[request_id] == 0


def test_mamba_align_num_cached_block_with_local_prefix_hit():
"""PD path with a partial local prefix cache hit.

If 2 computed blocks come from a local prefix hit and 1 is skipped,
num_cached_block should be 2 (all original computed blocks).
"""
block_size = 128
block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = _make_mamba_align_manager(block_size, block_pool)

# Simulate 2 local prefix-hit blocks.
cached_block_0 = block_pool.blocks[10]
cached_block_0.block_hash = make_block_hash_with_group_id(BlockHash(b"h0"), 0)
cached_block_1 = block_pool.blocks[11]
cached_block_1.block_hash = make_block_hash_with_group_id(BlockHash(b"h1"), 0)

request_id = "req_partial"
# 256 tokens → get_num_skipped_tokens(256)=255 → num_skipped_blocks=1
# new_computed_blocks[1:] = [cached_block_1] survives.
manager.allocate_new_computed_blocks(
request_id,
new_computed_blocks=[cached_block_0, cached_block_1],
num_local_computed_tokens=2 * block_size,
num_external_computed_tokens=0,
)

req_blocks = manager.req_to_blocks[request_id]
# [null, cached_block_1]
assert len(req_blocks) == 2
assert req_blocks[0].is_null
assert req_blocks[1] is cached_block_1

# num_cached_block = len(original new_computed_blocks) = 2
assert manager.num_cached_block[request_id] == 2


def test_mamba_align_cache_blocks_registers_null_hashes():
"""After cache_blocks(), null-block-position hashes must be in the
hash map so that find_longest_cache_hit can discover them.
"""
block_size = 128
block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = _make_mamba_align_manager(block_size, block_pool)

request_id = "req_nullhash"
num_external = 400 # 3 full blocks + 16-token partial

manager.allocate_new_computed_blocks(
request_id,
new_computed_blocks=[],
num_local_computed_tokens=0,
num_external_computed_tokens=num_external,
)

# Build a fake request whose block_hashes cover at least num_full_blocks.
num_full_blocks = num_external // block_size # 3
block_hashes = [BlockHash(f"h{i}".encode()) for i in range(num_full_blocks + 1)]
fake_req = _FakeRequest(request_id, block_hashes)

# cache_blocks with num_tokens covering the 3 full blocks.
manager.cache_blocks(fake_req, num_full_blocks * block_size)

# All 3 null-block hashes should now be discoverable in the hash map.
for i in range(num_full_blocks):
key = make_block_hash_with_group_id(block_hashes[i], 0)
cached = block_pool.cached_block_hash_to_block.get_one_block(key)
assert cached is not None, f"hash[{i}] not found after cache_blocks"
assert cached.is_null, f"hash[{i}] should map to null_block"


def test_mamba_align_cache_blocks_does_not_early_return_pd():
"""End-to-end: cache_blocks must NOT early-return when all blocks are
null (the original bug). After cache_blocks, num_cached_block should
advance so subsequent calls with more tokens can cache the state block.
"""
block_size = 128
block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = _make_mamba_align_manager(block_size, block_pool)

request_id = "req_noreturn"
num_external = 400
num_full_blocks = num_external // block_size # 3

manager.allocate_new_computed_blocks(
request_id,
new_computed_blocks=[],
num_local_computed_tokens=0,
num_external_computed_tokens=num_external,
)

block_hashes = [BlockHash(f"h{i}".encode()) for i in range(num_full_blocks + 1)]
fake_req = _FakeRequest(request_id, block_hashes)

# First call — should process blocks 0..2 (all null) and advance.
manager.cache_blocks(fake_req, num_full_blocks * block_size)
assert manager.num_cached_block[request_id] == num_full_blocks

# Second call with more tokens — block 3 (the real state block) gets cached.
manager.cache_blocks(fake_req, (num_full_blocks + 1) * block_size)
assert manager.num_cached_block[request_id] == num_full_blocks + 1

# The real block should now be hashed (non-null).
state_block = manager.req_to_blocks[request_id][num_full_blocks]
assert not state_block.is_null
assert state_block.block_hash is not None


def test_mamba_align_find_longest_cache_hit_after_pd_caching():
"""Full round-trip: after caching a PD request's blocks,
a second request with the same prefix should get a cache hit.
"""
block_size = 128
block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = _make_mamba_align_manager(block_size, block_pool)
spec = manager.kv_cache_spec

request_id = "req_roundtrip"
num_tokens = 3 * block_size # 384 — exactly block-aligned
num_full_blocks = num_tokens // block_size # 3

# --- First request (PD): allocate + cache ---
manager.allocate_new_computed_blocks(
request_id,
new_computed_blocks=[],
num_local_computed_tokens=0,
num_external_computed_tokens=num_tokens,
)

block_hashes = [BlockHash(f"h{i}".encode()) for i in range(num_full_blocks + 1)]
fake_req = _FakeRequest(request_id, block_hashes)

# Cache the full blocks (null positions registered by MambaManager).
manager.cache_blocks(fake_req, num_full_blocks * block_size)

# Simulate decode advancing to fill block 3, then cache it too.
# allocate_new_blocks would normally do this; we add a fresh block manually.
req_blocks = manager.req_to_blocks[request_id]
if len(req_blocks) <= num_full_blocks:
fresh = block_pool.get_new_blocks(1)[0]
req_blocks.append(fresh)
manager.cache_blocks(fake_req, (num_full_blocks + 1) * block_size)

# --- Second request: find_longest_cache_hit ---
hit = MambaManager.find_longest_cache_hit(
block_hashes=block_hashes,
max_length=num_tokens,
kv_cache_group_ids=[0],
block_pool=block_pool,
kv_cache_spec=spec,
use_eagle=False,
alignment_tokens=block_size,
)

# Mamba searches right-to-left. With 384 tokens → max_num_blocks=3,
# it checks block_hashes[2]. The null-block hash was registered,
# so this should be a hit of length 3 blocks.
hit_blocks = hit[0]
assert len(hit_blocks) == num_full_blocks, (
f"Expected hit length {num_full_blocks}, got {len(hit_blocks)}"
)


def test_mamba_align_swa_unchanged_by_num_cached_block_fix():
"""Verify the num_cached_block fix does not change SWA behavior.

For SWA, new_computed_blocks from find_longest_cache_hit includes
null padding. len(new_computed_blocks) before skipping should equal
len(req_blocks) — same as the old len(req_blocks) code.
"""
block_size = 2
sliding_window = 4
sliding_window_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window,
)
block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)

# Simulate SWA find_longest_cache_hit returning [NULL, NULL, B7, B8]
# for 4-block prefix with 2-block sliding window.
computed_blocks = [
block_pool.null_block,
block_pool.null_block,
block_pool.blocks[7],
block_pool.blocks[8],
]
# Mark the real blocks as having hashes (as they would from a cache hit).
block_pool.blocks[7].block_hash = make_block_hash_with_group_id(BlockHash(b"b7"), 0)
block_pool.blocks[8].block_hash = make_block_hash_with_group_id(BlockHash(b"b8"), 0)

num_local_computed_tokens = 4 * block_size # 8 tokens
manager.allocate_new_computed_blocks(
"swa_req",
new_computed_blocks=computed_blocks,
num_local_computed_tokens=num_local_computed_tokens,
num_external_computed_tokens=0,
)

# SWA: get_num_skipped_tokens(8) = max(0, 8-4+1) = 5 → num_skipped_blocks=2
# new_computed_blocks[2:] = [B7, B8]
# req_blocks = [null, null, B7, B8]
# num_cached_block should be 4 (= len(original computed_blocks))
# which is the same as len(req_blocks) — matching old behavior.
assert manager.num_cached_block["swa_req"] == 4
assert len(manager.req_to_blocks["swa_req"]) == 4
24 changes: 24 additions & 0 deletions tests/v1/kv_connector/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc

import pytest
import torch


@pytest.fixture
def clean_gpu_memory_between_tests():
"""Free GPU memory before and after each test that uses a real GPU.

Call gc.collect() + empty_cache() before the test so that allocations
from previous tests (in the same session) don't prevent this test from
reserving the memory it needs. Repeat after the test so that the next
test starts with a clean slate.
"""
gc.collect()
if torch.accelerator.is_available():
torch.accelerator.empty_cache()
yield
gc.collect()
if torch.accelerator.is_available():
torch.accelerator.empty_cache()
Loading
Loading