Skip to content

[kv_offload+HMA][8/N]: Support multi-group worker transfer#38453

Merged
orozery merged 7 commits intovllm-project:mainfrom
orozery:kv-offload-worker-swap-multiple-groups
Apr 22, 2026
Merged

[kv_offload+HMA][8/N]: Support multi-group worker transfer#38453
orozery merged 7 commits intovllm-project:mainfrom
orozery:kv-offload-worker-swap-multiple-groups

Conversation

@orozery
Copy link
Copy Markdown
Collaborator

@orozery orozery commented Mar 29, 2026

This PR extends the CPU-GPU offloading handler to support transfers with multiple KV cache groups.

@orozery orozery requested a review from ApostaC as a code owner March 29, 2026 05:25
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the v1 label Mar 29, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables support for multiple KV cache groups within the CPU-GPU offloading handler. Key changes include a refactored transfer_async method in vllm/v1/kv_offload/worker/cpu_gpu.py to process group-specific sizes and block indices, facilitating unaligned CPU->GPU transfers. Additionally, expand_block_ids was updated with bounds checking, and a new multi-group test case was added to tests/v1/kv_offload/test_cpu_gpu.py. I have no feedback to provide as there were no review comments to evaluate.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 3, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @orozery.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 3, 2026
hickeyma added a commit to hickeyma/vllm that referenced this pull request Apr 3, 2026
Recommendation from review as more work needs to be done to enable HMA
for CPU side and will be done in
vllm-project#38453

Review comment:

- vllm-project#37688 (comment)
- vllm-project#37688 (comment)

Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
@orozery orozery force-pushed the kv-offload-worker-swap-multiple-groups branch from 442e8c9 to ed7062e Compare April 5, 2026 10:08
@mergify mergify Bot removed the needs-rebase label Apr 5, 2026
@orozery orozery added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 6, 2026
@orozery orozery force-pushed the kv-offload-worker-swap-multiple-groups branch from ed7062e to 3865eb2 Compare April 9, 2026 12:12
@mergify mergify Bot added the kv-connector label Apr 9, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 14, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @orozery.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 14, 2026
This commit extends the CPU-GPU offloading handler to support transfers
with multiple KV cache groups.

Signed-off-by: Or Ozeri <oro@il.ibm.com>
@orozery orozery force-pushed the kv-offload-worker-swap-multiple-groups branch from 3865eb2 to 108e185 Compare April 15, 2026 08:56
@mergify mergify Bot removed the needs-rebase label Apr 15, 2026
Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, only left one question

block_indices[i] will represent the block index of the first block in group #i.
Thus, len(block_indices) == len(group_sizes) = number of KV cache groups.
This information is required in order to support loading from offloaded blocks
This information is required in order to support off/loading from offloaded blocks
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why is this off/loading? 😆

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I meant offloading + onloading, but indeed it is even confusing to me :)

Comment on lines +188 to +190
# There are 2 types of transfers:
# 1. GPU -> CPU
# 2. CPU -> GPU
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is useful

Comment on lines +223 to +225
all_src = np.empty(num_copy_ops, dtype=np.int64)
all_dst = np.empty(num_copy_ops, dtype=np.int64)
all_sizes = np.empty(num_copy_ops, dtype=np.int64)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would be the max size here if we were to use persistent cpu buffers?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that would be O(max_model_len) as the number of ops is some constant multiple of the number of blocks being copied. The constant factor would be num_groupsXnum_tensors_per_group (e.g. 2 for flash attention).

@mergify mergify Bot requested a review from xuechendi as a code owner April 21, 2026 14:38
@orozery orozery merged commit aad88f8 into vllm-project:main Apr 22, 2026
61 checks passed
Copilot AI pushed a commit to hongbolv/vllm that referenced this pull request Apr 22, 2026
…ect#38453)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
baonudesifeizhai pushed a commit to baonudesifeizhai/vllm that referenced this pull request Apr 23, 2026
…ect#38453)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
yzong-rh pushed a commit to yzong-rh/vllm that referenced this pull request Apr 23, 2026
…ect#38453)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Signed-off-by: Yifan <yzong@redhat.com>
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
…ect#38453)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
Lafunamor pushed a commit to Lafunamor/vllm that referenced this pull request May 1, 2026
…ect#38453)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Signed-off-by: Adrian <info@zzit.ch>
iboiko-habana pushed a commit to vllm-project/vllm-gaudi that referenced this pull request May 6, 2026
…stream regressions: MoE refactor, DeepSeek V4 router, KV offload HMA (#1403)

Fix multiple upstream vLLM regressions breaking vllm-gaudi unit tests at
vllm commit `5b39b268f506150dbab38f6f6c04b7c843e37c07`

## Fixes

### 1. MoE runner import rename (upstream
[#40560](vllm-project/vllm#40560))

`moe_runner_base.py` was removed and split into `moe_runner.py` +
`moe_runner_interface.py`. `MoERunnerBase` class renamed to `MoERunner`.

**Changes:**
- `vllm_gaudi/ops/hpu_fused_moe.py`: Import `MoERunner as MoERunnerBase`
from `moe_runner` module; update `get_layer_from_name` import path
- `vllm_gaudi/ops/hpu_lora.py`: Change `_all_lora_classes` from `set` to
`tuple` (upstream
[#35077](vllm-project/vllm#35077))

### 2. DeepSeek V4 router API — `hash_indices_table` (upstream
[#40860](vllm-project/vllm#40860))

`FusedMoE.__init__` now passes `hash_indices_table` and
`zero_expert_type` to `create_fused_moe_router()`, and `input_ids` kwarg
to `apply_monolithic()`.

**Changes:**
- `vllm_gaudi/ops/hpu_fused_moe.py`: Add `hash_indices_table` parameter
to HPU `create_fused_moe_router()` override; pass it to
`FusedTopKBiasRouter`; add `scoring_func` assertion
- `vllm_gaudi/ops/hpu_compressed_tensors.py`: Add `**kwargs` to
`HPUCompressedTensorsWNA16MoEMethod.apply_monolithic()` for `input_ids`
kwarg
- `tests/unit_tests/ops/utils.py`: Add `zero_expert_type` and
`hash_indices_table` to `create_fused_moe()` test helper

### 3. MoE test `_forward_dispatch` removal (upstream
[#40560](vllm-project/vllm#40560))

`MoERunnerBase._forward_dispatch()` was removed. Tests must use
`runner.forward()` with a proper `ForwardContext`.

**Changes:**
- `tests/unit_tests/ops/test_hpu_fused_moe.py`: Replace
`_forward_dispatch` call with `runner.forward()`; use real
`ForwardContext` with `no_compile_layers`
- `tests/unit_tests/ops/test_hpu_compressed_tensors.py`: Same migration

### 4. KV offload scheduler — HMA multi-group + per-job store completion
(upstream [#39186](vllm-project/vllm#39186),
[#39403](vllm-project/vllm#39403),
[#38453](vllm-project/vllm#38453),
[#39401](vllm-project/vllm#39401),
[#39402](vllm-project/vllm#39402))

Offloading scheduler was refactored for multi-group KV support and
per-job store completion tracking.

**Changes:**
- `tests/unit_tests/kv_offload/offloading_connector/utils.py`: Sync with
upstream — `OffloadKey`/`OffloadingWorkerMetadata` types, async
scheduling support, `TransferJobStatus` tracking,
`build_connector_worker_meta()` integration
- `tests/unit_tests/kv_offload/utils.py`: Add `kv_connector_worker_meta`
parameter to `create_model_runner_output()`

### 5. OffloadingConnectorMetadata per-job API in model runner (upstream
[#39186](vllm-project/vllm#39186))

`OffloadingConnectorMetadata` fields were renamed from
`reqs_to_store`/`reqs_to_load` (`dict[ReqId, TransferSpec]`) to
`store_jobs`/`load_jobs` (`dict[int, TransferJob]`), with `req_id` now
inside `TransferJob`. This caused `EngineDeadError` crashes in all tests
that use KV offloading or LoRA (3 CI failures).

**Changes:**
- `vllm_gaudi/v1/worker/hpu_model_runner.py`: Update
`_get_prompts_and_decodes()` to extract `req_id` from `TransferJob`
objects in `store_jobs`/`load_jobs` instead of iterating over removed
`reqs_to_store`/`reqs_to_load` — both for direct
`OffloadingConnectorMetadata` and nested `MultiKVConnectorMetadata`
cases

### 6. LoRA punica wrapper — add_shrink / add_expand (upstream
[#35077](vllm-project/vllm#35077))

Upstream refactored `PunicaWrapperBase` to add `add_shrink()` and
`add_expand()` methods. HPU punica wrapper was missing these, causing
`AttributeError`.

**Changes:**
- `vllm_gaudi/lora/hpu_punica_wrapper.py`: Implement `add_shrink()` and
`add_expand()` methods in HPU punica wrapper

### 7. Rejection sampler — synthetic_mode kwarg (upstream
[#40662](vllm-project/vllm#40662))

Upstream added `synthetic_mode` parameter to `rejection_sample()`. HPU
override was missing it.

**Changes:**
- `vllm_gaudi/v1/worker/hpu_model_runner.py`: Accept `synthetic_mode`
kwarg in HPU `rejection_sample` override

### 8. MoE forward — pass input_ids to custom op (upstream
[#40860](vllm-project/vllm#40860))

DeepSeek V4 support added `input_ids` parameter at position 3 of
`_moe_forward` / `_moe_forward_shared` custom ops.
`patched_fused_moe_forward` was not passing it through to
`_forward_impl` and `_forward_entry`, causing `RuntimeError: Expected
Optional[Tensor] but found str` in `hpu_dp_tests`.

**Changes:**
- `vllm_gaudi/ops/hpu_fused_moe.py`: Pass `input_ids` through both
`_forward_impl` and `_forward_entry` calls in
`patched_fused_moe_forward()`

---------

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Copilot AI pushed a commit to hongbolv/vllm that referenced this pull request May 7, 2026
…ect#38453)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants