diff --git a/tests/v1/kv_offload/cpu/test_swap_blocks_triton.py b/tests/v1/kv_offload/cpu/test_swap_blocks_triton.py new file mode 100644 index 000000000000..ec14a378434b --- /dev/null +++ b/tests/v1/kv_offload/cpu/test_swap_blocks_triton.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit test for the Triton ``swap_blocks_batch`` fast-path kernel.""" + +import pytest +import torch + +from vllm.platforms import current_platform +from vllm.v1.kv_offload.cpu.swap_blocks_triton import swap_blocks_batch + + +def _addrs(buffers: list[torch.Tensor]) -> torch.Tensor: + return torch.tensor([b.data_ptr() for b in buffers], dtype=torch.int64) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="Triton swap fast path requires CUDA" +) +def test_triton_swap_copies_source_bytes(): + # 8-byte-aligned, sub-threshold sizes covering 8 KiB chunk boundaries and + # odd tail-mask lengths, with enough descriptors to take the Triton path. + sizes = [8, 4096, 8192, 8200, 16384, 4088] * 8 + src = [torch.randint(256, (s,), dtype=torch.uint8, device="cuda") for s in sizes] + dst = [torch.zeros_like(s) for s in src] + sizes_t = torch.tensor(sizes, dtype=torch.int64) + + swap_blocks_batch(_addrs(src), _addrs(dst), sizes_t.clone(), bytes_per_chunk=8192) + torch.accelerator.synchronize() + + for s, t in zip(src, dst): + assert torch.equal(t, s) # kernel copied the source bytes verbatim diff --git a/vllm/v1/kv_offload/cpu/gpu_worker.py b/vllm/v1/kv_offload/cpu/gpu_worker.py index 119778368ca7..d0e73a2a0897 100644 --- a/vllm/v1/kv_offload/cpu/gpu_worker.py +++ b/vllm/v1/kv_offload/cpu/gpu_worker.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools import time from collections import deque from dataclasses import dataclass @@ -9,6 +10,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.triton_utils import HAS_TRITON, triton from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.kv_offload.base import ( @@ -18,6 +20,10 @@ GPULoadStoreSpec, ) from vllm.v1.kv_offload.cpu.shared_offload_region import SharedOffloadRegion +from vllm.v1.kv_offload.cpu.swap_blocks_triton import ( + THRESHOLD_BYTES, + swap_blocks_batch, +) from vllm.v1.kv_offload.worker.worker import ( OffloadingHandler, TransferResult, @@ -27,6 +33,30 @@ logger = init_logger(__name__) +def _select_swap_blocks_fn( + kv_cache_groups_data_refs: list[list[CanonicalKVCacheRef]], + gpu_to_cpu: bool, +): + """Resolve the swap_blocks function for a handler at init time.""" + # GPU->CPU is bandwidth-bound; the dedicated copy engine beats Triton. + if gpu_to_cpu: + return ops.swap_blocks_batch + # Fall back to the C++ DMA path on platforms where Triton isn't usable + # (e.g. ROCm builds without Triton). + if not HAS_TRITON: + return ops.swap_blocks_batch + page_sizes = [r.page_size_bytes for g in kv_cache_groups_data_refs for r in g] + # Triton wins only on small, 8-byte-aligned payloads. + if ( + not page_sizes + or max(page_sizes) >= THRESHOLD_BYTES + or any(s % 8 for s in page_sizes) + ): + return ops.swap_blocks_batch + chunk = min(triton.next_power_of_2(max(page_sizes)), 8192) + return functools.partial(swap_blocks_batch, bytes_per_chunk=chunk) + + @dataclass class Transfer: job_id: int @@ -34,6 +64,9 @@ class Transfer: start_event: torch.Event end_event: torch.Event num_bytes: int + batch_src: torch.Tensor + batch_dst: torch.Tensor + batch_sizes: torch.Tensor def compute_sub_block_ptrs( @@ -108,6 +141,17 @@ def pin_mmap_region(region: SharedOffloadRegion) -> None: region.is_pinned = True +def _new_descriptor_buffers( + num_copy_ops: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + pin = is_pin_memory_available() + return ( + torch.empty(num_copy_ops, dtype=torch.int64, pin_memory=pin), + torch.empty(num_copy_ops, dtype=torch.int64, pin_memory=pin), + torch.empty(num_copy_ops, dtype=torch.int64, pin_memory=pin), + ) + + class SingleDirectionOffloadingHandler(OffloadingHandler): """ SingleDirectionOffloadingHandler handles transfers for a single direction, @@ -161,6 +205,9 @@ def __init__( ) self.gpu_to_cpu: bool = gpu_to_cpu self.kv_cache_groups_data_refs = kv_cache_groups_data_refs + self._swap_blocks_batch = _select_swap_blocks_fn( + kv_cache_groups_data_refs, gpu_to_cpu + ) # GPU blocks may be smaller # cpu_page_size = gpu_page_size * block_size_factor. @@ -178,6 +225,8 @@ def __init__( self._stream_pool: list[torch.cuda.Stream] = [] # list of CUDA events available for re-use self._event_pool: list[torch.Event] = [] + # list of pinned descriptor buffer sets available for re-use + self._buffer_pool: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool: src_spec, dst_spec = transfer_spec @@ -227,9 +276,21 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool: ): num_copy_ops += group_size * len(group_data_refs) - all_src = np.empty(num_copy_ops, dtype=np.int64) - all_dst = np.empty(num_copy_ops, dtype=np.int64) - all_sizes = np.empty(num_copy_ops, dtype=np.int64) + # reuse a pooled buffer set, growing it if this transfer needs more room + batch_src, batch_dst, batch_sizes = ( + self._buffer_pool.pop() + if self._buffer_pool + else _new_descriptor_buffers(num_copy_ops) + ) + if batch_src.numel() < num_copy_ops: + batch_src, batch_dst, batch_sizes = _new_descriptor_buffers(num_copy_ops) + + src = batch_src[:num_copy_ops] + dst = batch_dst[:num_copy_ops] + sizes = batch_sizes[:num_copy_ops] + all_src = src.numpy() + all_dst = dst.numpy() + all_sizes = sizes.numpy() src_offset = 0 dst_offset = 0 @@ -292,10 +353,6 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool: assert dst_offset == num_dst_blocks assert op_idx == num_copy_ops - batch_src = torch.from_numpy(all_src) - batch_dst = torch.from_numpy(all_dst) - batch_sizes = torch.from_numpy(all_sizes) - stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream() start_event = ( self._event_pool.pop() @@ -326,10 +383,10 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool: with torch.cuda.stream(stream): start_event.record(stream) if num_copy_ops > 0: - ops.swap_blocks_batch( - batch_src, - batch_dst, - batch_sizes, + self._swap_blocks_batch( + src, + dst, + sizes, is_src_access_order_any=is_src_access_order_any, ) end_event.record(stream) @@ -342,6 +399,9 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool: start_event=start_event, end_event=end_event, num_bytes=num_transfer_bytes, + batch_src=batch_src, + batch_dst=batch_dst, + batch_sizes=batch_sizes, ) ) @@ -367,6 +427,9 @@ def get_finished(self) -> list[TransferResult]: self._stream_pool.append(transfer.stream) self._event_pool.append(transfer.end_event) self._event_pool.append(transfer.start_event) + self._buffer_pool.append( + (transfer.batch_src, transfer.batch_dst, transfer.batch_sizes) + ) del self._transfer_events[transfer.job_id] return results @@ -383,6 +446,7 @@ def shutdown(self) -> None: self._transfer_events.clear() self._stream_pool.clear() self._event_pool.clear() + self._buffer_pool.clear() self.src_tensors.clear() self.dst_tensors.clear() if self._mmap_region is not None: diff --git a/vllm/v1/kv_offload/cpu/swap_blocks_triton.py b/vllm/v1/kv_offload/cpu/swap_blocks_triton.py new file mode 100644 index 000000000000..77d9028d7395 --- /dev/null +++ b/vllm/v1/kv_offload/cpu/swap_blocks_triton.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Triton kernel + tuned constants for the ``swap_blocks_batch`` fast path.""" + +from __future__ import annotations + +import torch + +from vllm import _custom_ops as ops +from vllm.triton_utils import tl, triton + +# Constants tuned empirically on H100 (PCIe Gen5): +# NUM_SMS - smallest SM slice within 5% of peak bandwidth at the +# 8-32 KB block sizes that matter in practice +# THRESHOLD_BYTES - max payload per descriptor where Triton beats DMA; above +# this the C++ cuMemcpyBatchAsync path takes the lead +# MIN_N - minimum batch size where Triton's per-launch cost is +# amortized; below this DMA wins +NUM_SMS = 12 +THRESHOLD_BYTES = 28 * 1024 +MIN_N = 16 + + +@triton.jit +def _swap_blocks_kernel( + src_addrs, + dst_addrs, + sizes, + n_jobs, # type: ignore[name-defined] + BYTES_PER_CHUNK: tl.constexpr, # type: ignore[name-defined] +): + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + WORDS_PER_CHUNK: tl.constexpr = BYTES_PER_CHUNK // 8 + offsets = tl.arange(0, WORDS_PER_CHUNK) + job = pid + while job < n_jobs: + src = tl.load(src_addrs + job).to(tl.pointer_type(tl.int64)) + dst = tl.load(dst_addrs + job).to(tl.pointer_type(tl.int64)) + words = tl.load(sizes + job) // 8 + for start in range(0, words, WORDS_PER_CHUNK): + idx = start + offsets + mask = idx < words + data = tl.load(src + idx, mask=mask, other=0) + tl.store(dst + idx, data, mask=mask) + job += num_progs + + +def swap_blocks_batch( + src_addrs: torch.Tensor, + dst_addrs: torch.Tensor, + sizes: torch.Tensor, + is_src_access_order_any: bool = False, + *, + bytes_per_chunk: int, +) -> None: + """Triton implementation of ``swap_blocks_batch`` for small CPU->GPU batches.""" + n = src_addrs.numel() + # Too few descriptors to amortize Triton's launch cost. + if n < MIN_N: + ops.swap_blocks_batch( + src_addrs, + dst_addrs, + sizes, + is_src_access_order_any=is_src_access_order_any, + ) + return + _swap_blocks_kernel[(min(NUM_SMS, n),)]( + src_addrs.to("cuda", non_blocking=True), + dst_addrs.to("cuda", non_blocking=True), + sizes.to("cuda", non_blocking=True), + n, + BYTES_PER_CHUNK=bytes_per_chunk, + )