diff --git a/benchmarks/kv_connector_prefetch/__init__.py b/benchmarks/kv_connector_prefetch/__init__.py new file mode 100644 index 000000000000..593a790cd72e --- /dev/null +++ b/benchmarks/kv_connector_prefetch/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""KV connector early-prefetch microbenchmark package.""" diff --git a/benchmarks/kv_connector_prefetch/run_microbench.py b/benchmarks/kv_connector_prefetch/run_microbench.py new file mode 100644 index 000000000000..7c854c49dcb7 --- /dev/null +++ b/benchmarks/kv_connector_prefetch/run_microbench.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Microbenchmark for KV connector early-prefetch (#41784). + +Reproduces the failure mode jiangxiaosheng described in the issue: a saturated +running queue prevents the waiting-queue scheduling loop from ever invoking +the connector for newly arrived requests, so async lookups never start until +the running queue drains, leaving the GPU idle waiting on the just-started +lookup. + +The benchmark uses a synthetic `SlowMockKVConnector` that simulates a +disk-backed KV store with a fixed wall-clock lookup latency. We: + + 1) Drown the engine in long-decode "warmup" requests that saturate + `--max-num-batched-tokens`. + 2) Slightly later, submit short "probe" requests and measure their TTFT. + +With the prefetch optimization disabled (``--budget 0``), each probe pays +the full simulated lookup latency right before its first token. With the +optimization enabled (``--budget`` >= total prompt-token sum of probes), +the lookup timer starts at the top of the next schedule step rather than +when the probe finally reaches the head of the waiting queue, hiding the +latency behind ongoing GPU work. + +Run from the repo root: + + .venv/bin/python -m benchmarks.kv_connector_prefetch.run_microbench \\ + --model Qwen/Qwen2.5-0.5B-Instruct \\ + --max-num-batched-tokens 512 \\ + --num-warmup 6 --num-probe 12 \\ + --latency-ms 200 \\ + --compare + +The ``--compare`` flag runs the benchmark twice (budget=0 then +budget=large) and prints a side-by-side delta. Without ``--compare`` the +benchmark runs once with ``--budget``. +""" + +from __future__ import annotations + +import argparse +import asyncio +import os +import random +import statistics +import time +import uuid +from dataclasses import dataclass + +from vllm import SamplingParams +from vllm.config import KVTransferConfig +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.inputs import TokensPrompt +from vllm.usage.usage_lib import UsageContext +from vllm.v1.engine.async_llm import AsyncLLM + +SLOW_MOCK_MODULE = "benchmarks.kv_connector_prefetch.slow_mock_connector" + + +@dataclass +class ProbeResult: + request_id: str + submit_time: float + first_token_time: float + finish_time: float + + @property + def ttft(self) -> float: + return self.first_token_time - self.submit_time + + @property + def e2e(self) -> float: + return self.finish_time - self.submit_time + + +def _make_engine(args: argparse.Namespace, budget: int) -> AsyncLLM: + kv_transfer_config = KVTransferConfig( + kv_connector="SlowMockKVConnector", + kv_role="kv_both", + kv_connector_module_path=SLOW_MOCK_MODULE, + ) + + engine_args = AsyncEngineArgs( + model=args.model, + max_model_len=args.max_model_len, + max_num_batched_tokens=args.max_num_batched_tokens, + max_num_seqs=args.max_num_seqs, + gpu_memory_utilization=args.gpu_memory_utilization, + dtype=args.dtype, + enforce_eager=args.enforce_eager, + disable_log_stats=True, + enable_prefix_caching=False, + kv_transfer_config=kv_transfer_config, + kv_connector_prefetch_token_budget=budget, + ) + return AsyncLLM.from_engine_args(engine_args, usage_context=UsageContext.LLM_CLASS) + + +_RNG = random.Random(0xC0FFEE) + + +def _build_prompt(num_tokens: int) -> TokensPrompt: + # Random token ids give us exact length control (no tokenizer variance) + # and cheap uniqueness across requests. We avoid id 0/1/2 which tend + # to be special tokens in most tokenizers. + token_ids = [_RNG.randint(100, 30000) for _ in range(num_tokens)] + return TokensPrompt(prompt_token_ids=token_ids) + + +async def _run_request( + engine: AsyncLLM, + prompt: TokensPrompt, + sampling: SamplingParams, + submit_time: float, +) -> ProbeResult: + rid = str(uuid.uuid4()) + first_token_time: float | None = None + async for output in engine.generate( + prompt=prompt, sampling_params=sampling, request_id=rid + ): + if first_token_time is None and output.outputs and output.outputs[0].token_ids: + first_token_time = time.perf_counter() + if output.finished: + finish_time = time.perf_counter() + assert first_token_time is not None + return ProbeResult( + request_id=rid, + submit_time=submit_time, + first_token_time=first_token_time, + finish_time=finish_time, + ) + raise RuntimeError(f"request {rid} ended without `finished`") + + +async def _drive_one(engine: AsyncLLM, args: argparse.Namespace) -> list[ProbeResult]: + # Phase 1: kick off warmup requests that will saturate the running queue. + warmup_sampling = SamplingParams( + max_tokens=args.warmup_max_tokens, temperature=0.0, ignore_eos=True + ) + probe_sampling = SamplingParams( + max_tokens=args.probe_max_tokens, temperature=0.0, ignore_eos=True + ) + + # Burn-in: one tiny request to make the first compile/cudagraph cost + # land outside our measurement window. + burn_prompt = _build_prompt(8) + burn_sampling = SamplingParams(max_tokens=1, temperature=0.0, ignore_eos=True) + await _run_request(engine, burn_prompt, burn_sampling, time.perf_counter()) + + warmup_tasks = [] + for _ in range(args.num_warmup): + prompt = _build_prompt(args.warmup_prompt_tokens) + warmup_tasks.append( + asyncio.create_task( + _run_request(engine, prompt, warmup_sampling, time.perf_counter()) + ) + ) + + # Let the running queue fill before we submit the probes. This is what + # makes the prefetch pass observable: by the time probes arrive, the + # running queue is already saturating max_num_batched_tokens, so the + # waiting-queue scheduling loop would never run -- and without the + # early prefetch pass, the connector lookup never starts until later. + await asyncio.sleep(args.warmup_settle_s) + + probe_tasks = [] + for _ in range(args.num_probe): + prompt = _build_prompt(args.probe_prompt_tokens) + submit = time.perf_counter() + probe_tasks.append( + asyncio.create_task(_run_request(engine, prompt, probe_sampling, submit)) + ) + + probes = await asyncio.gather(*probe_tasks) + # Drain warmups so engine shutdown is clean. + await asyncio.gather(*warmup_tasks) + return probes + + +def _summarize(probes: list[ProbeResult], label: str) -> dict[str, float]: + ttfts = sorted(p.ttft for p in probes) + n = len(ttfts) + p50 = ttfts[n // 2] + p90 = ttfts[min(n - 1, int(n * 0.9))] + p99 = ttfts[min(n - 1, int(n * 0.99))] + mean = statistics.fmean(ttfts) + print(f"\n=== {label} (n={n}) ===") + print(f" TTFT mean : {mean * 1000:8.2f} ms") + print(f" TTFT p50 : {p50 * 1000:8.2f} ms") + print(f" TTFT p90 : {p90 * 1000:8.2f} ms") + print(f" TTFT p99 : {p99 * 1000:8.2f} ms") + print(f" TTFT min : {ttfts[0] * 1000:8.2f} ms") + print(f" TTFT max : {ttfts[-1] * 1000:8.2f} ms") + return {"mean": mean, "p50": p50, "p90": p90, "p99": p99} + + +async def _amain(args: argparse.Namespace) -> None: + os.environ["VLLM_SLOWMOCK_LATENCY_MS"] = str(args.latency_ms) + os.environ["VLLM_SLOWMOCK_MATCHED_TOK"] = str(args.matched_tokens) + + if args.compare: + # Run twice: budget=0 (disabled) then budget=large (enabled). + # Use a generous budget so every probe gets its hint in step 1. + big_budget = max( + args.num_probe * args.probe_prompt_tokens * 4, + args.max_num_batched_tokens * 4, + ) + + engine = _make_engine(args, budget=0) + try: + off = await _drive_one(engine, args) + finally: + engine.shutdown() + off_stats = _summarize(off, label="BUDGET=0 (disabled)") + + # Brief gap so any lingering background work settles. + await asyncio.sleep(1.0) + + engine = _make_engine(args, budget=big_budget) + try: + on = await _drive_one(engine, args) + finally: + engine.shutdown() + on_stats = _summarize(on, label=f"BUDGET={big_budget} (enabled)") + + print("\n=== DELTA (enabled - disabled) ===") + for k in ("mean", "p50", "p90", "p99"): + d_ms = (on_stats[k] - off_stats[k]) * 1000 + print(f" TTFT {k}: {d_ms:+8.2f} ms") + print( + "\nExpected sign: negative (TTFT improves) when the simulated " + "lookup latency was being paid in front of the GPU under " + "BUDGET=0 and is hidden behind running-queue work under " + "BUDGET>0." + ) + else: + engine = _make_engine(args, budget=args.budget) + try: + probes = await _drive_one(engine, args) + finally: + engine.shutdown() + _summarize(probes, label=f"BUDGET={args.budget}") + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct") + p.add_argument("--max-model-len", type=int, default=2048) + p.add_argument("--max-num-batched-tokens", type=int, default=512) + p.add_argument("--max-num-seqs", type=int, default=32) + p.add_argument("--gpu-memory-utilization", type=float, default=0.85) + p.add_argument("--dtype", default="auto") + p.add_argument("--enforce-eager", action="store_true", default=True) + + # The "warmup" phase is what saturates the per-step token budget. We + # rely on prefill, not decode, to do the saturating: prefill consumes + # `warmup_prompt_tokens` per step per request until each prompt is + # fully ingested, while decode only consumes 1 token per step per + # request. Pick num_warmup * warmup_prompt_tokens >> max_num_batched + # _tokens so several consecutive schedule steps run with the running + # queue completely full -- that is the window in which the bug bites. + p.add_argument("--num-warmup", type=int, default=4) + p.add_argument("--warmup-prompt-tokens", type=int, default=384) + p.add_argument("--warmup-max-tokens", type=int, default=64) + p.add_argument("--warmup-settle-s", type=float, default=0.5) + + p.add_argument("--num-probe", type=int, default=8) + p.add_argument("--probe-prompt-tokens", type=int, default=64) + p.add_argument("--probe-max-tokens", type=int, default=1) + + p.add_argument("--latency-ms", type=float, default=200.0) + p.add_argument("--matched-tokens", type=int, default=0) + p.add_argument("--budget", type=int, default=0) + p.add_argument( + "--compare", + action="store_true", + help="Run twice (budget=0 then large budget) and print delta.", + ) + return p.parse_args() + + +def main() -> None: + args = _parse_args() + asyncio.run(_amain(args)) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/kv_connector_prefetch/slow_mock_connector.py b/benchmarks/kv_connector_prefetch/slow_mock_connector.py new file mode 100644 index 000000000000..503dcd73ece0 --- /dev/null +++ b/benchmarks/kv_connector_prefetch/slow_mock_connector.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""SlowMockKVConnector: a synthetic KV connector for the prefetch microbench. + +Simulates a disk-backed KV store: every request takes +`simulated_disk_latency_ms` of wall-clock time before its lookup result is +"ready". Until the latency has elapsed, `get_num_new_matched_tokens` returns +(None, True) -- i.e. "ask again later, async". Once the latency has elapsed, +it returns (matched_tokens, True). This is the failure mode the early +prefetch pass is meant to hide: when the running queue saturates the per- +step token budget, the waiting-queue scheduling loop never runs and the +lookup timer never starts. + +The connector is registered via `KVTransferConfig.kv_connector_module_path` +so worker processes can import it without modifying the vLLM package. It +intentionally does **no** real KV transfer: the worker-side methods are +no-ops, and `get_num_new_matched_tokens` reports "matched but already +loaded" so the scheduler does not actually wait for KV blocks. +""" + +from __future__ import annotations + +import os +import time +from typing import TYPE_CHECKING, Any + +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + + +def _env_float(name: str, default: float) -> float: + raw = os.environ.get(name) + if raw is None or raw == "": + return default + return float(raw) + + +def _env_int(name: str, default: int) -> int: + raw = os.environ.get(name) + if raw is None or raw == "": + return default + return int(raw) + + +class _Empty(KVConnectorMetadata): + pass + + +class SlowMockKVConnector(KVConnectorBase_V1): + """Synthetic disk-like connector with configurable lookup latency. + + Configured via env vars (read at construction time): + VLLM_SLOWMOCK_LATENCY_MS: per-request lookup wall-clock latency. + Default 200 ms. + VLLM_SLOWMOCK_MATCHED_TOK: number of tokens reported as cache-hit + once the lookup is "ready". Default 0 + (we just want to show the *latency hide*, + not actually transfer KV). + """ + + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: KVCacheConfig | None = None, + ): + super().__init__(vllm_config, role, kv_cache_config) + self._latency_s = _env_float("VLLM_SLOWMOCK_LATENCY_MS", 200.0) / 1000.0 + self._matched_tokens = _env_int("VLLM_SLOWMOCK_MATCHED_TOK", 0) + # request_id -> monotonic time at which the lookup is considered done. + self._lookup_ready_at: dict[str, float] = {} + # request_id -> True if the early prefetch hook already kicked off + # the lookup timer for this request. + self._prefetched: dict[str, bool] = {} + + # ------------------------------------------------------------------ + # Scheduler-side + # ------------------------------------------------------------------ + def maybe_prefetch_request(self, request: Request) -> bool: + if request.request_id in self._lookup_ready_at: + return False + self._lookup_ready_at[request.request_id] = time.monotonic() + self._latency_s + self._prefetched[request.request_id] = True + return True + + def get_num_new_matched_tokens( + self, + request: Request, + num_computed_tokens: int, + ) -> tuple[int | None, bool]: + # First time we see this request through the regular path: start the + # timer now (mirrors the LMCache behavior of starting the lookup + # inside get_num_new_matched_tokens). If maybe_prefetch_request was + # already called for this request, the timer has been running since + # then -- which is the whole point of the optimization. + ready_at = self._lookup_ready_at.get(request.request_id) + if ready_at is None: + ready_at = time.monotonic() + self._latency_s + self._lookup_ready_at[request.request_id] = ready_at + + if time.monotonic() < ready_at: + # Lookup still in flight -- ask the scheduler to retry later. + return None, True + return self._matched_tokens, False + + def update_state_after_alloc( + self, + request: Request, + blocks: KVCacheBlocks, + num_external_tokens: int, + ) -> None: + return + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + return _Empty() + + def request_finished( + self, + request: Request, + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + self._lookup_ready_at.pop(request.request_id, None) + self._prefetched.pop(request.request_id, None) + return False, None + + # ------------------------------------------------------------------ + # Worker-side: this connector does no real transfer. + # ------------------------------------------------------------------ + def start_load_kv(self, forward_context: ForwardContext, **kwargs: Any) -> None: + return + + def wait_for_layer_load(self, layer_name: str) -> None: + return + + def save_kv_layer( + self, + layer_name: str, + kv_layer, + attn_metadata, + **kwargs: Any, + ) -> None: + return + + def wait_for_save(self) -> None: + return diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 42f4825e2b3b..2bc6423c22ba 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -24,6 +24,7 @@ from vllm.utils.hashing import sha256 from vllm.v1.core.encoder_cache_manager import EncoderCacheManager from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash +from vllm.v1.core.sched.interface import PauseState from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.engine import FinishReason @@ -3836,6 +3837,170 @@ def test_remote_kv_promotion_keeps_fcfs_with_grammar_prefix(): ] +def _build_prefetch_scheduler(token_budget: int) -> Scheduler: + """Build a real scheduler with a mockable connector prefetch hook. + + Uses the standard `create_scheduler` + MockKVConnector path to avoid + hand-stitching scheduler internals, then overrides + `kv_connector_prefetch_token_budget` and replaces + `connector.maybe_prefetch_request` with a Mock so call sites can be + asserted against. The real `get_num_new_matched_tokens` on + MockKVConnector remains intact, so we can also assert that the early + prefetch pass does NOT invoke it. + """ + scheduler = create_scheduler( + use_kv_connector=mock_kv(matched_tokens=0, is_async=False), + ) + scheduler.kv_connector_prefetch_token_budget = token_budget + scheduler.connector.maybe_prefetch_request = Mock(return_value=True) + scheduler.connector.get_num_new_matched_tokens = Mock( + wraps=scheduler.connector.get_num_new_matched_tokens + ) + return scheduler + + +def _prefetch_call_args(scheduler: Scheduler) -> list[Request]: + return [ + call.args[0] + for call in scheduler.connector.maybe_prefetch_request.call_args_list + ] + + +def test_kv_connector_early_prefetch_disabled_by_default(): + """Default token budget of 0 disables early prefetch entirely.""" + scheduler = create_scheduler( + use_kv_connector=mock_kv(matched_tokens=0, is_async=False), + ) + assert scheduler.kv_connector_prefetch_token_budget == 0 + scheduler.connector.maybe_prefetch_request = Mock(return_value=True) + + for request in create_requests(num_requests=3, num_tokens=10): + scheduler.add_request(request) + + scheduler._early_prefetch_waiting_kv() + scheduler.connector.maybe_prefetch_request.assert_not_called() + + +def test_kv_connector_early_prefetch_respects_token_budget(): + """Cumulative `request.num_tokens` is bounded by the per-step budget.""" + scheduler = _build_prefetch_scheduler(token_budget=25) + requests = create_requests(num_requests=3, num_tokens=10) + for request in requests: + scheduler.add_request(request) + + scheduler._early_prefetch_waiting_kv() + + # 10 + 10 = 20 <= 25; adding the third (30) would exceed the budget, + # so iteration stops and the third request is left for a later step. + assert _prefetch_call_args(scheduler) == [requests[0], requests[1]] + # The early pass must not invoke the regular matched-tokens API. + scheduler.connector.get_num_new_matched_tokens.assert_not_called() + + +def test_kv_connector_early_prefetch_persists_across_steps(): + """Already-prefetched ids are remembered across schedule steps so we + don't re-invoke the connector hook for the same request.""" + scheduler = _build_prefetch_scheduler(token_budget=15) + requests = create_requests(num_requests=3, num_tokens=10) + for request in requests: + scheduler.add_request(request) + + scheduler._early_prefetch_waiting_kv() + scheduler._early_prefetch_waiting_kv() + + # Step 1: prefetch requests[0] (10 used; +10 would be 20 > 15 → stop). + # Step 2: requests[0] is already in the cache → skipped; requests[1] + # fits a fresh per-step budget of 15. + assert _prefetch_call_args(scheduler) == [requests[0], requests[1]] + + +@pytest.mark.parametrize("pause_state", [PauseState.PAUSED_NEW, PauseState.PAUSED_ALL]) +def test_kv_connector_early_prefetch_skipped_when_paused(pause_state): + """Paused schedulers must not kick off any prefetch work.""" + scheduler = _build_prefetch_scheduler(token_budget=1024) + for request in create_requests(num_requests=2, num_tokens=10): + scheduler.add_request(request) + scheduler._pause_state = pause_state + + scheduler._early_prefetch_waiting_kv() + scheduler.connector.maybe_prefetch_request.assert_not_called() + + +@pytest.mark.parametrize( + "skip_status", + [ + RequestStatus.WAITING_FOR_REMOTE_KVS, + RequestStatus.WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR, + RequestStatus.PREEMPTED, + ], +) +def test_kv_connector_early_prefetch_skips_non_plain_waiting(skip_status): + """Only plain-WAITING requests with no computed tokens are prefetched.""" + scheduler = _build_prefetch_scheduler(token_budget=1024) + skip, ok = create_requests(num_requests=2, num_tokens=10) + scheduler.add_request(skip) + scheduler.add_request(ok) + skip.status = skip_status + + scheduler._early_prefetch_waiting_kv() + assert _prefetch_call_args(scheduler) == [ok] + + +def test_kv_connector_early_prefetch_skips_partially_computed_request(): + """Requests that already carry computed tokens are skipped.""" + scheduler = _build_prefetch_scheduler(token_budget=1024) + partial, fresh = create_requests(num_requests=2, num_tokens=10) + scheduler.add_request(partial) + scheduler.add_request(fresh) + partial.num_computed_tokens = 4 + + scheduler._early_prefetch_waiting_kv() + assert _prefetch_call_args(scheduler) == [fresh] + + +def test_kv_connector_early_prefetch_id_cleared_on_finish(): + """Finishing/aborting a request must drop it from the prefetched set, + so reuse of the same id later is safe.""" + scheduler = _build_prefetch_scheduler(token_budget=1024) + request = create_requests(num_requests=1, num_tokens=10)[0] + scheduler.add_request(request) + + scheduler._early_prefetch_waiting_kv() + assert request.request_id in scheduler._kv_connector_prefetched_req_ids + + scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED) + assert request.request_id not in scheduler._kv_connector_prefetched_req_ids + + +def test_kv_connector_early_prefetch_oversized_head_preserves_fcfs(): + """A head-of-queue request larger than the budget blocks subsequent + prefetches in the same step (FCFS — never bypass the head).""" + scheduler = _build_prefetch_scheduler(token_budget=15) + big = create_requests(num_requests=1, num_tokens=20, req_ids=["big"])[0] + small = create_requests(num_requests=1, num_tokens=10, req_ids=["small"])[0] + scheduler.add_request(big) + scheduler.add_request(small) + + scheduler._early_prefetch_waiting_kv() + scheduler.connector.maybe_prefetch_request.assert_not_called() + + +def test_kv_connector_early_prefetch_handles_negative_hook_return(): + """A connector that declines to prefetch (returns False) must not + consume budget nor be marked as prefetched, so a later step can retry.""" + scheduler = _build_prefetch_scheduler(token_budget=1024) + scheduler.connector.maybe_prefetch_request = Mock(return_value=False) + request = create_requests(num_requests=1, num_tokens=10)[0] + scheduler.add_request(request) + + scheduler._early_prefetch_waiting_kv() + assert request.request_id not in scheduler._kv_connector_prefetched_req_ids + + scheduler._early_prefetch_waiting_kv() + # Same request should be probed again on the next step. + assert scheduler.connector.maybe_prefetch_request.call_count == 2 + + def test_fcfs_mixed_skipped_waiting_types_keep_order(): scheduler = create_scheduler(max_num_batched_tokens=20) scheduler._update_waiting_for_remote_kv = Mock() diff --git a/tests/v1/kv_connector/unit/test_lmcache_mp_connector.py b/tests/v1/kv_connector/unit/test_lmcache_mp_connector.py new file mode 100644 index 000000000000..9d8b5dbade50 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_lmcache_mp_connector.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for LMCacheMPConnector.maybe_prefetch_request. + +The full connector requires an LMCache backend and a running scheduler +adapter. These tests instead bind the methods under test to a MagicMock +with `spec=LMCacheMPConnector`, so we exercise the real prefetch logic +against a stubbed scheduler adapter without spinning up LMCache. +""" + +from unittest.mock import MagicMock + +import pytest + +pytest.importorskip("lmcache") + +from tests.v1.core.utils import create_requests # noqa: E402 +from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_mp_connector import ( # noqa: E402 + LMCacheMPConnector, +) +from vllm.v1.request import RequestStatus # noqa: E402 + + +def _bind(connector: MagicMock, *method_names: str) -> None: + """Bind real `LMCacheMPConnector` methods onto a MagicMock instance.""" + for name in method_names: + method = getattr(LMCacheMPConnector, name) + setattr(connector, name, method.__get__(connector, LMCacheMPConnector)) + + +def _make_connector() -> MagicMock: + connector = MagicMock(spec=LMCacheMPConnector) + connector.request_trackers = {} + connector.scheduler_adapter = MagicMock() + _bind( + connector, + "maybe_prefetch_request", + "_get_or_create_request_tracker", + "_maybe_submit_lookup_request", + ) + return connector + + +def test_maybe_prefetch_request_submits_lookup_for_fresh_waiting_request(): + """A plain WAITING request with no computed tokens should trigger a + lookup submission and have its tracker created.""" + connector = _make_connector() + request = create_requests(num_requests=1, num_tokens=10)[0] + assert request.status == RequestStatus.WAITING + assert request.num_computed_tokens == 0 + + submitted = connector.maybe_prefetch_request(request) + + assert submitted is True + assert request.request_id in connector.request_trackers + submit = connector.scheduler_adapter.maybe_submit_lookup_request + submit.assert_called_once() + call = submit.call_args + assert call.args[0] == request.request_id + assert call.kwargs["token_ids"] == list(request.all_token_ids) + # cache_salt comes from the tracker; defaults to "" when request has none. + assert call.kwargs["cache_salt"] == (request.cache_salt or "") + + +@pytest.mark.parametrize( + "skip_status", + [ + RequestStatus.PREEMPTED, + RequestStatus.WAITING_FOR_REMOTE_KVS, + RequestStatus.WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR, + RequestStatus.RUNNING, + ], +) +def test_maybe_prefetch_request_skips_non_plain_waiting(skip_status): + """The hook must bail (and not submit a lookup) for any request that + is not in the plain WAITING state.""" + connector = _make_connector() + request = create_requests(num_requests=1, num_tokens=10)[0] + request.status = skip_status + + submitted = connector.maybe_prefetch_request(request) + + assert submitted is False + connector.scheduler_adapter.maybe_submit_lookup_request.assert_not_called() + assert request.request_id not in connector.request_trackers + + +def test_maybe_prefetch_request_skips_when_partially_computed(): + """Already-partially-computed requests should not have prefetch hints + issued, even if their status is still WAITING.""" + connector = _make_connector() + request = create_requests(num_requests=1, num_tokens=10)[0] + request.num_computed_tokens = 4 + + submitted = connector.maybe_prefetch_request(request) + + assert submitted is False + connector.scheduler_adapter.maybe_submit_lookup_request.assert_not_called() + + +def test_maybe_prefetch_request_is_idempotent_for_repeat_calls(): + """Calling the hook twice for the same fresh waiting request reuses the + existing tracker and just re-submits the lookup -- the underlying + `scheduler_adapter.maybe_submit_lookup_request` is itself the + deduplication boundary on the LMCache side.""" + connector = _make_connector() + request = create_requests(num_requests=1, num_tokens=10)[0] + + assert connector.maybe_prefetch_request(request) is True + tracker = connector.request_trackers[request.request_id] + + assert connector.maybe_prefetch_request(request) is True + # Same tracker instance is retained on the second call. + assert connector.request_trackers[request.request_id] is tracker + assert connector.scheduler_adapter.maybe_submit_lookup_request.call_count == 2 diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index fb6951ea7dd1..1cb92763c79c 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -154,6 +154,26 @@ class SchedulerConfig: while a larger value (e.g., 10) reduces host overhead and may increase throughput by batching multiple tokens before sending.""" + kv_connector_prefetch_token_budget: int = Field(default=0, ge=0) + """Per-step prompt-token budget for early KV connector prefetch hints. + + When positive, the scheduler asks the KV connector to start async + prefetch for waiting requests at the top of each schedule step (before + the running queue is processed). This lets connectors with disk- or + network-backed KV stores (e.g. LMCache) overlap their lookups with GPU + work that is still running, instead of stalling the GPU once the + waiting queue is finally polled. See GH issue #41784 for the motivating + trace. + + The budget is summed by `request.num_tokens` of hinted requests in a + single step and bounds connector-side staging memory pressure (token + count is a closer proxy to that pressure than a fixed request count). + Hints already submitted in earlier steps are remembered and not + reissued, so a small per-step budget will steadily drain a backlog + across consecutive steps without re-doing work. + + Default 0 disables the feature and preserves prior behavior.""" + @staticmethod def default_factory(**kwargs): """ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index ef143cba7fb5..343fc545e0a2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -10,6 +10,8 @@ get_num_new_matched_tokens() - get number of new tokens that exist in the remote KV cache. Might be called multiple times for a given request and should be side-effect free. + maybe_prefetch_request() - optionally starts connector-side + prefetch work when a request is waiting to be scheduled. update_state_after_alloc() - update KVConnector state after temporary buffer alloc by the CacheManager. update_connector_output() - update KVConnector state after @@ -446,6 +448,42 @@ def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None: # Scheduler-side methods # ============================== + def maybe_prefetch_request(self, request: "Request") -> bool: + """ + Optionally start connector-side prefetch work for a waiting request. + + This is a best-effort hint invoked by the scheduler at the top of a + schedule step, before the running queue is processed, so connectors + with async lookups (e.g. disk-backed KV stores) can begin loading + staging data while the GPU is still busy with running requests. The + default implementation is a no-op for connectors that do not benefit + from early lookup. + + Contract: + - Pure hint -- implementations MUST NOT allocate KV blocks, mutate + scheduler-visible request state, or change anything that the + regular scheduling path reads. Matching, allocation, and + metadata remain owned by `get_num_new_matched_tokens` / + `update_state_after_alloc` and friends. + - Idempotent -- the scheduler tracks ids it has already hinted and + will not call this repeatedly for the same request, but + implementations should still tolerate redundant calls (e.g. + after preempt/resume) without re-issuing duplicate work. + - Non-blocking -- callers expect this to return promptly. + + Args: + request: The waiting request the connector may prefetch for. + The scheduler only invokes this for requests in plain + `WAITING` status with no computed tokens yet. + + Returns: + True if prefetch work was actually submitted for this request, + otherwise False. The scheduler uses the return value only for + prefetch-budget accounting -- a `False` return leaves the budget + untouched and lets a later step retry the hint. + """ + return False + @abstractmethod def get_num_new_matched_tokens( self, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py index f55f04a08252..67143f870f8c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py @@ -766,11 +766,7 @@ def get_num_new_matched_tokens( if request.status == RequestStatus.PREEMPTED: return 0, False - self.scheduler_adapter.maybe_submit_lookup_request( - request.request_id, - token_ids=list(request.all_token_ids), - cache_salt=tracker.cache_salt, - ) + self._maybe_submit_lookup_request(request, tracker) ret = self.scheduler_adapter.check_lookup_result(request.request_id) if ret is None: @@ -799,6 +795,14 @@ def get_num_new_matched_tokens( ) return need_to_load, need_to_load > 0 + def maybe_prefetch_request(self, request: "Request") -> bool: + if request.status != RequestStatus.WAITING or request.num_computed_tokens != 0: + return False + + tracker = self._get_or_create_request_tracker(request) + self._maybe_submit_lookup_request(request, tracker) + return True + def update_state_after_alloc( self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int ): @@ -1177,6 +1181,15 @@ def _get_or_create_request_tracker( self.request_trackers[request_id] = new_tracker return self.request_trackers[request_id] + def _maybe_submit_lookup_request( + self, request: "Request", tracker: LMCacheMPRequestTracker + ) -> None: + self.scheduler_adapter.maybe_submit_lookup_request( + request.request_id, + token_ids=list(request.all_token_ids), + cache_salt=tracker.cache_salt, + ) + def _cleanup_request_tracker(self, request_id: str) -> None: """ Clean up request tracker and associated lookup future for a request. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1b3803139217..60b0edc9239b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -669,6 +669,10 @@ class EngineArgs: stream_interval: int = SchedulerConfig.stream_interval + kv_connector_prefetch_token_budget: int = ( + SchedulerConfig.kv_connector_prefetch_token_budget + ) + kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill optimization_level: OptimizationLevel = VllmConfig.optimization_level performance_mode: PerformanceMode = VllmConfig.performance_mode @@ -1363,6 +1367,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: scheduler_group.add_argument( "--stream-interval", **scheduler_kwargs["stream_interval"] ) + scheduler_group.add_argument( + "--kv-connector-prefetch-token-budget", + **scheduler_kwargs["kv_connector_prefetch_token_budget"], + ) # Compilation arguments compilation_kwargs = get_kwargs(CompilationConfig) @@ -1969,6 +1977,9 @@ def create_engine_config( disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager, async_scheduling=self.async_scheduling, stream_interval=self.stream_interval, + kv_connector_prefetch_token_budget=( + self.kv_connector_prefetch_token_budget + ), ) if not model_config.is_multimodal_model and self.default_mm_loras: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8aaeb3970079..46f283ef8b29 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -163,6 +163,17 @@ def __init__( # requests skipped in waiting flow due async deps or constraints. self.skipped_waiting = create_request_queue(self.policy) self.running: list[Request] = [] + # Per-step prompt-token budget for early KV connector prefetch hints. + # 0 (the default) disables the early-prefetch pass and preserves the + # pre-existing scheduling behavior. + self.kv_connector_prefetch_token_budget = ( + self.scheduler_config.kv_connector_prefetch_token_budget + ) + # Request ids for which a connector prefetch hint has already been + # submitted. Persists across schedule steps so we don't re-issue the + # hint for the same request, and is cleaned up by `_free_request` + # when the request finishes (including aborts). + self._kv_connector_prefetched_req_ids: set[str] = set() # The request IDs that are finished in between the previous and the # current steps. This is used to notify the workers about the finished @@ -307,6 +318,77 @@ def _mamba_block_aligned_split( pass return num_new_tokens + def _early_prefetch_waiting_kv(self) -> None: + """Hint the KV connector to start async prefetch for waiting requests + before the running queue is scheduled. + + Motivation (see GH issue #41784): when the running queue saturates the + per-step token budget, the waiting-queue scheduling loop never runs, + so connectors that perform async loads (e.g. LMCache disk reads) never + observe newly arrived requests until the running queue drains. By the + time the connector lookup is finally issued, the GPU is already idle + waiting for it, leaking hundreds of milliseconds per request and + cascading under load. + + This pass walks the waiting queues and asks the connector to *start* + the lookup early via `KVConnectorBase_V1.maybe_prefetch_request`. The + regular scheduling path still owns matching, allocation, and metadata; + the hint is purely a "kick async work off sooner" primitive, which + keeps the scheduler / connector contract unchanged. + + Bounding: prompt-token usage is summed per step against + `kv_connector_prefetch_token_budget` to bound in-flight prefetch + memory pressure on the connector (token budget is a closer proxy to + connector-side staging memory than a fixed request count). FCFS is + preserved -- if the head request would exceed the budget we stop + rather than skip it. Hints already submitted in earlier steps are + tracked in `_kv_connector_prefetched_req_ids` so we never repeat a + call for the same request. + + No-ops when: + - no KV connector is configured; + - the scheduler is paused (PAUSED_NEW or PAUSED_ALL) -- in those + states no waiting request will be scheduled, so prefetching is + wasted work and risks premature eviction of the staged data; + - the prefetch token budget is 0 (the default). + """ + if ( + self.connector is None + or self._pause_state != PauseState.UNPAUSED + or self.kv_connector_prefetch_token_budget <= 0 + ): + return + + # TODO: Allow connectors to provide a capacity-based default budget, + # e.g. derived from CPU staging pool size, instead of requiring users + # to tune a fixed token budget. + tokens_used = 0 + for queue in (self.waiting, self.skipped_waiting): + for request in queue: + if request.request_id in self._kv_connector_prefetched_req_ids: + continue + # Only fresh waiting requests are eligible. Preempted or + # blocked-waiting (e.g. WAITING_FOR_REMOTE_KVS) entries have + # connector-side state already in flight or are mid-resume, + # and per-connector contracts (e.g. LMCache bails on + # PREEMPTED) make a prefetch hint either redundant or wrong. + if ( + request.status != RequestStatus.WAITING + or request.num_computed_tokens != 0 + ): + continue + + num_tokens = request.num_tokens + if tokens_used + num_tokens > self.kv_connector_prefetch_token_budget: + # FCFS: stop on first ineligible request rather than + # skipping past it. Subsequent steps will retry from the + # same head once budget is available again. + return + + if self.connector.maybe_prefetch_request(request): + self._kv_connector_prefetched_req_ids.add(request.request_id) + tokens_used += num_tokens + def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -342,6 +424,8 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.new_step_starts() + self._early_prefetch_waiting_kv() + # First, schedule the RUNNING requests. req_index = 0 while req_index < len(self.running) and token_budget > 0: @@ -991,6 +1075,7 @@ def _update_request_as_session( session.num_prompt_tokens = len(session.prompt_token_ids) session.arrival_time = update.arrival_time session.sampling_params = update.sampling_params + self._kv_connector_prefetched_req_ids.discard(session.request_id) if session.status == RequestStatus.WAITING_FOR_STREAMING_REQ: self.num_waiting_for_streaming_input -= 1 session.status = RequestStatus.WAITING @@ -1752,6 +1837,7 @@ def _free_request( ) -> dict[str, Any] | None: assert request.is_finished() + self._kv_connector_prefetched_req_ids.discard(request.request_id) connector_delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) request_id = request.request_id