Skip to content

[KV Connector] Opt DecodeBenchConnector into SupportsHMA#41770

Merged
ywang96 merged 1 commit intovllm-project:mainfrom
liuzijing2014:decode-bench-supports-hma
May 7, 2026
Merged

[KV Connector] Opt DecodeBenchConnector into SupportsHMA#41770
ywang96 merged 1 commit intovllm-project:mainfrom
liuzijing2014:decode-bench-supports-hma

Conversation

@liuzijing2014
Copy link
Copy Markdown
Collaborator

@liuzijing2014 liuzijing2014 commented May 6, 2026

Summary

Opt DecodeBenchConnector into SupportsHMA so decode-only benchmark recipes can run with the hybrid KV cache manager enabled. Without this, users must pass --disable-hybrid-kv-cache-manager, which collapses hybrid-model KV cache groups (SWA / MLA compress=4 / MLA compress=128 / sparse indexer) into a single uniform page size and throws away the compression savings.

Implementation is minimal because this connector is a dummy fill that owns no external per-block state:

  • Inherit from SupportsHMA.
  • Implement request_finished_all_groups: same cleanup as the single-group request_finished (delegates to connector_scheduler.request_finished()); block_ids is ignored because there is no per-block external state to release.

Not a duplicate

Searched open PRs for DecodeBenchConnector, SupportsHMA, and decode_bench_connector (via the GitHub search API). No PR targets this connector. The closest neighbor, #41644 ("Keep HMA enabled for supported KV connectors"), is orthogonal — it changes the default HMA decision in vllm/config/vllm.py for connectors that already declare SupportsHMA. This PR adds one more connector to that supported set; the two changes compose.

Test plan

Hardware: 1× node with 4× GB200, single-node DEP=4 (DP=4, EP=4). Model: DeepSeek-V4-Flash with --load-format dummy and --tokenizer-mode deepseek_v4. Connector: DecodeBenchConnector, HMA explicitly enabled via --no-disable-hybrid-kv-cache-manager.

Same vllm serve command in both runs:

vllm serve <DSv4-Flash> --port 8000 \
  -dp 4 -ep --data-parallel-size-local 4 \
  --load-format dummy --tokenizer-mode deepseek_v4 \
  --kv-cache-dtype fp8 --max-model-len 8192 \
  --no-disable-hybrid-kv-cache-manager \
  --max-num-seqs 32 --max-num-batched-tokens 32 \
  --block-size 256 --gpu-memory-utilization 0.85 \
  --kv-transfer-config '{"kv_connector":"DecodeBenchConnector","kv_role":"kv_both","kv_connector_extra_config":{"fill_mean":0.015,"fill_std":0.02}}'

Without the PR (only difference: class DecodeBenchConnector(KVConnectorBase_V1) — no SupportsHMA mixin)

Crashes at engine init in KVConnectorFactory.create_connector_v1():

ValueError: Connector DecodeBenchConnector does not support HMA but HMA is enabled.
            Please set `--disable-hybrid-kv-cache-manager`.

All four DP ranks fail identically; the engine cores then exit with RuntimeError: Worker failed with error '...'.

With the PR

Server starts cleanly. HMA is active:

INFO ... [deepseek_v4_attention.py:614] Using DeepSeek's fp8_ds_mla KV cache format.
INFO ... [kv_cache_utils.py:1713] GPU KV cache size: 86,928 tokens     (each of 4 ranks)
INFO ...:     Application startup complete.

Functional sanity check via vllm bench serve (random ISL=4000, OSL=100, 32 prompts, max-concurrency=8):

Successful requests:                     32
Failed requests:                         0
Request throughput (req/s):              4.52
Total token throughput (tok/s):          18545.58
Mean TPOT (ms):                          12.78    P99: 20.53
Mean TTFT (ms):                          486.70   P99: 1456.91

@mergify mergify Bot added the kv-connector label May 6, 2026
@liuzijing2014 liuzijing2014 marked this pull request as ready for review May 6, 2026 01:03
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the DecodeBenchConnector to support Hierarchical Memory Access (HMA) by inheriting from the SupportsHMA interface and implementing the request_finished_all_groups method for proper cleanup. I have no feedback to provide as there were no review comments to evaluate.

@markmc markmc requested a review from tlrmchlsmth May 6, 2026 09:31
@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label May 6, 2026
@ywang96 ywang96 enabled auto-merge (squash) May 6, 2026 21:27
@ywang96 ywang96 disabled auto-merge May 7, 2026 22:03
@ywang96 ywang96 enabled auto-merge (squash) May 7, 2026 22:03
Previously the HMA (hybrid KV cache manager) layer refused to activate
when DecodeBenchConnector was in use, because the connector did not
advertise SupportsHMA. That forced decode-only benchmark recipes to
pass --disable-hybrid-kv-cache-manager, which collapsed hybrid-model
KV cache groups (SWA / MLA compress=4 / MLA compress=128 / sparse
indexer) into a single uniform page size via unify_kv_cache_spec_
page_size, throwing away the compression savings and capping concurrent
capacity on hybrid models (e.g. DeepSeek-V4 saw ~43 concurrent
8k/1k requests instead of the model's true ceiling).

This connector is a dummy fill that owns no external per-block state,
so the HMA path has nothing extra to do. Implementation is minimal:

- Inherit from SupportsHMA.
- Implement request_finished_all_groups: delegates to the same
  scheduler.request_finished() cleanup as the single-group variant,
  ignoring block_ids (no per-block state to release).

With this change, recipes can drop --disable-hybrid-kv-cache-manager
and let HMA size each KV cache group correctly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
auto-merge was automatically disabled May 7, 2026 22:51

Head branch was pushed to by a user without write access

@liuzijing2014 liuzijing2014 force-pushed the decode-bench-supports-hma branch from 22e669d to 9404a3d Compare May 7, 2026 22:51
@ywang96 ywang96 merged commit 09a7cc5 into vllm-project:main May 7, 2026
5 of 6 checks passed
whytem pushed a commit to whytem/vllm that referenced this pull request May 8, 2026
…t#41770)

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
libinta pushed a commit to libinta/vllm that referenced this pull request May 8, 2026
…t#41770)

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Libin Tang <libin.tang@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants