Skip to content

[KV Connector] Implement on_new_request for LMCacheMPConnector#42321

Open
chfeng-cs wants to merge 2 commits into
vllm-project:mainfrom
chfeng-cs:fix-41784-lmcache-on-new-request
Open

[KV Connector] Implement on_new_request for LMCacheMPConnector#42321
chfeng-cs wants to merge 2 commits into
vllm-project:mainfrom
chfeng-cs:fix-41784-lmcache-on-new-request

Conversation

@chfeng-cs
Copy link
Copy Markdown
Contributor

@chfeng-cs chfeng-cs commented May 11, 2026

Purpose

Fixes #41784.

Under high load the scheduler's running queue fills the token budget on every step, so the waiting-queue loop is never entered and get_num_new_matched_tokens is never called. The KV connector
only learns about a waiting request when the scheduler finally processes it — which can be many iterations later. By that point a disk or remote KV fetch that should have started seconds ago is
only just kicking off, causing GPU stalls.

#41383 added an on_new_request hook to KVConnectorBase_V1, called by the scheduler the moment a request enters add_request(). This PR implements that hook in LMCacheMPConnector: it submits the
async lookup via maybe_submit_lookup_request immediately, so disk fetches are already in flight by the time the scheduler processes the request.

Resumable (streaming-input) sessions are skipped because their token IDs are incomplete at add_request time; get_num_new_matched_tokens handles those when the full prompt is available.
maybe_submit_lookup_request is idempotent (guarded by lookup_futures), so the subsequent call from get_num_new_matched_tokens is a safe no-op.

Not duplicating an existing PR: #42170 adds a new notify_new_request hook in the scheduler, duplicating what #41383 already landed. This PR instead implements the existing on_new_request hook in
LMCacheMPConnector, which is the missing piece.

Design note: why no budget control is needed

The earlier approach (#42086) used a per-step scheduler pass (_early_prefetch_waiting_kv) with a token-budget cap to rate-limit how many waiting requests were hinted per scheduling iteration.
That budget was necessary because the pass iterated over the entire waiting queue on every step — without a cap it could submit O(queue_size) lookups per step repeatedly.

on_new_request makes the budget unnecessary:

  • It fires exactly once per request, driven by actual arrivals through add_request(). Even a burst of N concurrent requests produces N sequential add_request() calls — not a batch.
  • LMCache's own worker-thread pool (--max-cpu-workers) serialises lookup execution internally, so the connector does not need to throttle submission.
  • maybe_submit_lookup_request is idempotent (guarded by lookup_futures), so duplicate submission is already impossible.

The push model (on_new_request) is self-limiting by construction; the poll model (maybe_prefetch_request) required an explicit budget to compensate for repeated iteration.

Test plan

  pytest tests/v1/kv_connector/unit/test_lmcache_mp_connector.py -v

No GPU, LMCache server, or NIXL needed — uses a mock connector. Tests are skipped automatically when lmcache is not installed.

Test results:

  tests/v1/kv_connector/unit/test_lmcache_mp_connector.py::test_on_new_request_submits_lookup        PASSED
  tests/v1/kv_connector/unit/test_lmcache_mp_connector.py::test_on_new_request_creates_tracker       PASSED
  tests/v1/kv_connector/unit/test_lmcache_mp_connector.py::test_on_new_request_skips_resumable       PASSED
  tests/v1/kv_connector/unit/test_lmcache_mp_connector.py::test_on_new_request_disabled_by_default   PASSED
  tests/v1/kv_connector/unit/test_lmcache_mp_connector.py::test_on_new_request_idempotent_on_second_add PASSED

Benchmark

Preliminary results under queue-saturated load (benchmarks/benchmark_lmcache_prefetch.py):

Setup: Qwen3.5-0.8B, LMCacheMPConnector with disk L2 backend, --prompt-tokens 200, 48 measurement requests after warmup.

tag p50 ms p90 ms p99 ms mean ms
baseline 4767.4 5350.7 5388.9 4340.9
early_prefetch 4189.5 4411.9 4490.4 3765.9
improvement ↓12.1% ↓17.5% ↓16.7% ↓13.2%

Note: GTX 1660 Super (Turing, CC 7.5) has no FlashAttention2 and slower PCIe bandwidth than datacenter GPUs, so absolute TTFT numbers are high. Results on A10/A100 will be added once the PAI instance is available — expect the relative improvement to hold or increase under higher throughput.

Benchmark results on A100/A10 to be updated.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the on_new_request method in LMCacheMPConnector to trigger early KV cache lookups for non-resumable requests. It also includes unit tests to verify that lookups are submitted correctly, trackers are created, and resumable requests are skipped. I have no feedback to provide.

chfeng-cs added 2 commits May 12, 2026 01:16
Fixes vllm-project#41784.

Under high load the scheduler's running queue fills the token budget on
every step, so the waiting-queue loop is never entered and
get_num_new_matched_tokens is never called. As a result the KV connector
learns about waiting requests only when the scheduler finally processes
them, causing disk/remote KV fetches to start too late and the GPU to
stall waiting for data.

The on_new_request hook (added to KVConnectorBase_V1 in vllm-project#41383) is
called by the scheduler the moment a request enters add_request(),
before any scheduling loop runs. Implementing it in LMCacheMPConnector
lets the connector submit the async lookup immediately via
maybe_submit_lookup_request, so disk fetches are already in flight by
the time the scheduler processes the request.

Resumable (streaming-input) sessions are skipped because their token
IDs are incomplete at add_request time; get_num_new_matched_tokens
handles those when the full prompt is available.

maybe_submit_lookup_request is idempotent (guarded by lookup_futures),
so the subsequent call from get_num_new_matched_tokens is a safe no-op.

The early-prefetch path is opt-in via lmcache.mp.early_prefetch in
kv_connector_extra_config (default false), preserving prior behavior
for existing deployments. Users who want to hide disk lookup latency
behind ongoing GPU work set:

    "kv_connector_extra_config": {
        "lmcache.mp.early_prefetch": true
    }

Signed-off-by: Ethan Feng <ethan.fengch@gmail.com>
benchmarks/benchmark_lmcache_prefetch.py measures TTFT improvement
under queue-saturated load when lmcache.mp.early_prefetch is enabled.

Signed-off-by: Ethan Feng <ethan.fengch@gmail.com>
@chfeng-cs chfeng-cs force-pushed the fix-41784-lmcache-on-new-request branch from 5f6558b to 0af100f Compare May 11, 2026 17:17
@mergify mergify Bot added the performance Performance-related issues label May 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector performance Performance-related issues v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance]: KV prefetching not spawned on time can cause unnecessary GPU idle

1 participant