-
-
Notifications
You must be signed in to change notification settings - Fork 18.8k
[Perf] Triton fast path for small CPU→GPU swap_blocks_batch in the offloading connector
#42212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
dc578d8
de1e5b3
53115d7
72d1764
8b1abf6
d11e6a9
9f8bc77
3234210
58a8405
3165a0b
8357a8e
c5519f7
49c78b2
976423a
5d9b238
a09dbe7
5233a60
253f027
d11d59b
7d134da
3c8b122
0196c2a
7474996
4b5fd51
ad6d4a4
8253db2
85167e3
d171894
507c9bc
5015be6
5402fb9
84ad037
dc8f65b
1a77d38
414ddb6
ac18cb4
13ad6a9
24cc331
fc2af7a
45c8868
c1255cb
691b2f8
107ccb9
5aed71e
a4ba778
cb2b37e
9d07523
feb17f5
70cd571
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Triton fast path for ``swap_blocks_batch`` on small uniform batches.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.triton_utils import tl, triton | ||
|
|
||
| _NUM_SMS = 12 | ||
| _THRESHOLD_BYTES = 28 * 1024 | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _kernel( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will this work for other architectures?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AMD — the kernel is plain Triton, so it should run on ROCm, but the premise doesn't carry over: So _THRESHOLD_BYTES and _NUM_SMS (SMs vs CUs, different PCIe gen) would need re-measuring on AMD before it'd be worth enabling there.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Right. I'm just wondering if this can lead to an easy path to support offloading on these platforms using this triton kernel.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I get it. The kernel itself would run on ROCm, but not on HPU/TPU (no Triton) XPU might be but hacky. |
||
| src_addrs, | ||
| dst_addrs, | ||
| n_jobs, # type: ignore[name-defined] | ||
| bytes_per_job, # type: ignore[name-defined] | ||
| BYTES_PER_CHUNK: tl.constexpr, # type: ignore[name-defined] | ||
| ): | ||
| pid = tl.program_id(0) | ||
| num_progs = tl.num_programs(0) | ||
| WORDS_PER_CHUNK: tl.constexpr = BYTES_PER_CHUNK // 8 | ||
| words = bytes_per_job // 8 | ||
| offsets = tl.arange(0, WORDS_PER_CHUNK) | ||
| job = pid | ||
| while job < n_jobs: | ||
| src = tl.load(src_addrs + job).to(tl.pointer_type(tl.int64)) | ||
| dst = tl.load(dst_addrs + job).to(tl.pointer_type(tl.int64)) | ||
| for start in range(0, words, WORDS_PER_CHUNK): | ||
| idx = start + offsets | ||
| mask = idx < words | ||
| data = tl.load(src + idx, mask=mask, other=0) | ||
| tl.store(dst + idx, data, mask=mask) | ||
| job += num_progs | ||
|
|
||
|
|
||
| def swap_blocks_batch( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this function should move to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| src_addrs: torch.Tensor, | ||
| dst_addrs: torch.Tensor, | ||
| sizes: torch.Tensor, | ||
| ) -> None: | ||
| """Drop-in replacement for ``ops.swap_blocks_batch`` with Triton fast path.""" | ||
| n = src_addrs.numel() | ||
| if n == 0: | ||
| return | ||
| bpj = int(sizes[0].item()) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we choose a more meaningful variable name? |
||
| if bpj >= _THRESHOLD_BYTES or bpj % 8 != 0 or not bool((sizes == bpj).all()): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a comment explaining this criteria for choosing between cudamemcpybatch/triton?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| ops.swap_blocks_batch(src_addrs, dst_addrs, sizes) | ||
| return | ||
| chunk = min(triton.next_power_of_2(bpj), 8192) | ||
| _kernel[(min(_NUM_SMS, n),)]( | ||
| src_addrs.to("cuda", non_blocking=True), | ||
| dst_addrs.to("cuda", non_blocking=True), | ||
| n, | ||
| bpj, | ||
| BYTES_PER_CHUNK=chunk, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a comment on why did we choose these default values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.