diff --git a/.github/workflows/scripts/config.yaml b/.github/workflows/scripts/config.yaml index 71978b9a214..b08ec39c5a6 100644 --- a/.github/workflows/scripts/config.yaml +++ b/.github/workflows/scripts/config.yaml @@ -22,7 +22,9 @@ e2e-singlecard: - name: tests/e2e/singlecard/test_completion_with_prompt_embeds.py estimated_time: 180 - name: tests/e2e/singlecard/test_cpu_offloading.py - estimated_time: 28 + estimated_time: 169 +- name: tests/e2e/singlecard/test_simple_cpu_offload.py + estimated_time: 240 - name: tests/e2e/singlecard/test_guided_decoding.py estimated_time: 432 - name: tests/e2e/singlecard/test_ilama_lora.py diff --git a/tests/e2e/singlecard/test_simple_cpu_offload.py b/tests/e2e/singlecard/test_simple_cpu_offload.py new file mode 100644 index 00000000000..688eea8accb --- /dev/null +++ b/tests/e2e/singlecard/test_simple_cpu_offload.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +"""End-to-end tests for the Ascend ``SimpleCPUOffloadConnector``. + +The simple CPU offloading scheduler/worker pair is reused from upstream +vLLM; here we only exercise the NPU-native worker path +(``aclrtMemcpyBatchAsync`` + ``torch.npu`` streams) to confirm that +KV blocks are stored to and reloaded from CPU correctly on Ascend. +""" + +import os +import time + +import pytest +from vllm import SamplingParams, TokensPrompt +from vllm.config import KVTransferConfig + +from tests.e2e.conftest import VllmRunner + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +def _build_kv_transfer_config( + cpu_bytes_to_use: int, + lazy_offload: bool = False, +) -> KVTransferConfig: + return KVTransferConfig( + kv_connector="SimpleCPUOffloadConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "cpu_bytes_to_use": cpu_bytes_to_use, + "lazy_offload": lazy_offload, + }, + ) + + +def test_simple_cpu_offload_accuracy() -> None: + """Reset GPU prefix cache after a cold run; verify the CPU-loaded KV + cache reproduces the cold-run output deterministically.""" + sampling_params = SamplingParams(max_tokens=1, temperature=0) + + # Long enough prompt to occupy multiple full KV blocks. + prompt = "hi " * 500 + "Let's count to ten. One, two, three, " + + with VllmRunner( + "Qwen/Qwen3-0.6B", + max_model_len=4096, + gpu_memory_utilization=0.5, + enable_prefix_caching=True, + kv_transfer_config=_build_kv_transfer_config(1 << 30), # 1 GiB + enforce_eager=True, + ) as runner: + llm = runner.model + + # Cold run — populates GPU cache and triggers CPU offload. + cold_output = llm.generate(prompt, sampling_params, use_tqdm=False)[0] + expected = cold_output.outputs[0].text + + success = 0 + attempts = 5 + for _ in range(attempts): + # Let the engine core drain pending store transfers. + time.sleep(2) + # Reset GPU prefix cache so the next run must reload from CPU. + if not llm.reset_prefix_cache(): + continue + output = llm.generate(prompt, sampling_params, use_tqdm=False)[0] + if output.outputs[0].text == expected: + success += 1 + + assert success >= int(0.5 * attempts), ( + f"CPU-load accuracy too low: {success}/{attempts} matched baseline output {expected!r}" + ) + + +@pytest.mark.parametrize("lazy", [False, True]) +def test_simple_cpu_offload_no_crash_on_repeat(lazy: bool) -> None: + """Smoke test: many short generations exercise both eager and lazy + offload paths without errors and yield non-empty outputs.""" + sampling_params = SamplingParams(max_tokens=4, temperature=0) + prompt_token_ids = [0] * 257 + + with VllmRunner( + "Qwen/Qwen3-0.6B", + max_model_len=2048, + gpu_memory_utilization=0.5, + enable_prefix_caching=True, + kv_transfer_config=_build_kv_transfer_config( + cpu_bytes_to_use=512 * (1 << 20), # 512 MiB + lazy_offload=lazy, + ), + enforce_eager=True, + ) as runner: + llm = runner.model + for i in range(8): + prompt_token_ids[0] = i + prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)] + outs = llm.generate(prompts, sampling_params, use_tqdm=False) + assert outs and len(outs[0].outputs[0].token_ids) > 0 diff --git a/vllm_ascend/distributed/kv_transfer/__init__.py b/vllm_ascend/distributed/kv_transfer/__init__.py index 45d50414ee3..12fb31a8532 100644 --- a/vllm_ascend/distributed/kv_transfer/__init__.py +++ b/vllm_ascend/distributed/kv_transfer/__init__.py @@ -57,3 +57,19 @@ def register_connector(): "vllm_ascend.distributed.kv_transfer.kv_pool.lmcache_ascend_connector", "LMCacheConnectorV1", ) + + # Override the upstream SimpleCPUOffloadConnector with the NPU + # adaptation that uses aclrtMemcpyBatchAsync + torch.npu streams. + # Only override if the upstream module exists in this vLLM version. + try: + import vllm.v1.simple_kv_offload # noqa: F401 + except ImportError: + pass + else: + if "SimpleCPUOffloadConnector" in KVConnectorFactory._registry: + KVConnectorFactory._registry.pop("SimpleCPUOffloadConnector") + KVConnectorFactory.register_connector( + "SimpleCPUOffloadConnector", + "vllm_ascend.distributed.kv_transfer.kv_pool.simple_cpu_offload.simple_cpu_offload_connector", # noqa: E501 + "AscendSimpleCPUOffloadConnector", + ) diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/simple_cpu_offload/__init__.py b/vllm_ascend/distributed/kv_transfer/kv_pool/simple_cpu_offload/__init__.py new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/simple_cpu_offload/__init__.py @@ -0,0 +1 @@ + diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/simple_cpu_offload/simple_cpu_offload_connector.py b/vllm_ascend/distributed/kv_transfer/kv_pool/simple_cpu_offload/simple_cpu_offload_connector.py new file mode 100644 index 00000000000..d4a59c5b27b --- /dev/null +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/simple_cpu_offload/simple_cpu_offload_connector.py @@ -0,0 +1,60 @@ +"""Ascend NPU adaptation of vLLM's ``SimpleCPUOffloadConnector``. + +The scheduler-side ``SimpleCPUOffloadScheduler`` is platform-agnostic +and reused as-is from upstream vLLM. The Ascend variant only swaps the +worker-side handler with an NPU-native implementation that uses +``aclrtMemcpyBatchAsync`` and ``torch.npu`` streams/events. +""" + +from typing import TYPE_CHECKING + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1.simple_cpu_offload_connector import ( # noqa: E501 + SimpleCPUOffloadConnector, +) +from vllm.logger import logger + +from vllm_ascend.simple_kv_offload.worker import SimpleCPUOffloadNPUWorker + +if TYPE_CHECKING: + from vllm.v1.kv_cache_interface import KVCacheConfig + + +class AscendSimpleCPUOffloadConnector(SimpleCPUOffloadConnector): + """NPU-flavored ``SimpleCPUOffloadConnector``. + + Inherits the full scheduler/worker plumbing from upstream and only + replaces the CUDA worker handler with the NPU one. All other public + APIs (``register_kv_caches``, ``bind_connector_metadata``, + ``get_finished``, ``handle_preemptions``, every scheduler-side + method, etc.) are inherited verbatim — they all route through + ``self.worker_handler`` / ``self.scheduler_manager``. + + Why post-init swap (instead of skipping ``super().__init__``): + ``SimpleCPUOffloadWorker.__init__`` and ``DmaCopyBackend.__init__`` + only assign ``None``/empty-field defaults — no CUDA resource is + allocated until ``register_kv_caches`` runs. So letting the parent + construct a transient CUDA worker and then replacing it costs + nothing and keeps us free of duplicated configuration parsing. + """ + + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig | None" = None, + ) -> None: + super().__init__(vllm_config, role, kv_cache_config) + + # If prefix caching is disabled, the parent leaves both handlers + # as None and the connector is a no-op — nothing to swap. + if role == KVConnectorRole.WORKER and self.worker_handler is not None: + cpu_capacity = self.worker_handler.cpu_capacity_bytes + self.worker_handler: SimpleCPUOffloadNPUWorker = SimpleCPUOffloadNPUWorker( + vllm_config, kv_cache_config, cpu_capacity + ) + logger.info( + "AscendSimpleCPUOffloadConnector: swapped CUDA worker for NPU worker (per_rank=%.2f GB)", + cpu_capacity / (1024**3), + ) diff --git a/vllm_ascend/simple_kv_offload/__init__.py b/vllm_ascend/simple_kv_offload/__init__.py new file mode 100644 index 00000000000..485926bd1a7 --- /dev/null +++ b/vllm_ascend/simple_kv_offload/__init__.py @@ -0,0 +1 @@ +"""NPU adaptation of vLLM's simple CPU KV-cache offloading.""" diff --git a/vllm_ascend/simple_kv_offload/copy_backend.py b/vllm_ascend/simple_kv_offload/copy_backend.py new file mode 100644 index 00000000000..5572ee8f1cd --- /dev/null +++ b/vllm_ascend/simple_kv_offload/copy_backend.py @@ -0,0 +1,124 @@ +"""DMA copy backend for NPU<->CPU block transfers. + +Mirrors :class:`vllm.v1.simple_kv_offload.copy_backend.DmaCopyBackend` +but routes batched memcpy through ``torch.ops._C_ascend.swap_blocks_batch`` +and uses ``torch.npu`` streams/events. +""" + +from __future__ import annotations + +import queue +import threading + +import torch + +from vllm_ascend.simple_kv_offload.npu_mem_ops import ( + DIRECTION_D2H, + DIRECTION_H2D, + BatchMemcpyParams, + build_params, + copy_blocks, +) + + +class NPUDmaCopyBackend: + """``aclrtMemcpyBatchAsync`` copy backend running on a worker thread. + + Two pre-built ``BatchMemcpyParams`` are cached (load=H2D, store=D2H). + Submitted jobs are dispatched in FIFO order to a single worker + thread; each job issues its copies on a dedicated NPU stream and + records an Event the main thread can poll without synchronizing + the device. + """ + + def __init__(self) -> None: + self._store_params: BatchMemcpyParams | None = None + self._load_params: BatchMemcpyParams | None = None + self._load_stream: torch.npu.Stream | None = None + self._store_stream: torch.npu.Stream | None = None + self._device: torch.device | None = None + self._queue: queue.SimpleQueue | None = None + self._thread: threading.Thread | None = None + self._shutdown: bool = False + + def init( + self, + npu_caches: dict[str, torch.Tensor], + cpu_caches: dict[str, torch.Tensor], + device: torch.device, + load_stream: torch.npu.Stream, + store_stream: torch.npu.Stream, + ) -> None: + self._load_stream = load_stream + self._store_stream = store_stream + self._device = device + # Stores go NPU->CPU (D2H), loads go CPU->NPU (H2D). + self._store_params = build_params(npu_caches, cpu_caches, DIRECTION_D2H) + self._load_params = build_params(cpu_caches, npu_caches, DIRECTION_H2D) + + self._queue = queue.SimpleQueue() + self._thread = threading.Thread( + target=self._copy_loop, + name="npu-kv-offload-copy", + daemon=True, + ) + self._thread.start() + + def launch_copy( + self, + src_blocks: list[int], + dst_blocks: list[int], + is_store: bool, + event_idx: int, + events_list: list[tuple[int, torch.npu.Event]], + ) -> None: + params = self._store_params if is_store else self._load_params + assert params is not None and self._queue is not None + self._queue.put((src_blocks, dst_blocks, params, is_store, event_idx, events_list)) + + def shutdown(self) -> None: + if self._shutdown: + return + self._shutdown = True + if self._queue is not None: + self._queue.put(None) + if self._thread is not None: + self._thread.join(timeout=5.0) + + # ------------------------------------------------------------------ + # Worker thread main loop + # ------------------------------------------------------------------ + def _copy_loop(self) -> None: + # NOTE: matches upstream cuda backend semantics — no cross-stream + # sync. The scheduler manager only schedules stores for blocks + # whose KV data is **confirmed computed** (see + # ``confirmed_tokens`` in ``SimpleCPUOffloadScheduler``), so + # those blocks have long been written and visible across streams + # by the time we read them here. Loads target GPU blocks held + # by ``BlockPool.touch`` until load completes, so they are also + # safe to write without a barrier. + assert self._device is not None + assert self._queue is not None + assert self._load_stream is not None + assert self._store_stream is not None + torch.npu.set_device(self._device) + + while True: + item = self._queue.get() + if item is None: + return + ( + src_blocks, + dst_blocks, + params, + is_store, + event_idx, + events_list, + ) = item + + stream = self._store_stream if is_store else self._load_stream + with torch.npu.stream(stream): + copy_blocks(src_blocks, dst_blocks, params) + event = torch.npu.Event() + event.record(stream) + events_list.append((event_idx, event)) diff --git a/vllm_ascend/simple_kv_offload/npu_mem_ops.py b/vllm_ascend/simple_kv_offload/npu_mem_ops.py new file mode 100644 index 00000000000..fd873b906ce --- /dev/null +++ b/vllm_ascend/simple_kv_offload/npu_mem_ops.py @@ -0,0 +1,99 @@ +"""Low-level NPU memory helpers: batched DMA transfers. + +Mirrors :mod:`vllm.v1.simple_kv_offload.cuda_mem_ops` but uses the +Ascend ``aclrtMemcpyBatchAsync`` path exposed via +``torch.ops._C_ascend.swap_blocks_batch`` (see +``csrc/torch_binding.cpp``). +""" + +from __future__ import annotations + +from typing import NamedTuple + +import numpy as np +import torch + +# Direction codes shared with csrc/torch_binding.cpp::swap_blocks_batch. +DIRECTION_H2D = 0 +DIRECTION_D2H = 1 + + +class BatchMemcpyParams(NamedTuple): + """Pre-computed per-tensor descriptors for batched block copy.""" + + src_bases: np.ndarray # [num_sub_tensors] int64 — data_ptr per tensor + dst_bases: np.ndarray # [num_sub_tensors] int64 + bpb: np.ndarray # [num_sub_tensors] int64 — bytes per block + num_sub_tensors: int + direction: int # DIRECTION_H2D or DIRECTION_D2H + + +def _ordered_tensors(caches: dict[str, torch.Tensor]) -> list[torch.Tensor]: + """Return values in insertion order (kept as a function for clarity).""" + return list(caches.values()) + + +def build_params( + src_caches: dict[str, torch.Tensor], + dst_caches: dict[str, torch.Tensor], + direction: int, +) -> BatchMemcpyParams: + """Build cached pointer/stride descriptors for all sub-tensors. + + Both ``src_caches`` and ``dst_caches`` must have identical keys and a + matching ``[num_blocks, block_bytes]`` layout (already prepared by + :class:`SimpleCPUOffloadNPUWorker.register_kv_caches`). + """ + assert list(src_caches.keys()) == list(dst_caches.keys()), "src/dst cache key order must match" + src_tensors = _ordered_tensors(src_caches) + dst_tensors = _ordered_tensors(dst_caches) + + src_bases: list[int] = [] + dst_bases: list[int] = [] + bpb: list[int] = [] + for s, d in zip(src_tensors, dst_tensors): + s_bpb = s.stride(0) * s.element_size() + d_bpb = d.stride(0) * d.element_size() + assert s_bpb == d_bpb, f"per-block bytes mismatch src={s_bpb} dst={d_bpb}" + src_bases.append(s.data_ptr()) + dst_bases.append(d.data_ptr()) + bpb.append(s_bpb) + + return BatchMemcpyParams( + src_bases=np.array(src_bases, dtype=np.int64), + dst_bases=np.array(dst_bases, dtype=np.int64), + bpb=np.array(bpb, dtype=np.int64), + num_sub_tensors=len(src_tensors), + direction=direction, + ) + + +def copy_blocks( + src_block_ids: list[int], + dst_block_ids: list[int], + params: BatchMemcpyParams, +) -> None: + """Issue a batched async DMA on the *current* NPU stream. + + The caller is expected to be inside a ``torch.npu.stream(...)`` + context so the issued copies bind to the dedicated transfer stream. + """ + n = len(src_block_ids) + if n == 0: + return + assert n == len(dst_block_ids), "src/dst block counts must match" + + src_ids = np.asarray(src_block_ids, dtype=np.int64) + dst_ids = np.asarray(dst_block_ids, dtype=np.int64) + + # Layout: (num_sub_tensors, n) flattened — contract of swap_blocks_batch. + bpb_col = params.bpb[:, None] + src_all = (params.src_bases[:, None] + src_ids[None, :] * bpb_col).ravel() + dst_all = (params.dst_bases[:, None] + dst_ids[None, :] * bpb_col).ravel() + sz_all = np.broadcast_to(bpb_col, (params.num_sub_tensors, n)).ravel().copy() + + batch_src = torch.from_numpy(src_all) + batch_dst = torch.from_numpy(dst_all) + batch_sizes = torch.from_numpy(sz_all) + + torch.ops._C_ascend.swap_blocks_batch(batch_src, batch_dst, batch_sizes, params.direction) diff --git a/vllm_ascend/simple_kv_offload/worker.py b/vllm_ascend/simple_kv_offload/worker.py new file mode 100644 index 00000000000..7bc64adbdc4 --- /dev/null +++ b/vllm_ascend/simple_kv_offload/worker.py @@ -0,0 +1,331 @@ +"""Worker-side handler for the Ascend ``SimpleCPUOffloadConnector``. + +Mirrors :class:`vllm.v1.simple_kv_offload.worker.SimpleCPUOffloadWorker` +but uses ``torch.npu`` streams/events and the NPU-flavored DMA backend. +The scheduler-side metadata protocol is identical and reused as-is. +""" + +from typing import TYPE_CHECKING + +import torch +from vllm.config import VllmConfig +from vllm.logger import logger +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.v1.simple_kv_offload.metadata import ( + SimpleCPUOffloadMetadata, + SimpleCPUOffloadWorkerMetadata, +) + +from vllm_ascend.simple_kv_offload.copy_backend import NPUDmaCopyBackend + +if TYPE_CHECKING: + from vllm.v1.kv_cache_interface import KVCacheConfig + + +def _flatten_kv_value( + value: torch.Tensor | tuple | list, +) -> list[torch.Tensor]: + """Yield every constituent tensor of a per-layer KV-cache entry. + + On Ascend, attention layers register ``kv_caches[name]`` as a tuple + of independently-allocated tensors (e.g. ``(k_cache, v_cache)``); + Mamba layers register a list. Each tensor has its own backing + storage and shape ``[num_blocks, ...]``. + """ + if isinstance(value, torch.Tensor): + return [value] + assert isinstance(value, (tuple, list)), f"unexpected kv_caches value type: {type(value)}" + return [t for t in value if isinstance(t, torch.Tensor)] + + +class SimpleCPUOffloadNPUWorker: + """Worker-side handler for CPU offloading transfers on Ascend NPU.""" + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: "KVCacheConfig | None", + cpu_capacity_bytes: int, + ) -> None: + self.vllm_config = vllm_config + self.kv_cache_config = kv_cache_config + self.cpu_capacity_bytes = cpu_capacity_bytes + + self.npu_kv_caches: dict[str, torch.Tensor] | None = None + self.cpu_kv_caches: dict[str, torch.Tensor] | None = None + self.device: torch.device | None = None + self.num_cpu_blocks: int = 0 + + self.load_stream: torch.npu.Stream | None = None + self.store_stream: torch.npu.Stream | None = None + + self._backend = NPUDmaCopyBackend() + + # FIFO of (event_idx, Event), monotonic per direction. + self._load_events: list[tuple[int, torch.npu.Event]] = [] + self._store_events: list[tuple[int, torch.npu.Event]] = [] + # High-water marks: highest event_idx completed per stream. + self._load_hwm: int = -1 + self._store_hwm: int = -1 + + self._connector_metadata: SimpleCPUOffloadMetadata | None = None + self._pending_load_event_indices: set[int] = set() + self._pending_store_event_indices: set[int] = set() + self._completed_store_events: dict[int, int] = {} + + # ------------------------------------------------------------------ + # KV cache registration + # ------------------------------------------------------------------ + def register_kv_caches( + self, + kv_caches: dict[str, torch.Tensor | tuple | list], + ) -> None: + """Register NPU KV caches and allocate pinned CPU mirrors. + + For every unique storage backing ``kv_caches`` we expose a + contiguous ``[num_blocks, block_bytes]`` int8 view. The batch + memcpy backend then strides blocks uniformly across all such + sub-tensors in a single ``aclrtMemcpyBatchAsync`` call. + """ + if not kv_caches: + logger.warning("No NPU KV caches to offload.") + return + + first_tensor = _flatten_kv_value(next(iter(kv_caches.values())))[0] + self.device = first_tensor.device + + assert self.kv_cache_config is not None + num_blocks = self.kv_cache_config.num_blocks + + # Deduplicate by untyped_storage().data_ptr(): a single NPU + # allocation may back multiple layers (e.g. shared KV across + # tied weights or via aliasing). On Ascend, K and V live in + # *separate* allocations, so we must iterate every sub-tensor + # — taking only ``value[0]`` would silently drop the V cache. + unique_caches: dict[str, torch.Tensor] = {} + seen_ptrs: set[int] = set() + for layer_name, value in kv_caches.items(): + for sub_idx, tensor in enumerate(_flatten_kv_value(value)): + storage = tensor.untyped_storage() + ptr = storage.data_ptr() + if ptr in seen_ptrs: + continue + seen_ptrs.add(ptr) + + key = layer_name if sub_idx == 0 else f"{layer_name}.{sub_idx}" + unique_caches.update(self._build_block_views(key, tensor, num_blocks)) + + per_tensor_bpb = [t.stride(0) * t.element_size() for t in unique_caches.values()] + total_bytes_per_block = sum(per_tensor_bpb) + self.num_cpu_blocks = max(1, self.cpu_capacity_bytes // total_bytes_per_block) + logger.info( + "SimpleCPUOffloadNPUWorker: %d unique NPU KV tensors, allocating %d CPU blocks (%.2f GB)", + len(unique_caches), + self.num_cpu_blocks, + (self.num_cpu_blocks * total_bytes_per_block) / (1024**3), + ) + + pin_memory = is_pin_memory_available() + if not pin_memory: + logger.warning("Pinned memory not available; CPU offload throughput may be degraded on this host.") + + self.npu_kv_caches = unique_caches + self.cpu_kv_caches = { + name: torch.zeros( + (self.num_cpu_blocks,) + tuple(t.shape[1:]), + dtype=t.dtype, + device="cpu", + pin_memory=pin_memory, + ) + for name, t in unique_caches.items() + } + + # Upstream creates these with the lowest CUDA priority so KV I/O + # yields to compute on the default stream. ``torch.npu`` does + # NOT expose ``Stream.priority_range()`` / a ``priority=`` kwarg + # (``RuntimeError: NPU does not support Stream.priority_range() + # currently``) and there is no equivalent torch_npu API today. + # Use plain transfer streams — matches every other + # ``torch.npu.Stream`` site in this repo. The transfers still + # run off the default compute stream, so they overlap with the + # forward pass; we only lose the explicit "always yield" hint, + # which is a soft scheduling preference and not a correctness + # requirement. + self.load_stream = torch.npu.Stream() + self.store_stream = torch.npu.Stream() + self._backend.init( + self.npu_kv_caches, + self.cpu_kv_caches, + self.device, + self.load_stream, + self.store_stream, + ) + + @staticmethod + def _build_block_views( + key: str, + tensor: torch.Tensor, + num_blocks: int, + ) -> dict[str, torch.Tensor]: + """Return ``{name: [num_blocks, block_bytes] int8 view}`` for one tensor. + + Sizes views from the tensor's own metadata, NOT + ``storage.nbytes()``. When offload is enabled, + ``NPUModelRunner._allocate_kv_cache_tensors`` over-allocates + each KV tensor by ``+alignment`` (2 MiB) and slices back with + ``_align_memory(...)[:size]``; ``storage.nbytes()`` then + includes alignment-driven leading offset *and* trailing + padding that are not part of the block grid (the total is in + general not a multiple of ``num_blocks``). + + Most Ascend layers register K and V as separate blocks-outermost + tensors (single segment). The ``cache_only_layers`` path with + ``AscendAttentionBackend`` produces ``(N, num_blocks, ...)`` — + N segments stacked in one allocation; we split it into N keyed + views. The runner's actual blocks-dim size may exceed + ``kv_cache_config.num_blocks``; we only view the leading + ``num_blocks`` blocks the connector knows about. + """ + el = tensor.element_size() + storage = tensor.untyped_storage() + storage_offset_bytes = tensor.storage_offset() * el + + if tensor.ndim >= 1 and tensor.shape[0] >= num_blocks: + # Single-segment, blocks-outermost. + page_size_bytes = tensor.stride(0) * el + data_bytes = num_blocks * page_size_bytes + raw = torch.empty(0, dtype=torch.int8, device=tensor.device).set_( + storage, storage_offset_bytes, (data_bytes,) + ) + return {key: raw.view(num_blocks, page_size_bytes)} + + # Multi-segment: ``(N, num_blocks, ...)`` is the only NPU layout + # observed (N=2 for K|V stacked). We assume a single outer + # partition dim before the blocks dim. + if tensor.ndim < 2 or tensor.shape[1] < num_blocks: + raise RuntimeError( + f"_build_block_views[{key}]: cannot locate blocks dim " + f"(expected shape[0] or shape[1] >= {num_blocks}) in " + f"shape {tuple(tensor.shape)}" + ) + page_size_bytes = tensor.stride(1) * el + seg_data_bytes = num_blocks * page_size_bytes + seg_stride_bytes = tensor.stride(0) * el + n_segments = tensor.shape[0] + total_bytes = (n_segments - 1) * seg_stride_bytes + seg_data_bytes + + raw = torch.empty(0, dtype=torch.int8, device=tensor.device).set_(storage, storage_offset_bytes, (total_bytes,)) + segs: dict[str, torch.Tensor] = {} + for idx in range(n_segments): + start = idx * seg_stride_bytes + chunk = raw[start : start + seg_data_bytes] + segs[f"{key}.{idx}"] = chunk.view(num_blocks, page_size_bytes) + return segs + + # ------------------------------------------------------------------ + # Per-step metadata plumbing + # ------------------------------------------------------------------ + def bind_connector_metadata(self, metadata: SimpleCPUOffloadMetadata) -> None: + self._connector_metadata = metadata + if metadata.load_event >= 0: + self._pending_load_event_indices.add(metadata.load_event) + if metadata.store_event >= 0: + self._pending_store_event_indices.add(metadata.store_event) + + def clear_connector_metadata(self) -> None: + self._connector_metadata = None + + def start_load_kv(self) -> None: + # Defer launching load/store until after model execution so the + # Python-side block-list build overlaps with NPU compute. + pass + + def wait_for_save(self) -> None: + pass + + def get_finished( + self, + finished_req_ids: set[str], + ) -> tuple[set[str] | None, set[str] | None]: + """Submit transfers and report completed events to the scheduler.""" + metadata = self._connector_metadata + if metadata is not None: + if metadata.load_cpu_blocks: + self._backend.launch_copy( + metadata.load_cpu_blocks, + metadata.load_gpu_blocks, + is_store=False, + event_idx=metadata.load_event, + events_list=self._load_events, + ) + if metadata.store_gpu_blocks: + self._backend.launch_copy( + metadata.store_gpu_blocks, + metadata.store_cpu_blocks, + is_store=True, + event_idx=metadata.store_event, + events_list=self._store_events, + ) + + finished_recving: set[str] = set() + + if self._pending_load_event_indices: + load_wm = self._poll_stream_events(is_store=False) + for j in [j for j in self._pending_load_event_indices if j <= load_wm]: + self._pending_load_event_indices.discard(j) + req_ids = metadata.load_event_to_reqs.get(j) if metadata is not None else None + if req_ids: + finished_recving.update(req_ids) + + if self._pending_store_event_indices: + store_wm = self._poll_stream_events(is_store=True) + for j in [j for j in self._pending_store_event_indices if j <= store_wm]: + self._pending_store_event_indices.discard(j) + self._completed_store_events[j] = 1 + + return None, finished_recving or None + + def build_connector_worker_meta( + self, + ) -> SimpleCPUOffloadWorkerMetadata | None: + if not self._completed_store_events: + return None + meta = SimpleCPUOffloadWorkerMetadata( + completed_store_events=self._completed_store_events, + ) + self._completed_store_events = {} + return meta + + def handle_preemptions(self, kv_connector_metadata: SimpleCPUOffloadMetadata) -> None: + if not kv_connector_metadata.need_flush: + return + self._flush_and_sync_all() + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _flush_and_sync_all(self) -> None: + for event_idx, event in self._load_events: + event.synchronize() + self._load_hwm = event_idx + self._load_events.clear() + + for event_idx, event in self._store_events: + event.synchronize() + self._store_hwm = event_idx + self._store_events.clear() + + def _poll_stream_events(self, is_store: bool) -> int: + events = self._store_events if is_store else self._load_events + hwm = self._store_hwm if is_store else self._load_hwm + while events: + event_idx, event = events[0] + if not event.query(): + break + hwm = event_idx + events.pop(0) + if is_store: + self._store_hwm = hwm + else: + self._load_hwm = hwm + return hwm