-
-
Notifications
You must be signed in to change notification settings - Fork 18.8k
[Perf] Triton fast path for small CPU→GPU swap_blocks_batch in the offloading connector
#42212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 47 commits
dc578d8
de1e5b3
53115d7
72d1764
8b1abf6
d11e6a9
9f8bc77
3234210
58a8405
3165a0b
8357a8e
c5519f7
49c78b2
976423a
5d9b238
a09dbe7
5233a60
253f027
d11d59b
7d134da
3c8b122
0196c2a
7474996
4b5fd51
ad6d4a4
8253db2
85167e3
d171894
507c9bc
5015be6
5402fb9
84ad037
dc8f65b
1a77d38
414ddb6
ac18cb4
13ad6a9
24cc331
fc2af7a
45c8868
c1255cb
691b2f8
107ccb9
5aed71e
a4ba778
cb2b37e
9d07523
feb17f5
70cd571
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Unit test for the Triton ``swap_blocks_batch`` fast-path kernel.""" | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.platforms import current_platform | ||
| from vllm.v1.kv_offload.cpu.swap_blocks_triton import swap_blocks_batch | ||
|
|
||
|
|
||
| def _addrs(buffers: list[torch.Tensor]) -> torch.Tensor: | ||
| return torch.tensor([b.data_ptr() for b in buffers], dtype=torch.int64) | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not current_platform.is_cuda(), reason="Triton swap fast path requires CUDA" | ||
| ) | ||
| def test_triton_swap_matches_cpp_path(): | ||
| # 8-byte-aligned, sub-threshold sizes covering 8 KiB chunk boundaries and | ||
| # odd tail-mask lengths, with enough descriptors to take the Triton path. | ||
| sizes = [8, 4096, 8192, 8200, 16384, 4088] * 8 | ||
| src = [torch.randint(256, (s,), dtype=torch.uint8, device="cuda") for s in sizes] | ||
| dst_cpp = [torch.zeros_like(s) for s in src] | ||
| dst_tri = [torch.zeros_like(s) for s in src] | ||
| sizes_t = torch.tensor(sizes, dtype=torch.int64) | ||
|
|
||
| ops.swap_blocks_batch(_addrs(src), _addrs(dst_cpp), sizes_t.clone()) | ||
| swap_blocks_batch( | ||
| _addrs(src), _addrs(dst_tri), sizes_t.clone(), bytes_per_chunk=8192 | ||
| ) | ||
| torch.accelerator.synchronize() | ||
|
|
||
| for s, cpp, tri in zip(src, dst_cpp, dst_tri): | ||
| assert torch.equal(tri, s) # kernel copied the source bytes | ||
| assert torch.equal(tri, cpp) # ... identically to cuMemcpyBatchAsync |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Triton kernel + tuned constants for the ``swap_blocks_batch`` fast path.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.triton_utils import tl, triton | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be guarded by
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| # Constants tuned empirically on H100 (PCIe Gen5): | ||
| # NUM_SMS - smallest SM slice within 5% of peak bandwidth at the | ||
| # 8-32 KB block sizes that matter in practice | ||
| # THRESHOLD_BYTES - max payload per descriptor where Triton beats DMA; above | ||
| # this the C++ cuMemcpyBatchAsync path takes the lead | ||
| # MIN_N - minimum batch size where Triton's per-launch cost is | ||
| # amortized; below this DMA wins | ||
| NUM_SMS = 12 | ||
| THRESHOLD_BYTES = 28 * 1024 | ||
| MIN_N = 16 | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _swap_blocks_kernel( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need a unit test.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| src_addrs, | ||
| dst_addrs, | ||
| sizes, | ||
| n_jobs, # type: ignore[name-defined] | ||
| BYTES_PER_CHUNK: tl.constexpr, # type: ignore[name-defined] | ||
| ): | ||
| pid = tl.program_id(0) | ||
| num_progs = tl.num_programs(0) | ||
| WORDS_PER_CHUNK: tl.constexpr = BYTES_PER_CHUNK // 8 | ||
| offsets = tl.arange(0, WORDS_PER_CHUNK) | ||
| job = pid | ||
| while job < n_jobs: | ||
| src = tl.load(src_addrs + job).to(tl.pointer_type(tl.int64)) | ||
| dst = tl.load(dst_addrs + job).to(tl.pointer_type(tl.int64)) | ||
| words = tl.load(sizes + job) // 8 | ||
| for start in range(0, words, WORDS_PER_CHUNK): | ||
| idx = start + offsets | ||
| mask = idx < words | ||
| data = tl.load(src + idx, mask=mask, other=0) | ||
| tl.store(dst + idx, data, mask=mask) | ||
| job += num_progs | ||
|
|
||
|
|
||
| def swap_blocks_batch( | ||
| src_addrs: torch.Tensor, | ||
| dst_addrs: torch.Tensor, | ||
| sizes: torch.Tensor, | ||
| is_src_access_order_any: bool = False, | ||
| *, | ||
| bytes_per_chunk: int, | ||
| ) -> None: | ||
| """Triton implementation of ``swap_blocks_batch`` for small CPU->GPU batches.""" | ||
| n = src_addrs.numel() | ||
| # Too few descriptors to amortize Triton's launch cost. | ||
| if n < MIN_N: | ||
| ops.swap_blocks_batch( | ||
| src_addrs, | ||
| dst_addrs, | ||
| sizes, | ||
| is_src_access_order_any=is_src_access_order_any, | ||
| ) | ||
| return | ||
| _swap_blocks_kernel[(min(NUM_SMS, n),)]( | ||
| src_addrs.to("cuda", non_blocking=True), | ||
| dst_addrs.to("cuda", non_blocking=True), | ||
| sizes.to("cuda", non_blocking=True), | ||
| n, | ||
| BYTES_PER_CHUNK=bytes_per_chunk, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename this file
swap_blocks_triton.py?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done