[Core][KVConnector] Support HMA+NixlConnector#32204
[Core][KVConnector] Support HMA+NixlConnector#32204NickLucche wants to merge 23 commits intovllm-project:mainfrom
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Hybrid Memory Allocator (HMA) in the NixlConnector, which is a significant step towards optimizing performance for models with hybrid attention mechanisms. The changes are comprehensive, affecting the connector's core logic, the scheduler, and associated tests. The introduction of new data structures to handle multiple KV cache groups and the logic for clipping blocks for sliding window attention appear to be well-thought-out. The addition of dedicated HMA tests is also a positive aspect of this PR.
I've identified a critical issue in nixl_connector.py that could lead to a runtime error in non-HMA scenarios with differing block sizes. My review includes a specific comment with a suggested fix for this issue. Apart from that, the changes look solid and move vLLM forward in supporting more complex KV cache management strategies.
| local_block_ids = tuple(local_block_ids) if local_block_ids else [] | ||
| remote_block_ids = tuple(remote_block_ids) |
There was a problem hiding this comment.
There are a couple of issues on these lines that will cause a runtime error when block_size_ratio > 1 for a non-HMA setup.
- The conditional
if local_block_ids:on line 2168 will raise aValueErrorbecauselocal_block_idsis a numpy array at this point. The truth value of a numpy array with more than one element is ambiguous. You should useif local_block_ids.size > 0:to check for emptiness. - The conversions
tuple(local_block_ids)andtuple(remote_block_ids)are incorrect. They convert the array/list into a tuple of elements (e.g.,(1, 2, 3)), but the subsequent code, especially_get_block_descs_ids, expects a tuple of lists (e.g.,([1, 2, 3],)). This will causenp.concatenateto fail.
To fix this, you need to correctly check for an empty array and then wrap the result in a tuple to maintain the tuple[list[int], ...] structure.
| local_block_ids = tuple(local_block_ids) if local_block_ids else [] | |
| remote_block_ids = tuple(remote_block_ids) | |
| local_block_ids = (local_block_ids.tolist(),) if local_block_ids.size > 0 else () | |
| remote_block_ids = (remote_block_ids,) |
vllm/v1/core/sched/scheduler.py
Outdated
| # For the purpose of marking blocks as invalid, only report FA ones to | ||
| # handle blocks<>tokens mapping consistently. | ||
| # for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids): | ||
| for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids[fa_blocks_idx]): |
| # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to | ||
| # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | ||
| local_block_ids = local_block_ids[: len(remote_block_ids)] | ||
| local_block_ids = tuple(local_block_ids) if local_block_ids else [] |
There was a problem hiding this comment.
I tested with original codes on hetero setting, it will report
local_block_ids = tuple(local_block_ids) if local_block_ids else []
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
after fixing, other issues occur, so I pushed a new PR as below to fix for heterogenous support
760c8cf to
7789163
Compare
|
Hi @NickLucche, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
This pull request has merge conflicts that must be resolved before it can be |
ca6c740 to
62204b2
Compare
|
Hi @NickLucche, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
| for i, remote_group in enumerate(remote_block_ids): | ||
| num_remote_blocks = len(remote_group) | ||
| num_local_blocks = len(local_block_ids[i]) | ||
| assert num_local_blocks <= num_remote_blocks | ||
| # Partial prefix cache hit: just read uncomputed blocks. | ||
| if num_local_blocks < num_remote_blocks: |
There was a problem hiding this comment.
@heheda12345 is this the expected behavior for prefix caching with hma?
| else 0 | ||
| for group in kv_cache_config.kv_cache_groups | ||
| ] | ||
| self.sw_sizes = [n_tokens // self.block_size for n_tokens in sw_sizes_tokens] |
There was a problem hiding this comment.
| self.sw_sizes = [n_tokens // self.block_size for n_tokens in sw_sizes_tokens] | |
| self.sw_sizes = [cdiv(n_tokens, self.block_size) for n_tokens in sw_sizes_tokens] |
If block size is 16 and sliding window size is 24, I think we need to hit 2 consequent blocks to get cache hit.
Does NIXL support hitting 1 block + 8 additional tokens?
There was a problem hiding this comment.
good point thanks @heheda12345 .
Does NIXL support hitting 1 block + 8 additional tokens?
Right now we only count unhashed blocks, so hits on full-blocks only as prefix cache hit and transfer the partial ones from P.
Signed-off-by: NickLucche <nlucches@redhat.com>
update tests Signed-off-by: NickLucche <nlucches@redhat.com>
review Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
1bc96a2 to
ccb3957
Compare
| else: | ||
| # TODO(mgoin): remove this once we have hybrid memory allocator | ||
| # Optimization for models with local attention (Llama 4) | ||
| local_descs_list = [] |
There was a problem hiding this comment.
removing old model-specific opt as per-todo
| fa_blocks = req_block_ids[self._full_attention_group_idx] | ||
| max_num_blocks = len(fa_blocks) |
There was a problem hiding this comment.
still getting FA layer 'cause the way we map blocks->tokens is not reliable if we got a (eg) mamba layer here
| req_num_computed_tokens = request.num_cached_tokens | ||
|
|
||
| all_req_block_ids = ( | ||
| (block_id for group in req_block_ids for block_id in group) |
There was a problem hiding this comment.
unravel all blocks from all groups in generator, ow use same as what we had
There was a problem hiding this comment.
Could you add a comment explaining this? AFAICT: when is_hma=True, this flattens blocks across all groups and iterates with a single index. But then that mixes full attn and sliding window block IDs, so idx doesn't correspond to a position within the full attn group anymore. Is that right?
There was a problem hiding this comment.
That's correct! The idea is that any block_id (mind all ids are still unique across blocks) can fail during a transfer, as per discussion in main thread.
So we iterate all blocks involved in the request.
If one has failed, we "reset" request state.
O/w the logic is unchanged. I added a comment here.
|
I've updated PR with a couple of benchmark runs. cc @tlrmchlsmth |
| "deepseek-ai/deepseek-vl2-small": 0.59, | ||
| "deepseek-ai/deepseek-vl2-tiny": 0.19, | ||
| "deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65, | ||
| "google/gemma-3-4b-it": 0.74, |
There was a problem hiding this comment.
Let's add openai/gpt-oss-20b, as a still-small but very popular sliding window model?
There was a problem hiding this comment.
Can I defer this to another PR? I have to check whether this can run on L4s
| req_num_computed_tokens = request.num_cached_tokens | ||
|
|
||
| all_req_block_ids = ( | ||
| (block_id for group in req_block_ids for block_id in group) |
There was a problem hiding this comment.
Could you add a comment explaining this? AFAICT: when is_hma=True, this flattens blocks across all groups and iterates with a single index. But then that mixes full attn and sliding window block IDs, so idx doesn't correspond to a position within the full attn group anymore. Is that right?
| EngineId = str | ||
| # block ids as returned by the hybrid KV cache manager. list[list[int]] are allow | ||
| # mutability and are for connector internal use only. | ||
| BlockIds = tuple[list[int], ...] | list[list[int]] |
There was a problem hiding this comment.
how about sequence[list[int]]? (both work for me)
There was a problem hiding this comment.
I don't have a strong opinion, I think the current one though is the minimal/smallest set of possible types. sequence is elegant but is a superset of the above
| # When connector does not support HMA, a single group is present here | ||
| num_computed_tokens = ( | ||
| len(block_ids[self._full_attention_group_idx]) * self.block_size | ||
| ) |
| for group in kv_cache_config.kv_cache_groups: | ||
| if isinstance(group.kv_cache_spec, MambaSpec): | ||
| raise ValueError("NixlConnector does not support Mamba models.") |
There was a problem hiding this comment.
Instead of checking explicitly for MambaSpec, I think it's better to check that all specs are of supported specs (e.g. AttentionSpec).
This will protect the nixl connector from future unsupported specs.
There was a problem hiding this comment.
I probably want to address this separately due to deepseek
| self.use_host_buffer = ( | ||
| vllm_config.kv_transfer_config.kv_buffer_device == "cpu" | ||
| ) | ||
| self._is_hma_enabled = ( |
There was a problem hiding this comment.
I think this introducing this field makes it harder to follow the logic where it is used.
Basically, when you use something like if not self._is_hma_enabled you basically make assumptions (e.g. number of groups must be 1) that are derived from the fact that HMA is not enabled.
I think it's better to use generic fields like self.blocks_per_sw self.num_kv_groups which are explicit.
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
|
This pull request has merge conflicts that must be resolved before it can be |
|
@NickLucche are you still working on this? |
|
@sarckk yeah we've been discussing proper kv blocks recovery for HMA so things are stuck here.. |
Overview
Currently connectors are not able to take full advantage of models that employ hybrid attention (FA+SWA) and treat all layers as FA, as the Hybrid Kv Cache Manager is disabled.
This PR enables NixlConnector to work with the HMA, resulting in drastically reducing the number of bytes/regions moved with a xfer for SWA+FA models, while laying the ground for state-based ones (mamba etc).
Example of the former:
UPDATE: see comments below for a discussion on marking invalid blocks.
Test with
Enable HMA experimental support with
--no-disable-hybrid-kv-cache-manager:lm-eval results:
or newly added file
EDIT:
I've also validated part of the lm-eval CI locally, you can test out the different tracked configs with
Run
python -m pytest -s -v -x tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py::test_hma_sync_recompute_evicts_all_blocksfor testing the invalid block handling with hma.TODOs
cc working with @heheda12345 @KuntaiDu @ivanium
Benchmarks
ShareGPT results, no-prefix-caching 8xH100.
Main:
This PR:
so up to ~7% throughput in this small-scale intra-node setup. Inter-node one would be more interesting to analyze.
Note
Cursor Bugbot is generating a summary for commit 59b3474. Configure here.
Note
Enables HMA-aware KV transfer for FA+SWA models, reducing transferred regions and aligning connector behavior with hybrid KV cache groups.
kv_cache_configand operate on multi-groupBlockIds(tuples per KV group); addrequest_finished_all_groupsand HMA marker viaSupportsHMAsw_sizes) and passes unclipped/then clipped IDs: usesget_unhashed_block_ids_all_groups,get_blocks_in_fa_kv_group, and computes desc IDs across groups; full-prefix hits use empty listsget_blocks_in_fa_kv_group; KVCacheBlocks addsget_unhashed_block_ids_all_groupstest_nixl_connector_hma.pycovering SW sizing, logical→kernel mapping, fewer SW blocks, metadata structure; adapt existing tests to new APIs and configsWritten by Cursor Bugbot for commit 59b3474. This will update automatically on new commits. Configure here.