-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[Feature] Simple yet General CPU KV Cache Offloading #8743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
HF-001
wants to merge
16
commits into
vllm-project:main
Choose a base branch
from
HF-001:simple_kv_offload
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
77b3705
[Feature] add simple kvcache offloading
HF-001 a2d6217
[Feature] add simple kvcache offload
HF-001 a215201
fix
HF-001 5ef2850
Merge branch 'main' into simple_kv_offload
HF-001 d5ebc7b
fix
HF-001 4d6291a
fix
HF-001 d00d50f
fix
HF-001 3ae148d
fix
HF-001 6fbecad
fix
HF-001 7a9883d
Merge branch 'main' into simple_kv_offload
HF-001 0116ef1
Merge branch 'main' into simple_kv_offload
HF-001 169aa4f
fix
HF-001 ecf7144
fix
HF-001 a5ff9f7
Merge branch 'main' into simple_kv_offload
HF-001 cc16bbc
fix
HF-001 dd92c94
Merge branch 'main' into simple_kv_offload
HF-001 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| # | ||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # This file is a part of the vllm-ascend project. | ||
| # | ||
| """End-to-end tests for the Ascend ``SimpleCPUOffloadConnector``. | ||
|
|
||
| The simple CPU offloading scheduler/worker pair is reused from upstream | ||
| vLLM; here we only exercise the NPU-native worker path | ||
| (``aclrtMemcpyBatchAsync`` + ``torch.npu`` streams) to confirm that | ||
| KV blocks are stored to and reloaded from CPU correctly on Ascend. | ||
| """ | ||
|
|
||
| import os | ||
| import time | ||
|
|
||
| import pytest | ||
| from vllm import SamplingParams, TokensPrompt | ||
| from vllm.config import KVTransferConfig | ||
|
|
||
| from tests.e2e.conftest import VllmRunner | ||
|
|
||
| os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" | ||
|
|
||
|
|
||
| def _build_kv_transfer_config( | ||
| cpu_bytes_to_use: int, | ||
| lazy_offload: bool = False, | ||
| ) -> KVTransferConfig: | ||
| return KVTransferConfig( | ||
| kv_connector="SimpleCPUOffloadConnector", | ||
| kv_role="kv_both", | ||
| kv_connector_extra_config={ | ||
| "cpu_bytes_to_use": cpu_bytes_to_use, | ||
| "lazy_offload": lazy_offload, | ||
| }, | ||
| ) | ||
|
|
||
|
|
||
| def test_simple_cpu_offload_accuracy() -> None: | ||
| """Reset GPU prefix cache after a cold run; verify the CPU-loaded KV | ||
| cache reproduces the cold-run output deterministically.""" | ||
| sampling_params = SamplingParams(max_tokens=1, temperature=0) | ||
|
|
||
| # Long enough prompt to occupy multiple full KV blocks. | ||
| prompt = "hi " * 500 + "Let's count to ten. One, two, three, " | ||
|
|
||
| with VllmRunner( | ||
| "Qwen/Qwen3-0.6B", | ||
| max_model_len=4096, | ||
| gpu_memory_utilization=0.5, | ||
| enable_prefix_caching=True, | ||
| kv_transfer_config=_build_kv_transfer_config(1 << 30), # 1 GiB | ||
| enforce_eager=True, | ||
| ) as runner: | ||
| llm = runner.model | ||
|
|
||
| # Cold run — populates GPU cache and triggers CPU offload. | ||
| cold_output = llm.generate(prompt, sampling_params, use_tqdm=False)[0] | ||
| expected = cold_output.outputs[0].text | ||
|
|
||
| success = 0 | ||
| attempts = 5 | ||
| for _ in range(attempts): | ||
| # Let the engine core drain pending store transfers. | ||
| time.sleep(2) | ||
| # Reset GPU prefix cache so the next run must reload from CPU. | ||
| if not llm.reset_prefix_cache(): | ||
| continue | ||
| output = llm.generate(prompt, sampling_params, use_tqdm=False)[0] | ||
| if output.outputs[0].text == expected: | ||
| success += 1 | ||
|
|
||
| assert success >= int(0.5 * attempts), ( | ||
| f"CPU-load accuracy too low: {success}/{attempts} matched baseline output {expected!r}" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("lazy", [False, True]) | ||
| def test_simple_cpu_offload_no_crash_on_repeat(lazy: bool) -> None: | ||
| """Smoke test: many short generations exercise both eager and lazy | ||
| offload paths without errors and yield non-empty outputs.""" | ||
| sampling_params = SamplingParams(max_tokens=4, temperature=0) | ||
| prompt_token_ids = [0] * 257 | ||
|
|
||
| with VllmRunner( | ||
| "Qwen/Qwen3-0.6B", | ||
| max_model_len=2048, | ||
| gpu_memory_utilization=0.5, | ||
| enable_prefix_caching=True, | ||
| kv_transfer_config=_build_kv_transfer_config( | ||
| cpu_bytes_to_use=512 * (1 << 20), # 512 MiB | ||
| lazy_offload=lazy, | ||
| ), | ||
| enforce_eager=True, | ||
| ) as runner: | ||
| llm = runner.model | ||
| for i in range(8): | ||
| prompt_token_ids[0] = i | ||
| prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)] | ||
| outs = llm.generate(prompts, sampling_params, use_tqdm=False) | ||
| assert outs and len(outs[0].outputs[0].token_ids) > 0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 1 addition & 0 deletions
1
vllm_ascend/distributed/kv_transfer/kv_pool/simple_cpu_offload/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
|
|
60 changes: 60 additions & 0 deletions
60
...ascend/distributed/kv_transfer/kv_pool/simple_cpu_offload/simple_cpu_offload_connector.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| """Ascend NPU adaptation of vLLM's ``SimpleCPUOffloadConnector``. | ||
|
|
||
| The scheduler-side ``SimpleCPUOffloadScheduler`` is platform-agnostic | ||
| and reused as-is from upstream vLLM. The Ascend variant only swaps the | ||
| worker-side handler with an NPU-native implementation that uses | ||
| ``aclrtMemcpyBatchAsync`` and ``torch.npu`` streams/events. | ||
| """ | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| from vllm.config import VllmConfig | ||
| from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole | ||
| from vllm.distributed.kv_transfer.kv_connector.v1.simple_cpu_offload_connector import ( # noqa: E501 | ||
| SimpleCPUOffloadConnector, | ||
| ) | ||
| from vllm.logger import logger | ||
|
|
||
| from vllm_ascend.simple_kv_offload.worker import SimpleCPUOffloadNPUWorker | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.v1.kv_cache_interface import KVCacheConfig | ||
|
|
||
|
|
||
| class AscendSimpleCPUOffloadConnector(SimpleCPUOffloadConnector): | ||
| """NPU-flavored ``SimpleCPUOffloadConnector``. | ||
|
|
||
| Inherits the full scheduler/worker plumbing from upstream and only | ||
| replaces the CUDA worker handler with the NPU one. All other public | ||
| APIs (``register_kv_caches``, ``bind_connector_metadata``, | ||
| ``get_finished``, ``handle_preemptions``, every scheduler-side | ||
| method, etc.) are inherited verbatim — they all route through | ||
| ``self.worker_handler`` / ``self.scheduler_manager``. | ||
|
|
||
| Why post-init swap (instead of skipping ``super().__init__``): | ||
| ``SimpleCPUOffloadWorker.__init__`` and ``DmaCopyBackend.__init__`` | ||
| only assign ``None``/empty-field defaults — no CUDA resource is | ||
| allocated until ``register_kv_caches`` runs. So letting the parent | ||
| construct a transient CUDA worker and then replacing it costs | ||
| nothing and keeps us free of duplicated configuration parsing. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| vllm_config: VllmConfig, | ||
| role: KVConnectorRole, | ||
| kv_cache_config: "KVCacheConfig | None" = None, | ||
| ) -> None: | ||
| super().__init__(vllm_config, role, kv_cache_config) | ||
|
|
||
| # If prefix caching is disabled, the parent leaves both handlers | ||
| # as None and the connector is a no-op — nothing to swap. | ||
| if role == KVConnectorRole.WORKER and self.worker_handler is not None: | ||
| cpu_capacity = self.worker_handler.cpu_capacity_bytes | ||
| self.worker_handler: SimpleCPUOffloadNPUWorker = SimpleCPUOffloadNPUWorker( | ||
| vllm_config, kv_cache_config, cpu_capacity | ||
| ) | ||
| logger.info( | ||
| "AscendSimpleCPUOffloadConnector: swapped CUDA worker for NPU worker (per_rank=%.2f GB)", | ||
| cpu_capacity / (1024**3), | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """NPU adaptation of vLLM's simple CPU KV-cache offloading.""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| """DMA copy backend for NPU<->CPU block transfers. | ||
|
|
||
| Mirrors :class:`vllm.v1.simple_kv_offload.copy_backend.DmaCopyBackend` | ||
| but routes batched memcpy through ``torch.ops._C_ascend.swap_blocks_batch`` | ||
| and uses ``torch.npu`` streams/events. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import queue | ||
| import threading | ||
|
|
||
| import torch | ||
|
|
||
| from vllm_ascend.simple_kv_offload.npu_mem_ops import ( | ||
| DIRECTION_D2H, | ||
| DIRECTION_H2D, | ||
| BatchMemcpyParams, | ||
| build_params, | ||
| copy_blocks, | ||
| ) | ||
|
|
||
|
|
||
| class NPUDmaCopyBackend: | ||
| """``aclrtMemcpyBatchAsync`` copy backend running on a worker thread. | ||
|
|
||
| Two pre-built ``BatchMemcpyParams`` are cached (load=H2D, store=D2H). | ||
| Submitted jobs are dispatched in FIFO order to a single worker | ||
| thread; each job issues its copies on a dedicated NPU stream and | ||
| records an Event the main thread can poll without synchronizing | ||
| the device. | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| self._store_params: BatchMemcpyParams | None = None | ||
| self._load_params: BatchMemcpyParams | None = None | ||
| self._load_stream: torch.npu.Stream | None = None | ||
| self._store_stream: torch.npu.Stream | None = None | ||
| self._device: torch.device | None = None | ||
| self._queue: queue.SimpleQueue | None = None | ||
| self._thread: threading.Thread | None = None | ||
| self._shutdown: bool = False | ||
|
|
||
| def init( | ||
| self, | ||
| npu_caches: dict[str, torch.Tensor], | ||
| cpu_caches: dict[str, torch.Tensor], | ||
| device: torch.device, | ||
| load_stream: torch.npu.Stream, | ||
| store_stream: torch.npu.Stream, | ||
| ) -> None: | ||
| self._load_stream = load_stream | ||
| self._store_stream = store_stream | ||
| self._device = device | ||
| # Stores go NPU->CPU (D2H), loads go CPU->NPU (H2D). | ||
| self._store_params = build_params(npu_caches, cpu_caches, DIRECTION_D2H) | ||
| self._load_params = build_params(cpu_caches, npu_caches, DIRECTION_H2D) | ||
|
|
||
| self._queue = queue.SimpleQueue() | ||
| self._thread = threading.Thread( | ||
| target=self._copy_loop, | ||
| name="npu-kv-offload-copy", | ||
| daemon=True, | ||
| ) | ||
| self._thread.start() | ||
|
|
||
| def launch_copy( | ||
| self, | ||
| src_blocks: list[int], | ||
| dst_blocks: list[int], | ||
| is_store: bool, | ||
| event_idx: int, | ||
| events_list: list[tuple[int, torch.npu.Event]], | ||
| ) -> None: | ||
| params = self._store_params if is_store else self._load_params | ||
| assert params is not None and self._queue is not None | ||
| self._queue.put((src_blocks, dst_blocks, params, is_store, event_idx, events_list)) | ||
|
|
||
| def shutdown(self) -> None: | ||
| if self._shutdown: | ||
| return | ||
| self._shutdown = True | ||
| if self._queue is not None: | ||
| self._queue.put(None) | ||
| if self._thread is not None: | ||
| self._thread.join(timeout=5.0) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Worker thread main loop | ||
| # ------------------------------------------------------------------ | ||
| def _copy_loop(self) -> None: | ||
| # NOTE: matches upstream cuda backend semantics — no cross-stream | ||
| # sync. The scheduler manager only schedules stores for blocks | ||
| # whose KV data is **confirmed computed** (see | ||
| # ``confirmed_tokens`` in ``SimpleCPUOffloadScheduler``), so | ||
| # those blocks have long been written and visible across streams | ||
| # by the time we read them here. Loads target GPU blocks held | ||
| # by ``BlockPool.touch`` until load completes, so they are also | ||
| # safe to write without a barrier. | ||
| assert self._device is not None | ||
| assert self._queue is not None | ||
| assert self._load_stream is not None | ||
| assert self._store_stream is not None | ||
| torch.npu.set_device(self._device) | ||
|
|
||
| while True: | ||
| item = self._queue.get() | ||
| if item is None: | ||
| return | ||
| ( | ||
| src_blocks, | ||
| dst_blocks, | ||
| params, | ||
| is_store, | ||
| event_idx, | ||
| events_list, | ||
| ) = item | ||
|
|
||
| stream = self._store_stream if is_store else self._load_stream | ||
| with torch.npu.stream(stream): | ||
| copy_blocks(src_blocks, dst_blocks, params) | ||
| event = torch.npu.Event() | ||
| event.record(stream) | ||
| events_list.append((event_idx, event)) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition on
events_list. The worker thread appends toevents_listwhile the main thread iterates over or pops from it in_poll_stream_eventsand_flush_and_sync_all. While CPython lists are generally thread-safe for single operations likeappendandpop, the multi-step iteration and modification across threads without synchronization is risky and can lead to inconsistent state orRuntimeError. It is recommended to use a thread-safe queue (likequeue.SimpleQueue) for completion events, which the main thread can drain into its local list.