Skip to content

[KVConnector] scheduler: Add HMA support for KV load recovery#35223

Open
orozery wants to merge 1 commit intovllm-project:mainfrom
orozery:scheduler-hma-recovery
Open

[KVConnector] scheduler: Add HMA support for KV load recovery#35223
orozery wants to merge 1 commit intovllm-project:mainfrom
orozery:scheduler-hma-recovery

Conversation

@orozery
Copy link
Collaborator

@orozery orozery commented Feb 24, 2026

This PR extends the KV load recovery flow to support HMA.
Models using sliding window will fail on error instead of recompute.
Depends on #34616.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends the KV load recovery mechanism to support Hybrid Memory Allocation (HMA) by handling multiple KV cache groups. It also introduces a change where models with sliding window attention will fail on error instead of recomputing, which is a safer default.

The changes in the tests and scheduler logic for handling num_computed_tokens in async loading scenarios are consistent. However, I've found a critical bug in the _update_requests_with_invalid_blocks function where the accounting for affected tokens is incorrect when multiple cache groups are present. Please see my detailed comment for a fix.

@orozery
Copy link
Collaborator Author

orozery commented Feb 24, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends the KV load recovery mechanism to support the Hybrid Memory Allocator (HMA) by refactoring how requests with invalid KV blocks are handled. The changes correctly disable recovery for models with sliding windows. However, I've found a critical issue in the implementation for handling invalid blocks with HMA. The logic for calculating the number of affected tokens is placed inside a loop over KV cache groups, which will lead to incorrect accounting when a request is affected by failures in multiple groups. I've provided a comment with details on how to fix this.

@orozery orozery force-pushed the scheduler-hma-recovery branch from 4e47781 to 7e55abc Compare February 24, 2026 20:10
@mergify
Copy link

mergify bot commented Feb 24, 2026

Hi @orozery, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@orozery orozery force-pushed the scheduler-hma-recovery branch from 7e55abc to 8a8e519 Compare February 24, 2026 20:20
@orozery orozery force-pushed the scheduler-hma-recovery branch from 8a8e519 to 4e87cfe Compare February 24, 2026 20:56
@mergify
Copy link

mergify bot commented Feb 24, 2026

Hi @orozery, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@orozery orozery force-pushed the scheduler-hma-recovery branch from 4e87cfe to 1cb34bc Compare February 24, 2026 21:02
@orozery orozery force-pushed the scheduler-hma-recovery branch from 1cb34bc to f8112ad Compare February 24, 2026 21:26
@orozery
Copy link
Collaborator Author

orozery commented Feb 24, 2026

To make things clear, it is possible to enable re-computation even for sliding window.
It will simply have to always re-compute the entire request.
The reason I did not implement it is that naively it will add a bulk of code to _update_requests_with_invalid_blocks which is already very bulky.
So I think it's better to first to somehow break down this function to smaller pieces which will be more readable, or maybe to do some wider re-factoring of the entire recovery flow.
For now, I think we're good just failing the request, since we know nobody today uses it with sliding windows (as it currently crashes).

@orozery
Copy link
Collaborator Author

orozery commented Feb 24, 2026

Forgot cc @sdavidbd in case you're still around :)

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm but @heheda12345 and/or @NickLucche should approve as well

@orozery orozery force-pushed the scheduler-hma-recovery branch 3 times, most recently from 8f47716 to 8c011ae Compare February 24, 2026 22:29
@orozery
Copy link
Collaborator Author

orozery commented Feb 24, 2026

Relaxed check a bit to allow sliding window re-computation if HMA is off.

@orozery orozery force-pushed the scheduler-hma-recovery branch 2 times, most recently from d994c2f to f02a5c8 Compare February 25, 2026 08:30
Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work @orozery .
Can you check whether this works with MLA models?

@orozery
Copy link
Collaborator Author

orozery commented Feb 25, 2026

@tlrmchlsmth I found quite a few issues (bugs) with my previous implementation.
I basically re-wrote a new one, so I think you should have another look.
With this, pre-computation is supported (from offset 0) even for sliding window / SSMs.

@orozery
Copy link
Collaborator Author

orozery commented Feb 25, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends the KV load recovery mechanism to support Hybrid Memory Allocation (HMA). The key change is in the scheduler's handling of invalid KV cache blocks. The logic now differentiates between models with only full-attention layers and those with other types, such as sliding-window attention or SSMs. For the latter, any KV load failure triggers a full recomputation of the request. For full-attention models, a more granular partial recomputation is performed. The implementation correctly handles multiple KV cache groups with varying block sizes, a core requirement for HMA. The changes to how num_computed_tokens is managed for asynchronous loads and the updates to the test suite are consistent and appropriate. The overall implementation appears solid and I did not find any issues of high or critical severity.

@tlrmchlsmth
Copy link
Member

@tlrmchlsmth I found quite a few issues (bugs) with my previous implementation. I basically re-wrote a new one, so I think you should have another look. With this, pre-computation is supported (from offset 0) even for sliding window / SSMs.

ok, thanks for the update - sounds like I'll need to take a closer look this time 😄

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some tests to cover the failure cases we're trying to cover?

Comment on lines +2019 to +2021
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = request.num_computed_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why wasn't this needed earlier?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a bug: num_cached_tokens was not set for async loaded requests

Comment on lines -2090 to -2092
else:
# Sync loading. num_computed_tokens includes new tokens
req_num_computed_tokens = request.num_cached_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont fully understand how we're covering this case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another bug I'm fixing here:
req_num_computed_tokens is supposed to be the number of local_gpu_tokens + num_external_tokens.
For the sync case, they got it from request.num_cached_tokens.
But this does not work for requests resuming from preemption, as request.num_cached_tokens is only set once in the request lifetime.
So instead of using request.num_cached_tokens we use request.num_computed_tokens - num_scheduled_tokens to get local_gpu_tokens + num_external_tokens.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also simplifies the code which previously had different handling for sync and async.

req_num_computed_tokens = request.num_computed_tokens

is_affected = True
# iterate request blocks by group
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
# iterate request blocks by group
# iterate request blocks by group
assert len(self.kv_cache_config.kv_cache_groups) == len(req_block_id_groups)

Comment on lines +2107 to +2108
if is_affected and self.has_non_full_attention_layer:
break
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I am following, if we have more than one group than that implies that they're not all full attention.
Why would we still loop then..?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, I don't see an assertion for that (that multiple groups -> non full attention) in kv_cache_utils.py, and so I prefer not to assume it, even for the case of future compatibility.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to add an assert for this assumption here. I can't imagine a case that all layers are full attention but we create different groups for them. Even if we have multiple attention layers with different hidden size, we will merge them into one group (e.g., deepseek 3.2). Prefer to revisit this path if there is a new case that forces us to create multiple groups for full attention under some reason as the new reason may be strange.

# on a full prompt hit, we need to re-compute the last token
# in order to be able to sample the next token
if request.num_computed_tokens >= request.num_tokens:
request.num_computed_tokens = request.num_tokens - 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we do this line after

self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens)
```?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we decrease by 1 is that the model runner cannot run with num_scheduled_tokens = 0 per request.
So this basically is a model runner limitation.

cache_blocks has nothing to do with the model runner, it does not have the limitation that requires us to decrease by 1.
Decreasing by 1 before calling cache_blocks means we don't cache the last block.
This block is valid, why not cache it?

For sync connectors we do cache the last block when we call allocate_slots:

        num_tokens_to_cache = min(
            total_computed_tokens + num_new_tokens,
            request.num_tokens,
        )
        self.coordinator.cache_blocks(request, num_tokens_to_cache)

Why should the behaviors be different for async connectors?

@orozery orozery force-pushed the scheduler-hma-recovery branch from f02a5c8 to 0000209 Compare March 1, 2026 15:46
@orozery
Copy link
Collaborator Author

orozery commented Mar 1, 2026

I've made some changes:

  1. Removed theoretical support for multi-group full-attention (as suggested by @NickLucche and @heheda12345)
  2. Removed bugfix in error handling of preempted requests (to reduce the scope of this PR).
  3. Re-factored code to make minimal changes to the original flow, with minimal line changes.
  4. Added unit tests for the multi-group case.

@orozery orozery force-pushed the scheduler-hma-recovery branch 2 times, most recently from 90a9f27 to af56820 Compare March 5, 2026 09:28
This commits extends the KV load recovery flow to support HMA.

Signed-off-by: Or Ozeri <oro@il.ibm.com>
@orozery orozery force-pushed the scheduler-hma-recovery branch from af56820 to 3fe5f8a Compare March 5, 2026 14:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants