diff --git a/tests/v1/kv_connector/unit/test_mooncake_store_scheduler.py b/tests/v1/kv_connector/unit/test_mooncake_store_scheduler.py index 4b46c03f5831..2551774ef18f 100644 --- a/tests/v1/kv_connector/unit/test_mooncake_store_scheduler.py +++ b/tests/v1/kv_connector/unit/test_mooncake_store_scheduler.py @@ -44,6 +44,21 @@ def _make_scheduler_output(*, scheduled_spec_tokens: list[int] | None): ) +def _make_preemption_scheduler_output(): + return SimpleNamespace( + finished_req_ids=set(), + preempted_req_ids={"req-0"}, + scheduled_new_reqs=[], + scheduled_cached_reqs=SimpleNamespace( + req_ids=[], + new_block_ids=[], + num_computed_tokens=[], + ), + num_scheduled_tokens={}, + scheduled_spec_decode_tokens={}, + ) + + def _add_unfinished_request( scheduler: MooncakeStoreScheduler, *, @@ -113,6 +128,48 @@ def test_cached_request_without_spec_decode_keeps_current_step_save_overlap(): assert tracker.num_saved_tokens == 48 +def test_preemption_resets_tracker_before_request_finished(): + scheduler = _make_bare_scheduler() + _add_unfinished_request( + scheduler, + token_ids=list(range(44)), + block_hashes=[b"h0", b"h1"], + prefill_end_tokens=48, + ) + + scheduler.build_connector_meta(_make_preemption_scheduler_output()) + + tracker = scheduler._request_trackers["req-0"] + assert tracker.token_len == 0 + assert tracker.allocated_block_ids == () + assert tracker.num_saved_tokens == 0 + assert tracker.token_ids is None + assert tracker.prefill_end_tokens == 0 + request = SimpleNamespace(request_id="req-0") + assert scheduler.request_finished(request, ([0, 1],)) == (False, None) + + +def test_preemption_clears_stale_load_state(): + scheduler = _make_bare_scheduler() + _make_pending_load_unfinished_request( + scheduler, + num_tokens=48, + block_hashes=[b"h0", b"h1", b"h2"], + block_ids=([10, 11, 12],), + ) + scheduler.load_specs["req-0"] = LoadSpec( + vllm_cached_tokens=0, + kvpool_cached_tokens=48, + can_load=True, + ) + + meta = scheduler.build_connector_meta(_make_preemption_scheduler_output()) + + assert meta.requests == [] + assert "req-0" not in scheduler.load_specs + assert "req-0" not in scheduler._unfinished_requests + + def _make_pending_load_unfinished_request( scheduler: MooncakeStoreScheduler, *, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/data.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/data.py index f3e9a2e64469..8aa2625852fe 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/data.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/data.py @@ -173,6 +173,13 @@ class RequestTracker: # request it includes previously-generated tokens, which are re-prefilled. prefill_end_tokens: int = 0 + def reset(self) -> None: + self.token_len = 0 + self.allocated_block_ids = () + self.num_saved_tokens = 0 + self.token_ids = None + self.prefill_end_tokens = 0 + def update( self, new_block_ids: tuple[list[int], ...] | list[int], diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/scheduler.py index 5922965974fd..0bf178f26f73 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/scheduler.py @@ -176,7 +176,9 @@ def build_connector_meta( preempted_ids = scheduler_output.preempted_req_ids or set() self._preempted_req_ids.update(preempted_ids) for req_id in preempted_ids: - self._request_trackers.pop(req_id, None) + self.load_specs.pop(req_id, None) + if request_tracker := self._request_trackers.get(req_id): + request_tracker.reset() self._unfinished_requests.pop(req_id, None) meta = MooncakeStoreConnectorMetadata( @@ -374,8 +376,10 @@ def request_finished( if self.kv_role == "kv_consumer": return False, None tracker = self._request_trackers.get(request.request_id) - assert tracker is not None - if tracker.num_saved_tokens <= 0: + # Missing tracker can happen when the request is aborted before the + # connector observes the normal finished lifecycle or is preempted + # before finishing. + if tracker is None or tracker.num_saved_tokens <= 0: return False, None total_blocks = sum(len(g) for g in block_ids) delay_free_blocks = total_blocks > 0