Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
dc578d8
[Perf] Triton fast path for swap_blocks_batch on small uniform batches
EtelisIBM May 10, 2026
de1e5b3
Merge branch 'main' into perf/triton-swap-blocks-batch
EtelisIBM May 10, 2026
53115d7
[Perf] swap_blocks_batch Triton path: gate on CPU->GPU, allow heterog…
EtelisIBM May 10, 2026
72d1764
[Perf] Move swap_blocks_batch wrapper into gpu_worker; comment tuned …
EtelisIBM May 10, 2026
8b1abf6
[Perf] Resolve swap_blocks function once at handler init
EtelisIBM May 10, 2026
d11e6a9
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 11, 2026
9f8bc77
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 11, 2026
3234210
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 11, 2026
58a8405
[Fix] Include HAS_TRITON import and fallback logic in _select_swap_bl…
EtelisIBM May 17, 2026
3165a0b
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 24, 2026
8357a8e
[kv_offload] Pin swap descriptor arrays so non-blocking H2D is async
EtelisIBM May 24, 2026
c5519f7
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 25, 2026
49c78b2
[kv_offload] Drop leading underscores from shared swap_blocks constants
EtelisIBM May 26, 2026
976423a
[kv_offload] Recycle swap descriptor buffers across transfers
EtelisIBM May 26, 2026
5d9b238
[kv_offload] Add unit test for the Triton swap_blocks_batch kernel
EtelisIBM May 26, 2026
a09dbe7
[kv_offload] Hoist Triton swap_blocks_batch closure to a top-level fu…
EtelisIBM May 28, 2026
5233a60
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 28, 2026
253f027
[kv_offload] Move Triton swap_blocks_batch wrapper into triton_swap.py
EtelisIBM May 29, 2026
d11d59b
Rename triton_swap to swap_blocks_triton and expose swap_blocks_batch
EtelisIBM May 29, 2026
7d134da
Merge branch 'main' into perf/triton-swap-blocks-batch
mergify[bot] May 30, 2026
3c8b122
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 30, 2026
0196c2a
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 31, 2026
7474996
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis May 31, 2026
4b5fd51
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
ad6d4a4
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
8253db2
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
85167e3
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
d171894
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
507c9bc
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
5015be6
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
5402fb9
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
84ad037
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
dc8f65b
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
1a77d38
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
414ddb6
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
ac18cb4
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
13ad6a9
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
24cc331
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
fc2af7a
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 1, 2026
45c8868
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
c1255cb
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
691b2f8
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
107ccb9
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
5aed71e
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
a4ba778
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
cb2b37e
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
9d07523
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
feb17f5
Merge branch 'main' into perf/triton-swap-blocks-batch
Etelis Jun 2, 2026
70cd571
[kv_offload] Fix Triton swap test: drop unsupported D2D reference copy
EtelisIBM Jun 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/v1/kv_offload/cpu/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import torch

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available
Expand All @@ -18,6 +17,7 @@
GPULoadStoreSpec,
)
from vllm.v1.kv_offload.cpu.shared_offload_region import SharedOffloadRegion
from vllm.v1.kv_offload.cpu.triton_swap import swap_blocks_batch
from vllm.v1.kv_offload.worker.worker import (
OffloadingHandler,
TransferResult,
Expand Down Expand Up @@ -316,7 +316,7 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
with torch.cuda.stream(stream):
start_event.record(stream)
if num_copy_ops > 0:
ops.swap_blocks_batch(batch_src, batch_dst, batch_sizes)
swap_blocks_batch(batch_src, batch_dst, batch_sizes)
end_event.record(stream)

self._transfer_events[job_id] = end_event
Expand Down
61 changes: 61 additions & 0 deletions vllm/v1/kv_offload/cpu/triton_swap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Triton fast path for ``swap_blocks_batch`` on small uniform batches."""

from __future__ import annotations

import torch

from vllm import _custom_ops as ops
from vllm.triton_utils import tl, triton

_NUM_SMS = 12
_THRESHOLD_BYTES = 28 * 1024

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment on why did we choose these default values.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.



@triton.jit
def _kernel(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this work for other architectures?
e.g. AMD, XPU, HPU, TPU?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AMD — the kernel is plain Triton, so it should run on ROCm, but the premise doesn't carry over: So _THRESHOLD_BYTES and _NUM_SMS (SMs vs CUs, different PCIe gen) would need re-measuring on AMD before it'd be worth enabling there.
XPU / HPU / TPU — these can't use the OffloadingConnector CPU-offload path at all today? or am I missing something there?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XPU / HPU / TPU — these can't use the OffloadingConnector CPU-offload path at all today? or am I missing something there?

Right. I'm just wondering if this can lead to an easy path to support offloading on these platforms using this triton kernel.

@Etelis Etelis May 11, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I get it.

The kernel itself would run on ROCm, but not on HPU/TPU (no Triton)

XPU might be but hacky.

src_addrs,
dst_addrs,
n_jobs, # type: ignore[name-defined]
bytes_per_job, # type: ignore[name-defined]
BYTES_PER_CHUNK: tl.constexpr, # type: ignore[name-defined]
):
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
WORDS_PER_CHUNK: tl.constexpr = BYTES_PER_CHUNK // 8
words = bytes_per_job // 8
offsets = tl.arange(0, WORDS_PER_CHUNK)
job = pid
while job < n_jobs:
src = tl.load(src_addrs + job).to(tl.pointer_type(tl.int64))
dst = tl.load(dst_addrs + job).to(tl.pointer_type(tl.int64))
for start in range(0, words, WORDS_PER_CHUNK):
idx = start + offsets
mask = idx < words
data = tl.load(src + idx, mask=mask, other=0)
tl.store(dst + idx, data, mask=mask)
job += num_progs


def swap_blocks_batch(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function should move to gpu_worker.py.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

src_addrs: torch.Tensor,
dst_addrs: torch.Tensor,
sizes: torch.Tensor,
) -> None:
"""Drop-in replacement for ``ops.swap_blocks_batch`` with Triton fast path."""
n = src_addrs.numel()
if n == 0:
return
bpj = int(sizes[0].item())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we choose a more meaningful variable name?

if bpj >= _THRESHOLD_BYTES or bpj % 8 != 0 or not bool((sizes == bpj).all()):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment explaining this criteria for choosing between cudamemcpybatch/triton?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ops.swap_blocks_batch(src_addrs, dst_addrs, sizes)
return
chunk = min(triton.next_power_of_2(bpj), 8192)
_kernel[(min(_NUM_SMS, n),)](
src_addrs.to("cuda", non_blocking=True),
dst_addrs.to("cuda", non_blocking=True),
n,
bpj,
BYTES_PER_CHUNK=chunk,
)