Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tests/v1/kv_connector/unit/test_mooncake_store_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ def _make_scheduler_output(*, scheduled_spec_tokens: list[int] | None):
)


def _make_preemption_scheduler_output():
return SimpleNamespace(
finished_req_ids=set(),
preempted_req_ids={"req-0"},
scheduled_new_reqs=[],
scheduled_cached_reqs=SimpleNamespace(
req_ids=[],
new_block_ids=[],
num_computed_tokens=[],
),
num_scheduled_tokens={},
scheduled_spec_decode_tokens={},
)


def _add_unfinished_request(
scheduler: MooncakeStoreScheduler,
*,
Expand Down Expand Up @@ -113,6 +128,48 @@ def test_cached_request_without_spec_decode_keeps_current_step_save_overlap():
assert tracker.num_saved_tokens == 48


def test_preemption_resets_tracker_before_request_finished():
scheduler = _make_bare_scheduler()
_add_unfinished_request(
scheduler,
token_ids=list(range(44)),
block_hashes=[b"h0", b"h1"],
prefill_end_tokens=48,
)

scheduler.build_connector_meta(_make_preemption_scheduler_output())

tracker = scheduler._request_trackers["req-0"]
assert tracker.token_len == 0
assert tracker.allocated_block_ids == ()
assert tracker.num_saved_tokens == 0
assert tracker.token_ids is None
assert tracker.prefill_end_tokens == 0
request = SimpleNamespace(request_id="req-0")
assert scheduler.request_finished(request, ([0, 1],)) == (False, None)


def test_preemption_clears_stale_load_state():
scheduler = _make_bare_scheduler()
_make_pending_load_unfinished_request(
scheduler,
num_tokens=48,
block_hashes=[b"h0", b"h1", b"h2"],
block_ids=([10, 11, 12],),
)
scheduler.load_specs["req-0"] = LoadSpec(
vllm_cached_tokens=0,
kvpool_cached_tokens=48,
can_load=True,
)

meta = scheduler.build_connector_meta(_make_preemption_scheduler_output())

assert meta.requests == []
assert "req-0" not in scheduler.load_specs
assert "req-0" not in scheduler._unfinished_requests


def _make_pending_load_unfinished_request(
scheduler: MooncakeStoreScheduler,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ class RequestTracker:
# request it includes previously-generated tokens, which are re-prefilled.
prefill_end_tokens: int = 0

def reset(self) -> None:
self.token_len = 0
self.allocated_block_ids = ()
self.num_saved_tokens = 0
self.token_ids = None
self.prefill_end_tokens = 0

def update(
self,
new_block_ids: tuple[list[int], ...] | list[int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def build_connector_meta(
preempted_ids = scheduler_output.preempted_req_ids or set()
self._preempted_req_ids.update(preempted_ids)
for req_id in preempted_ids:
self._request_trackers.pop(req_id, None)
self.load_specs.pop(req_id, None)
if request_tracker := self._request_trackers.get(req_id):
request_tracker.reset()
self._unfinished_requests.pop(req_id, None)

meta = MooncakeStoreConnectorMetadata(
Expand Down Expand Up @@ -374,8 +376,10 @@ def request_finished(
if self.kv_role == "kv_consumer":
return False, None
tracker = self._request_trackers.get(request.request_id)
assert tracker is not None
if tracker.num_saved_tokens <= 0:
# Missing tracker can happen when the request is aborted before the
# connector observes the normal finished lifecycle or is preempted
# before finishing.
if tracker is None or tracker.num_saved_tokens <= 0:
return False, None
total_blocks = sum(len(g) for g in block_ids)
delay_free_blocks = total_blocks > 0
Expand Down
Loading