Skip to content
4 changes: 3 additions & 1 deletion .github/workflows/scripts/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ e2e-singlecard:
- name: tests/e2e/singlecard/test_completion_with_prompt_embeds.py
estimated_time: 180
- name: tests/e2e/singlecard/test_cpu_offloading.py
estimated_time: 28
estimated_time: 169
- name: tests/e2e/singlecard/test_simple_cpu_offload.py
estimated_time: 240
- name: tests/e2e/singlecard/test_guided_decoding.py
estimated_time: 432
- name: tests/e2e/singlecard/test_ilama_lora.py
Expand Down
103 changes: 103 additions & 0 deletions tests/e2e/singlecard/test_simple_cpu_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
"""End-to-end tests for the Ascend ``SimpleCPUOffloadConnector``.

The simple CPU offloading scheduler/worker pair is reused from upstream
vLLM; here we only exercise the NPU-native worker path
(``aclrtMemcpyBatchAsync`` + ``torch.npu`` streams) to confirm that
KV blocks are stored to and reloaded from CPU correctly on Ascend.
"""

import os
import time

import pytest
from vllm import SamplingParams, TokensPrompt
from vllm.config import KVTransferConfig

from tests.e2e.conftest import VllmRunner

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"


def _build_kv_transfer_config(
cpu_bytes_to_use: int,
lazy_offload: bool = False,
) -> KVTransferConfig:
return KVTransferConfig(
kv_connector="SimpleCPUOffloadConnector",
kv_role="kv_both",
kv_connector_extra_config={
"cpu_bytes_to_use": cpu_bytes_to_use,
"lazy_offload": lazy_offload,
},
)


def test_simple_cpu_offload_accuracy() -> None:
"""Reset GPU prefix cache after a cold run; verify the CPU-loaded KV
cache reproduces the cold-run output deterministically."""
sampling_params = SamplingParams(max_tokens=1, temperature=0)

# Long enough prompt to occupy multiple full KV blocks.
prompt = "hi " * 500 + "Let's count to ten. One, two, three, "

with VllmRunner(
"Qwen/Qwen3-0.6B",
max_model_len=4096,
gpu_memory_utilization=0.5,
enable_prefix_caching=True,
kv_transfer_config=_build_kv_transfer_config(1 << 30), # 1 GiB
enforce_eager=True,
) as runner:
llm = runner.model

# Cold run — populates GPU cache and triggers CPU offload.
cold_output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]
expected = cold_output.outputs[0].text

success = 0
attempts = 5
for _ in range(attempts):
# Let the engine core drain pending store transfers.
time.sleep(2)
# Reset GPU prefix cache so the next run must reload from CPU.
if not llm.reset_prefix_cache():
continue
output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]
if output.outputs[0].text == expected:
success += 1

assert success >= int(0.5 * attempts), (
f"CPU-load accuracy too low: {success}/{attempts} matched baseline output {expected!r}"
)


@pytest.mark.parametrize("lazy", [False, True])
def test_simple_cpu_offload_no_crash_on_repeat(lazy: bool) -> None:
"""Smoke test: many short generations exercise both eager and lazy
offload paths without errors and yield non-empty outputs."""
sampling_params = SamplingParams(max_tokens=4, temperature=0)
prompt_token_ids = [0] * 257

with VllmRunner(
"Qwen/Qwen3-0.6B",
max_model_len=2048,
gpu_memory_utilization=0.5,
enable_prefix_caching=True,
kv_transfer_config=_build_kv_transfer_config(
cpu_bytes_to_use=512 * (1 << 20), # 512 MiB
lazy_offload=lazy,
),
enforce_eager=True,
) as runner:
llm = runner.model
for i in range(8):
prompt_token_ids[0] = i
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
outs = llm.generate(prompts, sampling_params, use_tqdm=False)
assert outs and len(outs[0].outputs[0].token_ids) > 0
16 changes: 16 additions & 0 deletions vllm_ascend/distributed/kv_transfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,19 @@ def register_connector():
"vllm_ascend.distributed.kv_transfer.kv_pool.lmcache_ascend_connector",
"LMCacheConnectorV1",
)

# Override the upstream SimpleCPUOffloadConnector with the NPU
# adaptation that uses aclrtMemcpyBatchAsync + torch.npu streams.
# Only override if the upstream module exists in this vLLM version.
try:
import vllm.v1.simple_kv_offload # noqa: F401
except ImportError:
pass
else:
if "SimpleCPUOffloadConnector" in KVConnectorFactory._registry:
KVConnectorFactory._registry.pop("SimpleCPUOffloadConnector")
KVConnectorFactory.register_connector(
"SimpleCPUOffloadConnector",
"vllm_ascend.distributed.kv_transfer.kv_pool.simple_cpu_offload.simple_cpu_offload_connector", # noqa: E501
"AscendSimpleCPUOffloadConnector",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Ascend NPU adaptation of vLLM's ``SimpleCPUOffloadConnector``.

The scheduler-side ``SimpleCPUOffloadScheduler`` is platform-agnostic
and reused as-is from upstream vLLM. The Ascend variant only swaps the
worker-side handler with an NPU-native implementation that uses
``aclrtMemcpyBatchAsync`` and ``torch.npu`` streams/events.
"""

from typing import TYPE_CHECKING

from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.simple_cpu_offload_connector import ( # noqa: E501
SimpleCPUOffloadConnector,
)
from vllm.logger import logger

from vllm_ascend.simple_kv_offload.worker import SimpleCPUOffloadNPUWorker

if TYPE_CHECKING:
from vllm.v1.kv_cache_interface import KVCacheConfig


class AscendSimpleCPUOffloadConnector(SimpleCPUOffloadConnector):
"""NPU-flavored ``SimpleCPUOffloadConnector``.

Inherits the full scheduler/worker plumbing from upstream and only
replaces the CUDA worker handler with the NPU one. All other public
APIs (``register_kv_caches``, ``bind_connector_metadata``,
``get_finished``, ``handle_preemptions``, every scheduler-side
method, etc.) are inherited verbatim — they all route through
``self.worker_handler`` / ``self.scheduler_manager``.

Why post-init swap (instead of skipping ``super().__init__``):
``SimpleCPUOffloadWorker.__init__`` and ``DmaCopyBackend.__init__``
only assign ``None``/empty-field defaults — no CUDA resource is
allocated until ``register_kv_caches`` runs. So letting the parent
construct a transient CUDA worker and then replacing it costs
nothing and keeps us free of duplicated configuration parsing.
"""

def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
) -> None:
super().__init__(vllm_config, role, kv_cache_config)

# If prefix caching is disabled, the parent leaves both handlers
# as None and the connector is a no-op — nothing to swap.
if role == KVConnectorRole.WORKER and self.worker_handler is not None:
cpu_capacity = self.worker_handler.cpu_capacity_bytes
self.worker_handler: SimpleCPUOffloadNPUWorker = SimpleCPUOffloadNPUWorker(
vllm_config, kv_cache_config, cpu_capacity
)
logger.info(
"AscendSimpleCPUOffloadConnector: swapped CUDA worker for NPU worker (per_rank=%.2f GB)",
cpu_capacity / (1024**3),
)
1 change: 1 addition & 0 deletions vllm_ascend/simple_kv_offload/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""NPU adaptation of vLLM's simple CPU KV-cache offloading."""
124 changes: 124 additions & 0 deletions vllm_ascend/simple_kv_offload/copy_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""DMA copy backend for NPU<->CPU block transfers.

Mirrors :class:`vllm.v1.simple_kv_offload.copy_backend.DmaCopyBackend`
but routes batched memcpy through ``torch.ops._C_ascend.swap_blocks_batch``
and uses ``torch.npu`` streams/events.
"""

from __future__ import annotations

import queue
import threading

import torch

from vllm_ascend.simple_kv_offload.npu_mem_ops import (
DIRECTION_D2H,
DIRECTION_H2D,
BatchMemcpyParams,
build_params,
copy_blocks,
)


class NPUDmaCopyBackend:
"""``aclrtMemcpyBatchAsync`` copy backend running on a worker thread.

Two pre-built ``BatchMemcpyParams`` are cached (load=H2D, store=D2H).
Submitted jobs are dispatched in FIFO order to a single worker
thread; each job issues its copies on a dedicated NPU stream and
records an Event the main thread can poll without synchronizing
the device.
"""

def __init__(self) -> None:
self._store_params: BatchMemcpyParams | None = None
self._load_params: BatchMemcpyParams | None = None
self._load_stream: torch.npu.Stream | None = None
self._store_stream: torch.npu.Stream | None = None
self._device: torch.device | None = None
self._queue: queue.SimpleQueue | None = None
self._thread: threading.Thread | None = None
self._shutdown: bool = False

def init(
self,
npu_caches: dict[str, torch.Tensor],
cpu_caches: dict[str, torch.Tensor],
device: torch.device,
load_stream: torch.npu.Stream,
store_stream: torch.npu.Stream,
) -> None:
self._load_stream = load_stream
self._store_stream = store_stream
self._device = device
# Stores go NPU->CPU (D2H), loads go CPU->NPU (H2D).
self._store_params = build_params(npu_caches, cpu_caches, DIRECTION_D2H)
self._load_params = build_params(cpu_caches, npu_caches, DIRECTION_H2D)

self._queue = queue.SimpleQueue()
self._thread = threading.Thread(
target=self._copy_loop,
name="npu-kv-offload-copy",
daemon=True,
)
self._thread.start()

def launch_copy(
self,
src_blocks: list[int],
dst_blocks: list[int],
is_store: bool,
event_idx: int,
events_list: list[tuple[int, torch.npu.Event]],
) -> None:
params = self._store_params if is_store else self._load_params
assert params is not None and self._queue is not None
self._queue.put((src_blocks, dst_blocks, params, is_store, event_idx, events_list))

def shutdown(self) -> None:
if self._shutdown:
return
self._shutdown = True
if self._queue is not None:
self._queue.put(None)
if self._thread is not None:
self._thread.join(timeout=5.0)

# ------------------------------------------------------------------
# Worker thread main loop
# ------------------------------------------------------------------
def _copy_loop(self) -> None:
# NOTE: matches upstream cuda backend semantics — no cross-stream
# sync. The scheduler manager only schedules stores for blocks
# whose KV data is **confirmed computed** (see
# ``confirmed_tokens`` in ``SimpleCPUOffloadScheduler``), so
# those blocks have long been written and visible across streams
# by the time we read them here. Loads target GPU blocks held
# by ``BlockPool.touch`` until load completes, so they are also
# safe to write without a barrier.
assert self._device is not None
assert self._queue is not None
assert self._load_stream is not None
assert self._store_stream is not None
torch.npu.set_device(self._device)

while True:
item = self._queue.get()
if item is None:
return
(
src_blocks,
dst_blocks,
params,
is_store,
event_idx,
events_list,
) = item

stream = self._store_stream if is_store else self._load_stream
with torch.npu.stream(stream):
copy_blocks(src_blocks, dst_blocks, params)
event = torch.npu.Event()
event.record(stream)
events_list.append((event_idx, event))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Race condition on events_list. The worker thread appends to events_list while the main thread iterates over or pops from it in _poll_stream_events and _flush_and_sync_all. While CPython lists are generally thread-safe for single operations like append and pop, the multi-step iteration and modification across threads without synchronization is risky and can lead to inconsistent state or RuntimeError. It is recommended to use a thread-safe queue (like queue.SimpleQueue) for completion events, which the main thread can drain into its local list.

Loading
Loading