Skip to content

[Core][KV Connector] Bounded early prefetch for waiting requests#42086

Open
chfeng-cs wants to merge 1 commit intovllm-project:mainfrom
chfeng-cs:fix-41784-kv-prefetch
Open

[Core][KV Connector] Bounded early prefetch for waiting requests#42086
chfeng-cs wants to merge 1 commit intovllm-project:mainfrom
chfeng-cs:fix-41784-kv-prefetch

Conversation

@chfeng-cs
Copy link
Copy Markdown
Contributor

@chfeng-cs chfeng-cs commented May 8, 2026

Summary

Fixes #41784

When the running queue saturates the per-step token budget, the
waiting-queue scheduling loop never executes, so KV connectors with
async disk/network loads (e.g. LMCache) never start their lookups
until the running queue drains. By then the GPU is already idle
waiting on the just-started lookup, wasting hundreds of milliseconds
per request and cascading under load.

This PR adds an optional _early_prefetch_waiting_kv pass that runs
at the top of each schedule() call, before the running queue is
processed, and hints the KV connector to start async lookups for
waiting requests. This lets connectors kick off async loads while the
GPU is still busy, hiding disk latency behind ongoing GPU work.

Design

New scheduler config: kv_connector_prefetch_token_budget
(default 0, disabled), plumbed through EngineArgs /
--kv-connector-prefetch-token-budget so it is reachable from
LLM(...) and vllm serve. Budget is summed by request.num_tokens
per step, bounding connector-side staging memory pressure. FCFS is
preserved — iteration stops at the first request that would exceed
the budget rather than skipping past it.

New interface: KVConnectorBase_V1.maybe_prefetch_request()
a pure hint, no-op by default (returns False). Implementations must
not allocate KV blocks or mutate scheduler-visible request state.
Returns bool to signal whether work was actually submitted, so
budget accounting can skip no-ops.

Cross-step deduplication: submitted request ids are tracked in
_kv_connector_prefetched_req_ids and not re-submitted in later
steps. Cleaned up on request finish/abort to prevent stale state.

LMCacheMPConnector: implements maybe_prefetch_request() by
extracting the existing lookup-submission path into a shared
_maybe_submit_lookup_request helper, reused by both the new hook
and the original get_num_new_matched_tokens path.

Benchmark (simulated disk, GTX 1660 Super)

Synthetic benchmark using SlowMockKVConnector (200 ms simulated
disk latency), opt-125m, --max-num-batched-tokens 128:
BUDGET=0 (disabled): TTFT mean ≈ 457.7 ms
BUDGET=512 (enabled): TTFT mean ≈ 355.3 ms
DELTA: TTFT mean -102.4 ms (p50 -101.8 / p90 -104.1 / p99 -104.1)

~50% of the lookup latency is hidden. The remaining gap closes as
running-queue saturation duration increases — in production workloads
with sustained heavy load the improvement approaches 100%.

Reproduce on any GPU (~1 min):

python -m benchmarks.kv_connector_prefetch.run_microbench \
  --model facebook/opt-125m \
  --max-model-len 512 --max-num-batched-tokens 128 \
  --gpu-memory-utilization 0.6 \
  --num-warmup 3 --warmup-prompt-tokens 96 \
  --num-probe 4 --latency-ms 200 \
  --compare

Real LMCache + disk backend benchmark on A10 in progress.

Test Plan

  • tests/v1/core/test_scheduler.py: 9 new unit tests covering
    default-disabled behavior, per-step token budget bound, cross-step
    id persistence, paused-state no-op, blocked/preempted/partial-
    computed skip semantics, finish/abort cleanup, FCFS preservation on
    oversized head, and hook-returns-False retry behavior.
  • tests/v1/kv_connector/unit/test_lmcache_mp_connector.py: unit
    tests for the LMCache prefetch hook (skipped when lmcache is not
    installed), covering fresh waiting requests, non-plain-waiting skip,
    partial-computed skip, and idempotency.
  • benchmarks/kv_connector_prefetch/: synthetic microbench with
    SlowMockKVConnector that reproduces the [Performance]: KV prefetching not spawned on time can cause unnecessary GPU idle #41784 saturation pattern
    with configurable disk-latency mock, no real hardware dependencies.

Adds a `kv_connector_prefetch_token_budget` scheduler config and an
`_early_prefetch_waiting_kv` pass that, before the running queue is
scheduled, hints the KV connector to start async lookups for waiting
requests. This hides connector lookup latency (e.g. LMCache disk reads)
behind ongoing GPU work in the saturation regime described in vllm-project#41784,
where a full running queue otherwise prevents the waiting-queue loop
from ever invoking the connector.

The hint is delivered via a new optional `maybe_prefetch_request` hook
on `KVConnectorBase_V1` (default no-op) and implemented for
`LMCacheMPConnector` by extracting the existing lookup-submission path
into a shared helper. The scheduler tracks already-hinted ids across
schedule steps and drops them on request finish, so each request is
hinted at most once. Default budget is 0 (disabled), preserving prior
behavior.

Plumbed the new config through `EngineArgs` /
`--kv-connector-prefetch-token-budget` so it is reachable from
`LLM(...)` and `vllm serve`.

Test plan:
- New unit tests in tests/v1/core/test_scheduler.py covering: default-
  disabled behavior, per-step token budget bound, cross-step id
  persistence, paused-state no-op, blocked / preempted / partial-
  computed skip semantics, finish/abort cleanup, FCFS preservation
  on oversized head, and hook-returns-False retry behavior.
- New tests/v1/kv_connector/unit/test_lmcache_mp_connector.py covering
  the LMCache prefetch hook (skipped when lmcache is not installed).
- Synthetic microbench under benchmarks/kv_connector_prefetch/ that
  reproduces the vllm-project#41784 saturation pattern with a configurable disk-
  latency mock connector.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Ethan Feng <ethan.fengch@gmail.com>
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an early-prefetch mechanism for the KV connector to hide lookup latency behind GPU work. It adds a kv_connector_prefetch_token_budget configuration to the scheduler, allowing it to issue async prefetch hints for waiting requests at the start of each schedule step. The implementation includes updates to the base KV connector interface, the LMCache connector, and the V1 scheduler, along with a new microbenchmark and comprehensive unit tests to verify the prefetch logic and its impact on TTFT. I have no feedback to provide.

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check whether the callback added here #41383 works for your use case or if you need the observableQ proposed in that RFC @chfeng-cs ?

@chfeng-cs
Copy link
Copy Markdown
Contributor Author

@NickLucche thanks for the pointer — the callback added in #41383 definitely looks relevant for my use case, especially the new on_new_request() hook and the connector-side bookkeeping flow.

I still need to study the implementation in more detail and test it against the scenario from #41784 before I can tell whether it fully covers the issue, or whether something closer to the observableQ approach is still needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector performance Performance-related issues v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance]: KV prefetching not spawned on time can cause unnecessary GPU idle

2 participants