Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions tests/v1/kv_connector/unit/test_mooncake_store_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from types import SimpleNamespace

from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.store.data import (
LoadSpec,
ReqMeta,
RequestTracker,
)
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.store.scheduler import (
Expand Down Expand Up @@ -109,3 +111,236 @@ def test_cached_request_without_spec_decode_keeps_current_step_save_overlap():
tracker = scheduler._request_trackers["req-0"]
assert tracker.token_len == 48
assert tracker.num_saved_tokens == 48


def _make_pending_load_unfinished_request(
scheduler: MooncakeStoreScheduler,
*,
num_tokens: int,
block_hashes: list[bytes],
block_ids: tuple[list[int], ...] = ([0, 1, 2],),
) -> None:
request = SimpleNamespace(
num_tokens=num_tokens,
block_hashes=block_hashes,
num_output_placeholders=0,
)
scheduler._unfinished_requests["req-0"] = (request, block_ids)


def _make_pending_load_scheduler_output() -> SimpleNamespace:
"""scheduler_output for a step where req-0 is parked on a pending load
(not in scheduled_new_reqs or scheduled_cached_reqs)."""
return SimpleNamespace(
finished_req_ids=set(),
preempted_req_ids=set(),
scheduled_new_reqs=[],
scheduled_cached_reqs=SimpleNamespace(
req_ids=[],
new_block_ids=[],
num_computed_tokens=[],
),
num_scheduled_tokens={},
scheduled_spec_decode_tokens={},
)


def test_pending_load_does_not_co_queue_save():
# Regression: a cache-hit request waiting on an async load must not also
# enqueue a save in the same scheduling step. Co-queuing both produces a
# recv+send pair for the same req_id, and the scheduler's
# _update_from_kv_xfer_finished then trips `assert req_id in self.requests`
# when both completions land for the delay-freed request.
scheduler = _make_bare_scheduler()
_make_pending_load_unfinished_request(
scheduler,
num_tokens=48,
block_hashes=[b"h0", b"h1", b"h2"],
)
scheduler.load_specs["req-0"] = LoadSpec(
vllm_cached_tokens=0,
kvpool_cached_tokens=48,
can_load=True,
)

meta = scheduler.build_connector_meta(_make_pending_load_scheduler_output())

assert len(meta.requests) == 1
req_meta = meta.requests[0]
assert req_meta.req_id == "req-0"
# Save must be off so the worker does not call add_stored_request.
assert req_meta.can_save is False
# Load is still issued as planned.
assert req_meta.load_spec is not None
assert req_meta.load_spec.can_load is True
# And the tracker's saved-tokens watermark stays at 0 so request_finished
# later sees `num_saved_tokens <= 0` and frees immediately rather than
# waiting for a finished_sending that will never come.
tracker = scheduler._request_trackers["req-0"]
assert tracker.num_saved_tokens == 0


def _make_resumed_unfinished_request(
scheduler: MooncakeStoreScheduler,
*,
token_ids: list[int],
block_hashes: list[bytes],
num_computed_tokens: int,
) -> None:
request = SimpleNamespace(
all_token_ids=token_ids,
block_hashes=block_hashes,
num_computed_tokens=num_computed_tokens,
num_output_placeholders=0,
)
scheduler._unfinished_requests["req-0"] = (request, ([0, 1],))


def _make_resumed_scheduler_output(*, num_scheduled_tokens: int) -> SimpleNamespace:
return SimpleNamespace(
finished_req_ids=set(),
preempted_req_ids=set(),
scheduled_new_reqs=[],
scheduled_cached_reqs=SimpleNamespace(
req_ids=["req-0"],
new_block_ids=[([2],)],
num_computed_tokens=[0],
),
num_scheduled_tokens={"req-0": num_scheduled_tokens},
scheduled_spec_decode_tokens={},
)


def test_resumed_from_preemption_with_load_skips_save():
# On resume-from-preemption with a cache hit, the same co-queueing race
# applies: the resumed-from-preemption branch in build_connector_meta also
# passes load_spec.can_load=True. Skip save in this step; subsequent
# cached_reqs steps will save new tokens normally.
scheduler = _make_bare_scheduler()
scheduler._preempted_req_ids = {"req-0"}
_make_resumed_unfinished_request(
scheduler,
token_ids=list(range(48)),
block_hashes=[b"h0", b"h1", b"h2"],
num_computed_tokens=0,
)
scheduler.load_specs["req-0"] = LoadSpec(
vllm_cached_tokens=0,
kvpool_cached_tokens=48,
can_load=True,
)

meta = scheduler.build_connector_meta(
_make_resumed_scheduler_output(num_scheduled_tokens=48)
)

assert len(meta.requests) == 1
req_meta = meta.requests[0]
assert req_meta.req_id == "req-0"
assert req_meta.can_save is False
assert req_meta.load_spec is not None
assert req_meta.load_spec.can_load is True
tracker = scheduler._request_trackers["req-0"]
assert tracker.num_saved_tokens == 0


def test_resumed_from_preemption_without_load_still_saves():
# No load_spec → behavior is unchanged: save proceeds.
scheduler = _make_bare_scheduler()
scheduler._preempted_req_ids = {"req-0"}
_make_resumed_unfinished_request(
scheduler,
token_ids=list(range(48)),
block_hashes=[b"h0", b"h1", b"h2"],
num_computed_tokens=0,
)

meta = scheduler.build_connector_meta(
_make_resumed_scheduler_output(num_scheduled_tokens=48)
)

assert len(meta.requests) == 1
req_meta = meta.requests[0]
assert req_meta.req_id == "req-0"
assert req_meta.can_save is True
assert req_meta.load_spec is None
tracker = scheduler._request_trackers["req-0"]
assert tracker.num_saved_tokens == 48


# Focused tests for ReqMeta.from_request_tracker — the centralized guard that
# enforces "a ReqMeta never carries both a save and a load".


def test_from_request_tracker_load_overrides_caller_skip_save():
# Caller asks for skip_save=False, but load_spec.can_load=True. The
# function must force skip_save=True to avoid producing a ReqMeta the
# worker would enqueue on both kv_send_thread and kv_recv_thread.
tracker = RequestTracker(
req_id="req-0",
token_len=48,
allocated_block_ids=([0, 1, 2],),
num_saved_tokens=0,
)
load_spec = LoadSpec(vllm_cached_tokens=0, kvpool_cached_tokens=48, can_load=True)

req_meta = ReqMeta.from_request_tracker(
tracker,
block_size=16,
load_spec=load_spec,
skip_save=False,
block_hashes=[b"h0", b"h1", b"h2"],
)

assert req_meta is not None
assert req_meta.can_save is False
assert req_meta.load_spec is load_spec
assert tracker.num_saved_tokens == 0


def test_from_request_tracker_load_with_can_load_false_still_saves():
# A LoadSpec with can_load=False (e.g., no external tokens to load after
# update_state_after_alloc) must not suppress the save.
tracker = RequestTracker(
req_id="req-0",
token_len=48,
allocated_block_ids=([0, 1, 2],),
num_saved_tokens=0,
)
load_spec = LoadSpec(vllm_cached_tokens=0, kvpool_cached_tokens=48, can_load=False)

req_meta = ReqMeta.from_request_tracker(
tracker,
block_size=16,
load_spec=load_spec,
skip_save=False,
block_hashes=[b"h0", b"h1", b"h2"],
)

assert req_meta is not None
assert req_meta.can_save is True
# from_request_tracker clears load_spec when can_load is False.
assert req_meta.load_spec is None
assert tracker.num_saved_tokens == 48


def test_from_request_tracker_no_load_saves_normally():
tracker = RequestTracker(
req_id="req-0",
token_len=48,
allocated_block_ids=([0, 1, 2],),
num_saved_tokens=0,
)

req_meta = ReqMeta.from_request_tracker(
tracker,
block_size=16,
load_spec=None,
skip_save=False,
block_hashes=[b"h0", b"h1", b"h2"],
)

assert req_meta is not None
assert req_meta.can_save is True
assert req_meta.load_spec is None
assert tracker.num_saved_tokens == 48
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,12 @@ def from_request_tracker(
)

skip_save = skip_save or num_tokens_to_save < chunk_boundary
# A ReqMeta must never carry both a save AND a load.
# The save would also be wasted work — the bytes are being looked up
# in the store right now. Later cached_reqs steps save new tokens
# normally.
if load_spec is not None and load_spec.can_load:
skip_save = True
if skip_save and load_spec is None:
return None

Expand Down
Loading