Skip to content

[Core][KVConnector] Support HMA+NixlConnector#32204

Open
NickLucche wants to merge 23 commits intovllm-project:mainfrom
NickLucche:nixl-hma-rebase
Open

[Core][KVConnector] Support HMA+NixlConnector#32204
NickLucche wants to merge 23 commits intovllm-project:mainfrom
NickLucche:nixl-hma-rebase

Conversation

@NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Jan 12, 2026

Overview

Currently connectors are not able to take full advantage of models that employ hybrid attention (FA+SWA) and treat all layers as FA, as the Hybrid Kv Cache Manager is disabled.

This PR enables NixlConnector to work with the HMA, resulting in drastically reducing the number of bytes/regions moved with a xfer for SWA+FA models, while laying the ground for state-based ones (mamba etc).
Example of the former:

# NON-HMA (current master)
(EngineCore_DP0 pid=521538) get_block_descs_ids [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]]
(EngineCore_DP0 pid=521538)
get_block_descs_ids num output 4284

# HMA --no-enable-prefix-caching --no-disable-hybrid-kv-cache-manager (this PR)
get_block_descs_ids (remote descs) [[47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], [110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126], ... [379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441]]


get_block_descs_ids num output 1650

UPDATE: see comments below for a discussion on marking invalid blocks.

Test with

Enable HMA experimental support with --no-disable-hybrid-kv-cache-manager:

# usual P/D command
vllm serve google/gemma-3-4b-it
--trust-remote-code \
--block-size 64 \
--no-enable-prefix-caching \
--no-disable-hybrid-kv-cache-manager \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'

# usual toy_proxy_server.py command

lm-eval results:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.74|±  |0.0441|
|     |       |strict-match    |     5|exact_match|↑  | 0.74|±  |0.0441|

or newly added file

pytest -x -v -s tests/v1/kv_connector/unit/test_nixl_connector_hma.py

EDIT:
I've also validated part of the lm-eval CI locally, you can test out the different tracked configs with

cd tests && ENABLE_HMA_FLAG=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh

Run python -m pytest -s -v -x tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py::test_hma_sync_recompute_evicts_all_blocks for testing the invalid block handling with hma.

TODOs

  • pre-commit + mypy
  • Report and handle block-level failures
  • verify logical<>physical kernel block path
  • eval with hma disabled to make sure there's no regression
  • verify mamba-like models (defer to separate PR)
  • run lm-eval on different config
  • verify with llama4 (old optimization has been removed)
  • verify host-backed transfers (D2H->H2D)
  • block_size_ratio !=1 (defer to separate PR)

cc working with @heheda12345 @KuntaiDu @ivanium

Benchmarks

ShareGPT results, no-prefix-caching 8xH100.

Main:

# Main DTP4-PTP4
============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Maximum request concurrency:             10
Benchmark duration (s):                  74.16
Total input tokens:                      215312
Total generated tokens:                  193901
Request throughput (req/s):              13.48
Output token throughput (tok/s):         2614.71
Peak output token throughput (tok/s):    2933.00
Peak concurrent requests:                36.00
Total token throughput (tok/s):          5518.14
---------------Time to First Token----------------
Mean TTFT (ms):                          92.65
Median TTFT (ms):                        36.87
P99 TTFT (ms):                           2410.05
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          3.31
Median TPOT (ms):                        3.32
P99 TPOT (ms):                           3.46
---------------Inter-token Latency----------------
Mean ITL (ms):                           3.31
Median ITL (ms):                         3.31
P99 ITL (ms):                            4.23
==================================================

# Main "WideEP" D DPEP4 - PTP4
============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Maximum request concurrency:             10
Benchmark duration (s):                  99.42
Total input tokens:                      215312
Total generated tokens:                  199033
Request throughput (req/s):              10.06
Output token throughput (tok/s):         2001.95
Peak output token throughput (tok/s):    2236.00
Peak concurrent requests:                27.00
Total token throughput (tok/s):          4167.65
---------------Time to First Token----------------
Mean TTFT (ms):                          90.22
Median TTFT (ms):                        41.13
P99 TTFT (ms):                           120.99
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          4.51
Median TPOT (ms):                        4.48
P99 TPOT (ms):                           5.08
---------------Inter-token Latency----------------
Mean ITL (ms):                           4.50
Median ITL (ms):                         4.41
P99 ITL (ms):                            7.92
================================================== 

This PR:

# HMA DTP4 - PTP4
============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Maximum request concurrency:             10
Benchmark duration (s):                  68.28
Total input tokens:                      215312
Total generated tokens:                  191092
Request throughput (req/s):              14.65
Output token throughput (tok/s):         2798.60
Peak output token throughput (tok/s):    2924.00
Peak concurrent requests:                32.00
Total token throughput (tok/s):          5951.91
---------------Time to First Token----------------
Mean TTFT (ms):                          45.78
Median TTFT (ms):                        37.02
P99 TTFT (ms):                           563.35
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          3.32
Median TPOT (ms):                        3.32
P99 TPOT (ms):                           3.47
---------------Inter-token Latency----------------
Mean ITL (ms):                           3.32
Median ITL (ms):                         3.32
P99 ITL (ms):                            4.23
==================================================

# HMA PR "WideEP" D DPEP4 - PTP4
============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Maximum request concurrency:             10
Benchmark duration (s):                  98.65
Total input tokens:                      215312
Total generated tokens:                  199033
Request throughput (req/s):              10.14
Output token throughput (tok/s):         2017.54
Peak output token throughput (tok/s):    2218.00
Peak concurrent requests:                30.00
Total token throughput (tok/s):          4200.10
---------------Time to First Token----------------
Mean TTFT (ms):                          88.58
Median TTFT (ms):                        40.12
P99 TTFT (ms):                           168.45
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          4.48
Median TPOT (ms):                        4.47
P99 TPOT (ms):                           4.71
---------------Inter-token Latency----------------
Mean ITL (ms):                           4.47
Median ITL (ms):                         4.44
P99 ITL (ms):                            6.21
==================================================

so up to ~7% throughput in this small-scale intra-node setup. Inter-node one would be more interesting to analyze.

Note

Cursor Bugbot is generating a summary for commit 59b3474. Configure here.


Note

Enables HMA-aware KV transfer for FA+SWA models, reducing transferred regions and aligning connector behavior with hybrid KV cache groups.

  • NixlConnector/Scheduler/Worker now accept kv_cache_config and operate on multi-group BlockIds (tuples per KV group); add request_finished_all_groups and HMA marker via SupportsHMA
  • Scheduler clips SW groups (sw_sizes) and passes unclipped/then clipped IDs: uses get_unhashed_block_ids_all_groups, get_blocks_in_fa_kv_group, and computes desc IDs across groups; full-prefix hits use empty lists
  • Worker reads/clips per-group blocks, flattens for descriptor IDs, maps logical→kernel blocks per group, and marks failures using FA group only; asserts HMA requires same remote block size
  • Utils: add get_blocks_in_fa_kv_group; KVCacheBlocks adds get_unhashed_block_ids_all_groups
  • Scheduler integrates HMA path: computes finished tokens from FA group, updates invalid-block handling and eviction across groups; deprecates old llama4 local-attn optimization path
  • Extensive test updates and new test_nixl_connector_hma.py covering SW sizing, logical→kernel mapping, fewer SW blocks, metadata structure; adapt existing tests to new APIs and configs

Written by Cursor Bugbot for commit 59b3474. This will update automatically on new commits. Configure here.

@mergify
Copy link

mergify bot commented Jan 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Hybrid Memory Allocator (HMA) in the NixlConnector, which is a significant step towards optimizing performance for models with hybrid attention mechanisms. The changes are comprehensive, affecting the connector's core logic, the scheduler, and associated tests. The introduction of new data structures to handle multiple KV cache groups and the logic for clipping blocks for sliding window attention appear to be well-thought-out. The addition of dedicated HMA tests is also a positive aspect of this PR.

I've identified a critical issue in nixl_connector.py that could lead to a runtime error in non-HMA scenarios with differing block sizes. My review includes a specific comment with a suggested fix for this issue. Apart from that, the changes look solid and move vLLM forward in supporting more complex KV cache management strategies.

Comment on lines +2168 to +2169
local_block_ids = tuple(local_block_ids) if local_block_ids else []
remote_block_ids = tuple(remote_block_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are a couple of issues on these lines that will cause a runtime error when block_size_ratio > 1 for a non-HMA setup.

  1. The conditional if local_block_ids: on line 2168 will raise a ValueError because local_block_ids is a numpy array at this point. The truth value of a numpy array with more than one element is ambiguous. You should use if local_block_ids.size > 0: to check for emptiness.
  2. The conversions tuple(local_block_ids) and tuple(remote_block_ids) are incorrect. They convert the array/list into a tuple of elements (e.g., (1, 2, 3)), but the subsequent code, especially _get_block_descs_ids, expects a tuple of lists (e.g., ([1, 2, 3],)). This will cause np.concatenate to fail.

To fix this, you need to correctly check for an empty array and then wrap the result in a tuple to maintain the tuple[list[int], ...] structure.

Suggested change
local_block_ids = tuple(local_block_ids) if local_block_ids else []
remote_block_ids = tuple(remote_block_ids)
local_block_ids = (local_block_ids.tolist(),) if local_block_ids.size > 0 else ()
remote_block_ids = (remote_block_ids,)

Comment on lines +1780 to +1783
# For the purpose of marking blocks as invalid, only report FA ones to
# handle blocks<>tokens mapping consistently.
# for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids):
for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids[fa_blocks_idx]):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use a review on the KV transfer failure logic in this file.
Tagging you guys since it's a major PR @njhill @sdavidbd @wseaton

# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
local_block_ids = local_block_ids[: len(remote_block_ids)]
local_block_ids = tuple(local_block_ids) if local_block_ids else []
Copy link
Contributor

@xuechendi xuechendi Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested with original codes on hetero setting, it will report

local_block_ids = tuple(local_block_ids) if local_block_ids else []
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

after fixing, other issues occur, so I pushed a new PR as below to fix for heterogenous support

NickLucche#7

@mergify
Copy link

mergify bot commented Jan 13, 2026

Hi @NickLucche, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify mergify bot removed the needs-rebase label Jan 13, 2026
@mergify
Copy link

mergify bot commented Jan 13, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link

mergify bot commented Jan 13, 2026

Hi @NickLucche, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Comment on lines +2310 to +2315
for i, remote_group in enumerate(remote_block_ids):
num_remote_blocks = len(remote_group)
num_local_blocks = len(local_block_ids[i])
assert num_local_blocks <= num_remote_blocks
# Partial prefix cache hit: just read uncomputed blocks.
if num_local_blocks < num_remote_blocks:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heheda12345 is this the expected behavior for prefix caching with hma?

else 0
for group in kv_cache_config.kv_cache_groups
]
self.sw_sizes = [n_tokens // self.block_size for n_tokens in sw_sizes_tokens]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.sw_sizes = [n_tokens // self.block_size for n_tokens in sw_sizes_tokens]
self.sw_sizes = [cdiv(n_tokens, self.block_size) for n_tokens in sw_sizes_tokens]

If block size is 16 and sliding window size is 24, I think we need to hit 2 consequent blocks to get cache hit.

Does NIXL support hitting 1 block + 8 additional tokens?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point thanks @heheda12345 .

Does NIXL support hitting 1 block + 8 additional tokens?

Right now we only count unhashed blocks, so hits on full-blocks only as prefix cache hit and transfer the partial ones from P.

Signed-off-by: NickLucche <nlucches@redhat.com>
update tests

Signed-off-by: NickLucche <nlucches@redhat.com>
review

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Comment on lines -2307 to -2310
else:
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
local_descs_list = []
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing old model-specific opt as per-todo

Comment on lines +2054 to +2055
fa_blocks = req_block_ids[self._full_attention_group_idx]
max_num_blocks = len(fa_blocks)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still getting FA layer 'cause the way we map blocks->tokens is not reliable if we got a (eg) mamba layer here

req_num_computed_tokens = request.num_cached_tokens

all_req_block_ids = (
(block_id for group in req_block_ids for block_id in group)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unravel all blocks from all groups in generator, ow use same as what we had

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining this? AFAICT: when is_hma=True, this flattens blocks across all groups and iterates with a single index. But then that mixes full attn and sliding window block IDs, so idx doesn't correspond to a position within the full attn group anymore. Is that right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct! The idea is that any block_id (mind all ids are still unique across blocks) can fail during a transfer, as per discussion in main thread.
So we iterate all blocks involved in the request.
If one has failed, we "reset" request state.

O/w the logic is unchanged. I added a comment here.

@mergify mergify bot removed the needs-rebase label Feb 6, 2026
Signed-off-by: NickLucche <nlucches@redhat.com>
@NickLucche
Copy link
Collaborator Author

I've updated PR with a couple of benchmark runs.

cc @tlrmchlsmth

"deepseek-ai/deepseek-vl2-small": 0.59,
"deepseek-ai/deepseek-vl2-tiny": 0.19,
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
"google/gemma-3-4b-it": 0.74,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add openai/gpt-oss-20b, as a still-small but very popular sliding window model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I defer this to another PR? I have to check whether this can run on L4s

req_num_computed_tokens = request.num_cached_tokens

all_req_block_ids = (
(block_id for group in req_block_ids for block_id in group)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining this? AFAICT: when is_hma=True, this flattens blocks across all groups and iterates with a single index. But then that mixes full attn and sliding window block IDs, so idx doesn't correspond to a position within the full attn group anymore. Is that right?

EngineId = str
# block ids as returned by the hybrid KV cache manager. list[list[int]] are allow
# mutability and are for connector internal use only.
BlockIds = tuple[list[int], ...] | list[list[int]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about sequence[list[int]]? (both work for me)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion, I think the current one though is the minimal/smallest set of possible types. sequence is elegant but is a superset of the above

# When connector does not support HMA, a single group is present here
num_computed_tokens = (
len(block_ids[self._full_attention_group_idx]) * self.block_size
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#34616 is pretty nice!

Comment on lines +337 to +339
for group in kv_cache_config.kv_cache_groups:
if isinstance(group.kv_cache_spec, MambaSpec):
raise ValueError("NixlConnector does not support Mamba models.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking explicitly for MambaSpec, I think it's better to check that all specs are of supported specs (e.g. AttentionSpec).
This will protect the nixl connector from future unsupported specs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably want to address this separately due to deepseek

self.use_host_buffer = (
vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
)
self._is_hma_enabled = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this introducing this field makes it harder to follow the logic where it is used.
Basically, when you use something like if not self._is_hma_enabled you basically make assumptions (e.g. number of groups must be 1) that are derived from the fact that HMA is not enabled.
I think it's better to use generic fields like self.blocks_per_sw self.num_kv_groups which are explicit.

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@mergify
Copy link

mergify bot commented Feb 25, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 25, 2026
@sarckk
Copy link
Collaborator

sarckk commented Mar 1, 2026

@NickLucche are you still working on this?

@NickLucche
Copy link
Collaborator Author

@sarckk yeah we've been discussing proper kv blocks recovery for HMA so things are stuck here..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants