Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@
from unittest.mock import Mock

import pytest
import torch

from vllm.distributed.kv_transfer.kv_connector.v1 import SupportsHMA
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
SlidingWindowSpec,
)
from vllm.v1.request import FinishReason, Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager

from .utils import (
create_model_runner_output,
Expand Down Expand Up @@ -478,3 +487,110 @@ def cache_blocks_spy(req, num_tokens):

# request should be in the running queue
assert request in recompute_scheduler.running


@pytest.fixture
def multi_group_fail_scheduler():
vllm_config = create_vllm_config(kv_load_failure_policy="fail")
kv_cache_config = KVCacheConfig(
num_blocks=1000,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=16,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=32,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
sliding_window=32 * 4,
),
),
],
)
vllm_config.cache_config.num_gpu_blocks = 1000
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
block_size=16,
)


def test_sync_fail_multi_group_invalid_blocks_evicted(multi_group_fail_scheduler):
"""
Test sync fail case with multiple KV groups

Same as test_sync_fail_invalid_blocks_evicted but for multiple KV groups.
"""
scheduler = multi_group_fail_scheduler

num_prompt_tokens = 100 * scheduler.block_size
num_external_computed_tokens = 75 * scheduler.block_size

request = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request)

req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}

class HMAMockConnector(Mock, SupportsHMA):
def request_finished_all_groups(self, _, __):
return False, None

scheduler.connector = HMAMockConnector()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
scheduler.connector.request_finished.return_value = (False, None)
scheduler.connector.take_events.return_value = ()

scheduler_output = scheduler.schedule()

# get allocated block IDs
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[-1]
req_block_ids = [block_id for block_id in req_block_ids if block_id]
invalid_block_id = req_block_ids[1]
invalid_block_ids = {invalid_block_id}

# report invalid blocks - request should fail
model_runner_output = create_model_runner_output(
[request],
invalid_block_ids=invalid_block_ids,
)

_ = scheduler.update_from_output(scheduler_output, model_runner_output)

# verify the request was removed from scheduler
assert request.request_id not in scheduler.requests
assert len(scheduler.running) == 0

# verify invalid block was actually freed from cache
try:
block_id_groups = scheduler.kv_cache_manager.get_block_ids(request.request_id)
# if we get here, check if blocks were actually freed
if block_id_groups is not None:
num_allocated_blocks = 0
for block_ids in block_id_groups:
num_allocated_blocks += sum(block_id != 0 for block_id in block_ids)
if num_allocated_blocks > 0:
pytest.fail(
f"Invalid blocks still tracked for finished request! "
f"Request {request.request_id} should have been freed but "
f"still has {num_allocated_blocks} blocks allocated."
)
# blocks list exists but is empty - this is fine, they were freed
except KeyError:
# expected - request completely removed from tracking
pass
169 changes: 169 additions & 0 deletions tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@
from unittest.mock import Mock

import pytest
import torch

from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
SlidingWindowSpec,
)
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager

from .utils import (
create_model_runner_output,
Expand Down Expand Up @@ -333,3 +341,164 @@ def test_async_progressive_load_failure(
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.failed_recving_kv_req_ids == {request.request_id}
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1


@pytest.fixture
def multi_group_scheduler():
vllm_config = create_vllm_config(kv_load_failure_policy="recompute")
kv_cache_config = KVCacheConfig(
num_blocks=1000,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=16,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=32,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
sliding_window=32 * 4,
),
),
],
)
vllm_config.cache_config.num_gpu_blocks = 1000
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
block_size=16,
)


def test_async_load_failure_multiple_kv_groups(multi_group_scheduler):
scheduler = multi_group_scheduler

num_prompt_tokens = 100 * scheduler.block_size
num_external_computed_tokens = 50 * scheduler.block_size

request1 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request1)
request2 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request2)
request3 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request3)

# Mock KV connector method.
# req_id -> num_external_computed_tokens
req_num_new_matched_tokens = {
request1.request_id: num_external_computed_tokens,
request2.request_id: num_external_computed_tokens,
request3.request_id: num_external_computed_tokens,
}

scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True)
)
scheduler.connector.take_events.return_value = ()

scheduler_output = scheduler.schedule()
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3

# Simulate a failure in loading one of request2 full attention blocks.
# and a failure in loading one of request3 sliding window blocks.
req2_block_ids = scheduler.kv_cache_manager.get_block_ids(request2.request_id)[0]
req3_block_ids = scheduler.kv_cache_manager.get_block_ids(request3.request_id)[1]
invalid_block_ids = {req2_block_ids[-1], req3_block_ids[-1]}
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving={request1.request_id},
invalid_block_ids=invalid_block_ids,
use_eos=True,
)

scheduler.update_from_output(scheduler_output, model_runner_output)

assert len(scheduler.waiting) == 3
for request in scheduler.running:
if request.request_id != request1.request_id:
assert request.num_computed_tokens == 0
assert request.num_external_computed_tokens == 0
else:
assert request.num_computed_tokens == num_external_computed_tokens
assert request.num_external_computed_tokens == num_external_computed_tokens
assert scheduler.failed_recving_kv_req_ids == {
request2.request_id,
request3.request_id,
}


def test_sync_load_failure_multiple_kv_groups(multi_group_scheduler):
scheduler = multi_group_scheduler
scheduler.max_num_scheduled_tokens = 10000

num_prompt_tokens = 100 * scheduler.block_size
num_external_computed_tokens = 75 * scheduler.block_size
common_prefix_len = 50 * scheduler.block_size

request1 = create_request(
num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len
)
scheduler.add_request(request=request1)
request2 = create_request(
num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len
)
scheduler.add_request(request=request2)
request3 = create_request(
num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len
)
scheduler.add_request(request=request3)
request4 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request4)

# Mock KV connector method.
# req_id -> num_external_computed_tokens
req_num_new_matched_tokens = {
request1.request_id: num_external_computed_tokens,
request2.request_id: num_external_computed_tokens,
request3.request_id: num_external_computed_tokens,
request4.request_id: num_external_computed_tokens,
}

scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False)
)
scheduler.connector.take_events.return_value = ()

scheduler_output = scheduler.schedule()
assert scheduler.connector.get_num_new_matched_tokens.call_count == 4

# Simulate a failure in loading:
# - one of request2 full attention blocks (shared block).
# - one of request3 sliding window blocks (non-shared block).
req2_block_ids = scheduler.kv_cache_manager.get_block_ids(request2.request_id)[0]
req3_block_ids = scheduler.kv_cache_manager.get_block_ids(request3.request_id)[1]
req3_block_ids = [block_id for block_id in req3_block_ids if block_id]
invalid_block_ids = {req2_block_ids[1], req3_block_ids[1]}
model_runner_output = create_model_runner_output(
reqs=[request1, request2, request3, request4],
invalid_block_ids=invalid_block_ids,
)

scheduler.update_from_output(scheduler_output, model_runner_output)

assert len(scheduler.running) == 4
for request in scheduler.waiting:
if request.request_id != request4.request_id:
assert request.num_computed_tokens == 0
assert request.num_external_computed_tokens == 0
else:
assert request.num_computed_tokens == num_external_computed_tokens
assert request.num_external_computed_tokens == num_external_computed_tokens
43 changes: 36 additions & 7 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.encoder_budget import MultiModalBudget
from vllm.utils.math_utils import cdiv
from vllm.v1.core.encoder_cache_manager import (
EncoderCacheManager,
EncoderDecoderCacheManager,
Expand Down Expand Up @@ -2094,8 +2095,7 @@ def _update_requests_with_invalid_blocks(
is_affected = False
marked_invalid_block = False
req_id = request.request_id
# TODO (davidb): add support for hybrid memory allocator
(req_block_ids,) = self.kv_cache_manager.get_block_ids(req_id)

# We iterate only over blocks that may contain externally computed
# tokens
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
Expand All @@ -2105,15 +2105,42 @@ def _update_requests_with_invalid_blocks(
# Sync loading. num_computed_tokens includes new tokens
req_num_computed_tokens = request.num_cached_tokens

req_num_computed_blocks = (
req_num_computed_tokens + self.block_size - 1
) // self.block_size
for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids):
block_sizes = (
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in self.kv_cache_config.kv_cache_groups
)

req_block_ids = itertools.chain.from_iterable(
group_block_ids[: cdiv(req_num_computed_tokens, block_size)]
for group_block_ids, block_size in zip(
self.kv_cache_manager.get_block_ids(req_id), block_sizes
)
)

for idx, block_id in enumerate(req_block_ids):
if block_id not in invalid_block_ids:
continue

is_affected = True

if len(self.kv_cache_config.kv_cache_groups) > 1:
# We have more than one KV cache group.
# This means that not all layers are full attention,
# but instead some layers are sliding window attention or SSM.
# This means the entire request has to be re-computed.
total_affected_tokens += request.num_computed_tokens
request.num_computed_tokens = 0
request.num_external_computed_tokens = 0
marked_invalid_block = True
if evict_blocks:
# evict the entire request
blocks_to_evict.update(
itertools.chain.from_iterable(
self.kv_cache_manager.get_block_ids(req_id)
)
)
break

if block_id in marked_invalid_block_ids:
# This invalid block is shared with a previous request
# and was already marked for recomputation.
Expand All @@ -2140,7 +2167,9 @@ def _update_requests_with_invalid_blocks(
request.num_external_computed_tokens -= num_affected_tokens
# collect invalid block and all downstream dependent blocks
if evict_blocks:
blocks_to_evict.update(req_block_ids[idx:])
blocks_to_evict.update(
self.kv_cache_manager.get_block_ids(req_id)[0][idx:]
)

if is_affected:
if not marked_invalid_block:
Expand Down