Skip to content
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
dc578d8
[Perf] Triton fast path for swap_blocks_batch on small uniform batches
EtelisIBM May 10, 2026
de1e5b3
Merge branch 'main' into perf/triton-swap-blocks-batch
EtelisIBM May 10, 2026
53115d7
[Perf] swap_blocks_batch Triton path: gate on CPU->GPU, allow heterog…
EtelisIBM May 10, 2026
72d1764
[Perf] Move swap_blocks_batch wrapper into gpu_worker; comment tuned …
EtelisIBM May 10, 2026
8b1abf6
[Perf] Resolve swap_blocks function once at handler init
EtelisIBM May 10, 2026
d11e6a9
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 11, 2026
9f8bc77
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 11, 2026
3234210
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 11, 2026
58a8405
[Fix] Include HAS_TRITON import and fallback logic in _select_swap_bl…
EtelisIBM May 17, 2026
3165a0b
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 24, 2026
8357a8e
[kv_offload] Pin swap descriptor arrays so non-blocking H2D is async
EtelisIBM May 24, 2026
c5519f7
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 25, 2026
49c78b2
[kv_offload] Drop leading underscores from shared swap_blocks constants
EtelisIBM May 26, 2026
976423a
[kv_offload] Recycle swap descriptor buffers across transfers
EtelisIBM May 26, 2026
5d9b238
[kv_offload] Add unit test for the Triton swap_blocks_batch kernel
EtelisIBM May 26, 2026
a09dbe7
[kv_offload] Hoist Triton swap_blocks_batch closure to a top-level fu…
EtelisIBM May 28, 2026
5233a60
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 28, 2026
253f027
[kv_offload] Move Triton swap_blocks_batch wrapper into triton_swap.py
EtelisIBM May 29, 2026
d11d59b
Rename triton_swap to swap_blocks_triton and expose swap_blocks_batch
EtelisIBM May 29, 2026
7d134da
Merge branch 'main' into perf/triton-swap-blocks-batch
mergify[bot] May 30, 2026
3c8b122
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 30, 2026
0196c2a
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 31, 2026
7474996
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 31, 2026
4b5fd51
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
ad6d4a4
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
8253db2
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
85167e3
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
d171894
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
507c9bc
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
5015be6
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
5402fb9
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
84ad037
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
dc8f65b
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
1a77d38
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
414ddb6
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
ac18cb4
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
13ad6a9
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
24cc331
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
fc2af7a
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
45c8868
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
c1255cb
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
691b2f8
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
107ccb9
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
5aed71e
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
a4ba778
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
cb2b37e
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
9d07523
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
feb17f5
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
70cd571
[kv_offload] Fix Triton swap test: drop unsupported D2D reference copy
EtelisIBM Jun 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/v1/kv_offload/cpu/test_swap_blocks_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit test for the Triton ``swap_blocks_batch`` fast-path kernel."""

import pytest
import torch

from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.v1.kv_offload.cpu.swap_blocks_triton import swap_blocks_batch


def _addrs(buffers: list[torch.Tensor]) -> torch.Tensor:
return torch.tensor([b.data_ptr() for b in buffers], dtype=torch.int64)


@pytest.mark.skipif(
not current_platform.is_cuda(), reason="Triton swap fast path requires CUDA"
)
def test_triton_swap_matches_cpp_path():
# 8-byte-aligned, sub-threshold sizes covering 8 KiB chunk boundaries and
# odd tail-mask lengths, with enough descriptors to take the Triton path.
sizes = [8, 4096, 8192, 8200, 16384, 4088] * 8
src = [torch.randint(256, (s,), dtype=torch.uint8, device="cuda") for s in sizes]
dst_cpp = [torch.zeros_like(s) for s in src]
dst_tri = [torch.zeros_like(s) for s in src]
sizes_t = torch.tensor(sizes, dtype=torch.int64)

ops.swap_blocks_batch(_addrs(src), _addrs(dst_cpp), sizes_t.clone())
swap_blocks_batch(
_addrs(src), _addrs(dst_tri), sizes_t.clone(), bytes_per_chunk=8192
)
torch.accelerator.synchronize()

for s, cpp, tri in zip(src, dst_cpp, dst_tri):
assert torch.equal(tri, s) # kernel copied the source bytes
assert torch.equal(tri, cpp) # ... identically to cuMemcpyBatchAsync
86 changes: 75 additions & 11 deletions vllm/v1/kv_offload/cpu/gpu_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import time
from collections import deque
from dataclasses import dataclass
Expand All @@ -9,6 +10,7 @@

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.kv_offload.base import (
Expand All @@ -18,6 +20,10 @@
GPULoadStoreSpec,
)
from vllm.v1.kv_offload.cpu.shared_offload_region import SharedOffloadRegion
from vllm.v1.kv_offload.cpu.swap_blocks_triton import (
THRESHOLD_BYTES,
swap_blocks_batch,
)
from vllm.v1.kv_offload.worker.worker import (
OffloadingHandler,
TransferResult,
Expand All @@ -27,13 +33,40 @@
logger = init_logger(__name__)


def _select_swap_blocks_fn(
kv_cache_groups_data_refs: list[list[CanonicalKVCacheRef]],
gpu_to_cpu: bool,
):
"""Resolve the swap_blocks function for a handler at init time."""
# GPU->CPU is bandwidth-bound; the dedicated copy engine beats Triton.
if gpu_to_cpu:
return ops.swap_blocks_batch
# Fall back to the C++ DMA path on platforms where Triton isn't usable
# (e.g. ROCm builds without Triton).
if not HAS_TRITON:
return ops.swap_blocks_batch
page_sizes = [r.page_size_bytes for g in kv_cache_groups_data_refs for r in g]
# Triton wins only on small, 8-byte-aligned payloads.
if (
not page_sizes
or max(page_sizes) >= THRESHOLD_BYTES
or any(s % 8 for s in page_sizes)
):
return ops.swap_blocks_batch
chunk = min(triton.next_power_of_2(max(page_sizes)), 8192)
return functools.partial(swap_blocks_batch, bytes_per_chunk=chunk)


@dataclass
class Transfer:
job_id: int
stream: torch.cuda.Stream
start_event: torch.Event
end_event: torch.Event
num_bytes: int
batch_src: torch.Tensor
batch_dst: torch.Tensor
batch_sizes: torch.Tensor


def compute_sub_block_ptrs(
Expand Down Expand Up @@ -108,6 +141,17 @@ def pin_mmap_region(region: SharedOffloadRegion) -> None:
region.is_pinned = True


def _new_descriptor_buffers(
num_copy_ops: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pin = is_pin_memory_available()
return (
torch.empty(num_copy_ops, dtype=torch.int64, pin_memory=pin),
torch.empty(num_copy_ops, dtype=torch.int64, pin_memory=pin),
torch.empty(num_copy_ops, dtype=torch.int64, pin_memory=pin),
)


class SingleDirectionOffloadingHandler(OffloadingHandler):
"""
SingleDirectionOffloadingHandler handles transfers for a single direction,
Expand Down Expand Up @@ -161,6 +205,9 @@ def __init__(
)
self.gpu_to_cpu: bool = gpu_to_cpu
self.kv_cache_groups_data_refs = kv_cache_groups_data_refs
self._swap_blocks_batch = _select_swap_blocks_fn(
kv_cache_groups_data_refs, gpu_to_cpu
)

# GPU blocks may be smaller
# cpu_page_size = gpu_page_size * block_size_factor.
Expand All @@ -178,6 +225,8 @@ def __init__(
self._stream_pool: list[torch.cuda.Stream] = []
# list of CUDA events available for re-use
self._event_pool: list[torch.Event] = []
# list of pinned descriptor buffer sets available for re-use
self._buffer_pool: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = []

def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
src_spec, dst_spec = transfer_spec
Expand Down Expand Up @@ -227,9 +276,21 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
):
num_copy_ops += group_size * len(group_data_refs)

all_src = np.empty(num_copy_ops, dtype=np.int64)
all_dst = np.empty(num_copy_ops, dtype=np.int64)
all_sizes = np.empty(num_copy_ops, dtype=np.int64)
# reuse a pooled buffer set, growing it if this transfer needs more room
batch_src, batch_dst, batch_sizes = (
self._buffer_pool.pop()
if self._buffer_pool
else _new_descriptor_buffers(num_copy_ops)
)
if batch_src.numel() < num_copy_ops:
batch_src, batch_dst, batch_sizes = _new_descriptor_buffers(num_copy_ops)

src = batch_src[:num_copy_ops]
dst = batch_dst[:num_copy_ops]
sizes = batch_sizes[:num_copy_ops]
all_src = src.numpy()
all_dst = dst.numpy()
all_sizes = sizes.numpy()

src_offset = 0
dst_offset = 0
Expand Down Expand Up @@ -292,10 +353,6 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
assert dst_offset == num_dst_blocks
assert op_idx == num_copy_ops

batch_src = torch.from_numpy(all_src)
batch_dst = torch.from_numpy(all_dst)
batch_sizes = torch.from_numpy(all_sizes)

stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream()
start_event = (
self._event_pool.pop()
Expand Down Expand Up @@ -326,10 +383,10 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
with torch.cuda.stream(stream):
start_event.record(stream)
if num_copy_ops > 0:
ops.swap_blocks_batch(
batch_src,
batch_dst,
batch_sizes,
self._swap_blocks_batch(
src,
dst,
sizes,
is_src_access_order_any=is_src_access_order_any,
)
end_event.record(stream)
Expand All @@ -342,6 +399,9 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
start_event=start_event,
end_event=end_event,
num_bytes=num_transfer_bytes,
batch_src=batch_src,
batch_dst=batch_dst,
batch_sizes=batch_sizes,
)
)

Expand All @@ -367,6 +427,9 @@ def get_finished(self) -> list[TransferResult]:
self._stream_pool.append(transfer.stream)
self._event_pool.append(transfer.end_event)
self._event_pool.append(transfer.start_event)
self._buffer_pool.append(
(transfer.batch_src, transfer.batch_dst, transfer.batch_sizes)
)
del self._transfer_events[transfer.job_id]
return results

Expand All @@ -383,6 +446,7 @@ def shutdown(self) -> None:
self._transfer_events.clear()
self._stream_pool.clear()
self._event_pool.clear()
self._buffer_pool.clear()
self.src_tensors.clear()
self.dst_tensors.clear()
if self._mmap_region is not None:
Expand Down
74 changes: 74 additions & 0 deletions vllm/v1/kv_offload/cpu/swap_blocks_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this file swap_blocks_triton.py?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Triton kernel + tuned constants for the ``swap_blocks_batch`` fast path."""

from __future__ import annotations

import torch

from vllm import _custom_ops as ops
from vllm.triton_utils import tl, triton

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be guarded by HAS_TRITON

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thevllm.triton_utils has a mechanism there, we'd need that if we'd directly import triton.


# Constants tuned empirically on H100 (PCIe Gen5):
# NUM_SMS - smallest SM slice within 5% of peak bandwidth at the
# 8-32 KB block sizes that matter in practice
# THRESHOLD_BYTES - max payload per descriptor where Triton beats DMA; above
# this the C++ cuMemcpyBatchAsync path takes the lead
# MIN_N - minimum batch size where Triton's per-launch cost is
# amortized; below this DMA wins
NUM_SMS = 12
THRESHOLD_BYTES = 28 * 1024
MIN_N = 16


@triton.jit
def _swap_blocks_kernel(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a unit test.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

src_addrs,
dst_addrs,
sizes,
n_jobs, # type: ignore[name-defined]
BYTES_PER_CHUNK: tl.constexpr, # type: ignore[name-defined]
):
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
WORDS_PER_CHUNK: tl.constexpr = BYTES_PER_CHUNK // 8
offsets = tl.arange(0, WORDS_PER_CHUNK)
job = pid
while job < n_jobs:
src = tl.load(src_addrs + job).to(tl.pointer_type(tl.int64))
dst = tl.load(dst_addrs + job).to(tl.pointer_type(tl.int64))
words = tl.load(sizes + job) // 8
for start in range(0, words, WORDS_PER_CHUNK):
idx = start + offsets
mask = idx < words
data = tl.load(src + idx, mask=mask, other=0)
tl.store(dst + idx, data, mask=mask)
job += num_progs


def swap_blocks_batch(
src_addrs: torch.Tensor,
dst_addrs: torch.Tensor,
sizes: torch.Tensor,
is_src_access_order_any: bool = False,
*,
bytes_per_chunk: int,
) -> None:
"""Triton implementation of ``swap_blocks_batch`` for small CPU->GPU batches."""
n = src_addrs.numel()
# Too few descriptors to amortize Triton's launch cost.
if n < MIN_N:
ops.swap_blocks_batch(
src_addrs,
dst_addrs,
sizes,
is_src_access_order_any=is_src_access_order_any,
)
return
_swap_blocks_kernel[(min(NUM_SMS, n),)](
src_addrs.to("cuda", non_blocking=True),
dst_addrs.to("cuda", non_blocking=True),
sizes.to("cuda", non_blocking=True),
n,
BYTES_PER_CHUNK=bytes_per_chunk,
)
Loading