Skip to content

[Perf] Triton fast path for small CPU→GPU swap_blocks_batch in the offloading connector#42212

Merged
orozery merged 49 commits into
vllm-project:mainfrom
Etelis:perf/triton-swap-blocks-batch
Jun 3, 2026
Merged

[Perf] Triton fast path for small CPU→GPU swap_blocks_batch in the offloading connector#42212
orozery merged 49 commits into
vllm-project:mainfrom
Etelis:perf/triton-swap-blocks-batch

Conversation

@Etelis

@Etelis Etelis commented May 10, 2026

Copy link
Copy Markdown
Contributor

Summary

OffloadingConnector copies KV between host and device via
cuMemcpyBatchAsync. That call saturates PCIe for large contiguous copies,
but on the CPU→GPU (onload / "read") direction it collapses for small
per-descriptor payloads
— the regime KV offload actually runs in. This PR adds
a small Triton kernel (_swap_blocks_kernel) that takes over the CPU→GPU
direction, gated on batch size (n ≥ 16) and small payloads
(max(sizes) < 28 KiB).

Inspired by @ivanium's prototype, adapted to the connector's flat
(src_addr, dst_addr, size) interface.

Motivation — the DMA small-page cliff (onload)

A single cuMemcpyBatchAsync of N small descriptors (one batched copy = one
production swap) is flat at ~5–7 GB/s for every page from 4 KiB up to 24 KiB,
then jumps ~6× at 28 KiB
to the copy-engine ceiling. The cliff deepens with
descriptor count N (i.e. with prefix length) — exactly the wrong direction for
KV offload. Per-call DMA read throughput vs page (H100, fast GPU):

page DMA single-call N=10000 N=50000 N=100000
4 KiB 6.8 4.0 3.8
8 KiB 6.8 5.2 5.3
16 KiB 7.1 6.3 6.3
24 KiB 25.8 6.7 6.9
28 KiB 27.5 39.9 44.1
64 KiB 38.8 51.4 51.8
pr_dma_cliff_read

The GPU→CPU (offload) direction has no such cliff in the aggregate (posted
writes degrade gradually), so this PR leaves offload on cuMemcpyBatchAsync.

Threshold choice (_THRESHOLD_BYTES = 28 KiB)

DMA vs Triton, onload, matched N=10000, fast GPU. DMA shown back-to-back
(sustained, reaches the ~55 ceiling):

page DMA Triton sm12 Triton sm16 winner
4 KiB 7.7 23.8 29.5 Triton 3.1×
8 KiB 14.9 40.9 48.6 Triton 2.7×
16 KiB 30.0 46.1 51.4 Triton 1.5×
24 KiB 44.9 47.8 51.4 Triton 1.06×
28 KiB 52.3 43.2 46.7 DMA
32 KiB 53.9 48.9 51.5 DMA
64 KiB 55.0 49.8 51.3 DMA

The crossover is 28 KiB across every SM count ≥ 12.

v2_phase1_crossover_allsm

SM count (_NUM_SMS = 12)

Onload bandwidth vs SM count, median over N ∈ {5K…200K}, fast GPU. Ceiling ≈ 51 GB/s:

page DMA sm8 sm12 sm16 sm20 sm24
4 KiB 4.0 13.9 23.0 27.9 37.1 42.2
8 KiB 5.4 20.1 36.2 40.3 51.1 51.3
16 KiB 6.5 27.7 41.4 49.0 51.5 51.4

E2E req/s is identical for sm12 / sm16 / sm20 despite their 38 / 45 / 51 GB/s
kernel ladder, so the smallest slice wins (least decode contention).

v2_phase1_finepage

End-to-end — gpt-oss-120b, TP=2, OffloadingConnector (24 GB), repeated measures

3 separate-boot reps per cell, median [min,max]; fast-path GPUs; per-layer KV
(8 KiB onload descriptors → kernel engages). B1 = vanilla DMA (HMA-on),
B2 = crosslayers (HMA-off, bundled DMA), PR = Triton onload.

config req/s 12K req/s 24K req/s 48K
B1 vanilla DMA 11.1 [9.3, 11.8] 6.5 [5.8, 6.8] 3.1 [2.9, 3.4]
B2 crosslayers 21.5 [21.1, 22.0] 12.3 [11.4, 12.6] 2.3 [2.2, 2.4]
PR sm12 28.1 [23.9, 28.4] 13.7 [13.0, 14.1] 6.4 [5.6, 6.4]
config TTFT 12K/24K/48K (ms) TPOT 12K/24K/48K (ms) delivered C→G GB/s kernel GB/s
B1 46750 / 78783 / 169643 224 / 258 / 250 4.8 — (DMA)
B2 23962 / 41758 / 217905 70 / 69 / 158 53.8 — (DMA)
PR sm12 17519 / 36475 / 79505 102 / 91 / 56 37.4 38.0

PR beats vanilla DMA (B1) +105–154% req/s at every prefix and lifts onload
delivered bandwidth 4.8 → 37 GB/s.

cuMemcpyBatchAsync is descriptor-overhead-bound for small blocks. For
uniformly-sized batches under 28 KiB per block, route the copy through
a single Triton-kernel launch instead. Falls back to the existing C++
path above the threshold or for non-uniform sizes.

Empirically tuned on H100 (PCIe Gen5):
  * num_sms = 12 — knee within 5% of peak; 9% of compute taken
  * threshold = 28 KiB — exact crossover identical across pair counts
  * 1.13-3.16x speedup below threshold; 1.00x above (no regression)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis requested review from ApostaC and orozery as code owners May 10, 2026 07:12

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the v1 label May 10, 2026
@mergify

mergify Bot commented May 10, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Etelis.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 10, 2026
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Triton-based fast path for the swap_blocks_batch operation, optimizing small uniform batches during KV offloading. The implementation includes a fallback to the original custom ops for large or non-uniform batches. Feedback indicates that the address tensors are currently allocated in unpinned memory, which causes the .to('cuda', non_blocking=True) calls to perform synchronous copies and block the CPU thread. It is recommended to use pinned memory for these tensors to ensure true asynchronous execution and avoid blocking the worker thread.

I am having trouble creating individual review comments. Click here to see my feedback.

vllm/v1/kv_offload/cpu/gpu_worker.py (319)

high

The batch_src and batch_dst tensors are created from unpinned numpy arrays (lines 292-293). Consequently, the .to('cuda', non_blocking=True) calls inside swap_blocks_batch (lines 56-57 of triton_swap.py) will result in synchronous host-to-device copies, blocking the CPU thread. To achieve true asynchronous execution and avoid blocking the scheduler/worker thread, consider using pinned memory for these address tensors (e.g., by allocating them with torch.empty(..., pin_memory=True) and using their numpy views).

@mergify mergify Bot removed the needs-rebase label May 10, 2026
@mergify

mergify Bot commented May 10, 2026

Copy link
Copy Markdown
Contributor

Hi @Etelis, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

job += num_progs


def swap_blocks_batch(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function should move to gpu_worker.py.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread vllm/v1/kv_offload/cpu/triton_swap.py Outdated
n = src_addrs.numel()
if n == 0:
return
bpj = int(sizes[0].item())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we choose a more meaningful variable name?

Comment thread vllm/v1/kv_offload/cpu/triton_swap.py Outdated
if n == 0:
return
bpj = int(sizes[0].item())
if bpj >= _THRESHOLD_BYTES or bpj % 8 != 0 or not bool((sizes == bpj).all()):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment explaining this criteria for choosing between cudamemcpybatch/triton?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment thread vllm/v1/kv_offload/cpu/triton_swap.py Outdated
Comment on lines +12 to +13
_NUM_SMS = 12
_THRESHOLD_BYTES = 28 * 1024

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment on why did we choose these default values.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment thread vllm/v1/kv_offload/cpu/triton_swap.py Outdated
Comment on lines +16 to +17
@triton.jit
def _kernel(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this work for other architectures?
e.g. AMD, XPU, HPU, TPU?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AMD — the kernel is plain Triton, so it should run on ROCm, but the premise doesn't carry over: So _THRESHOLD_BYTES and _NUM_SMS (SMs vs CUs, different PCIe gen) would need re-measuring on AMD before it'd be worth enabling there.
XPU / HPU / TPU — these can't use the OffloadingConnector CPU-offload path at all today? or am I missing something there?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XPU / HPU / TPU — these can't use the OffloadingConnector CPU-offload path at all today? or am I missing something there?

Right. I'm just wondering if this can lead to an easy path to support offloading on these platforms using this triton kernel.

@Etelis Etelis May 11, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I get it.

The kernel itself would run on ROCm, but not on HPU/TPU (no Triton)

XPU might be but hacky.

@orozery

orozery commented May 10, 2026

Copy link
Copy Markdown
Collaborator

Thanks @Etelis !
Can you compare CPU load (CPU->GPU) with this vs. cudaMemcpyBatch+Cross layers on gpt-oss models?
For GPU->CPU I think cudaMemcpyBatch is always better?

@orozery

orozery commented May 10, 2026

Copy link
Copy Markdown
Collaborator

Also, I think that we can select which swap_blocks function to use on init (instead of per-transfer) after examining the minimum size of Refs' page_size_bytes.

EtelisIBM added 3 commits May 10, 2026 16:03
…eneous sizes

Adds an explicit gpu_to_cpu kwarg used to gate the Triton fast path so
only CPU->GPU reads take it; GPU->CPU writes always defer to the C++
DMA path. Generalizes the kernel to handle non-uniform per-job sizes
(each job loads its own size from the sizes tensor) and adds an
n >= 16 batch-size floor so n=1 calls don't take the fast path.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
…constants

Wrapper now lives next to its only caller. triton_swap keeps the
kernel and the empirically-tuned constants (_NUM_SMS, _THRESHOLD_BYTES,
_MIN_N) with a comment explaining how each was derived.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Move the gate (direction, page-size threshold, 8-byte alignment) and
the chunk-size computation out of the per-call path. The handler
binds either ops.swap_blocks_batch or a Triton closure with chunk
pre-baked, so per-call work shrinks to a single n>=_MIN_N check.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@ivanium

ivanium commented May 10, 2026

Copy link
Copy Markdown
Collaborator

What's the e2e performance for the tested models? I deprecated the SM-based approach because I found that copy kernels will contend SMs and L1/L2 caches with concurrent GPU kernels and slow down the computation.
In my tests, the CPU overhead of cudaMemcpy is actually fine if we put them into a background thread.

@Etelis Etelis changed the title [Perf] Triton fast path for swap_blocks_batch on small uniform batches [Perf] Triton fast path for small CPU→GPU swap_blocks_batch in the offloading connector May 11, 2026
@Etelis

Etelis commented May 11, 2026

Copy link
Copy Markdown
Contributor Author

What's the e2e performance for the tested models? I deprecated the SM-based approach because I found that copy kernels will contend SMs and L1/L2 caches with concurrent GPU kernels and slow down the computation.
In my tests, the CPU overhead of cudaMemcpy is actually fine if we put them into a background thread.

Thanks @ivanium — full detail (cliff charts, threshold/SM sweeps, complete tables) is in the updated PR description;
I must have not explained myself correctly before.

e2e — gpt-oss-120B / TP=4, OffloadingConnector, cache-hit-heavy workload (long shared prefix, short generations → the onload is on the TTFT critical path), kernel on vs off:

HMA req/s TTFT p50 TPOT p99
on 29.9 → 32.2 (+7 %) 628 → 476 ms (−24 %) 498 → 318 (−36 %)
off 26.4 → 31.6 (+20 %) 761 → 499 ms (−35 %) 432 → 316 (−27 %)

The onload transfer goes ~3.8 → ~18–19 GB/s (4 KiB descriptors, right on the cuMemcpyBatchAsync small-descriptor cliff); TPOT/ITL improve.

wdyt? it looks decent? do you want other benchmarking or models?

kfirtoledo added a commit to kfirtoledo/vllm that referenced this pull request May 12, 2026
…atch

Squashed cherry-pick of vllm-project#42212 (Itay Etelis).
Adds Triton fast path for small uniform CPU->GPU swap_blocks_batch,
gated at handler init.

Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Kfir Toledo <kfir.toledo@ibm.com>

@orozery orozery left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Etelis !
Can we add a unit test to the triton swap?
Also, need to be sure we're not breaking AMD.

Comment thread vllm/v1/kv_offload/cpu/gpu_worker.py Outdated

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.triton_utils import triton

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude comment:

The import from vllm.triton_utils import triton is unconditional. If running on a platform where Triton isn't
  available (e.g., ROCm older builds, CPU-only), this will fail at import time even for GPU→CPU handlers that wouldn't
  use the Triton path. Consider guarding with a lazy import or checking triton is not None in _select_swap_blocks_fn.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done,
Thank you.

Comment thread vllm/v1/kv_offload/cpu/gpu_worker.py Outdated
@ivanium

ivanium commented May 13, 2026

Copy link
Copy Markdown
Collaborator

What's the e2e performance for the tested models? I deprecated the SM-based approach because I found that copy kernels will contend SMs and L1/L2 caches with concurrent GPU kernels and slow down the computation.
In my tests, the CPU overhead of cudaMemcpy is actually fine if we put them into a background thread.

Thanks @ivanium — full detail (cliff charts, threshold/SM sweeps, complete tables) is in the updated PR description; I must have not explained myself correctly before.

e2e — gpt-oss-120B / TP=4, OffloadingConnector, cache-hit-heavy workload (long shared prefix, short generations → the onload is on the TTFT critical path), kernel on vs off:

HMA req/s TTFT p50 TPOT p99
on 29.9 → 32.2 (+7 %) 628 → 476 ms (−24 %) 498 → 318 (−36 %)
off 26.4 → 31.6 (+20 %) 761 → 499 ms (−35 %) 432 → 316 (−27 %)
The onload transfer goes ~3.8 → ~18–19 GB/s (4 KiB descriptors, right on the cuMemcpyBatchAsync small-descriptor cliff); TPOT/ITL improve.

wdyt? it looks decent? do you want other benchmarking or models?

Nice results! I am actually also curious about the overhead in the no-cache-hit settings (e.g., random workload), where the SM interference can be more visible. Also for results here, is cudaMemcpyBatchAsync called in a background thread?

EtelisIBM and others added 3 commits May 17, 2026 20:13
…ocks_fn

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
The (src, dst, sizes) descriptor tensors were built with torch.from_numpy
(pageable), so `.to("cuda", non_blocking=True)` on the Triton swap path
silently fell back to a synchronous copy. Back them with pinned host memory
and fill via numpy views, so the per-swap H2D actually overlaps. No extra
copy: the numpy views share the pinned buffers.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis requested a review from orozery May 24, 2026 18:56
@orozery

orozery commented Jun 2, 2026

Copy link
Copy Markdown
Collaborator

The test compared against a device->device ops.swap_blocks_batch call, which cuMemcpyBatchAsync rejects with CUDA_ERROR_INVALID_VALUE. Validate the Triton kernel's output against the source bytes directly instead.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis

Etelis commented Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

Sorry all passing now
@orozery

@orozery orozery merged commit 1fa9ea0 into vllm-project:main Jun 3, 2026
59 checks passed
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
…offloading connector (vllm-project#42212)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
pawel-olejniczak added a commit to pawel-olejniczak/vllm-gaudi that referenced this pull request Jun 5, 2026
Root cause: upstream PR #42212 added a pinned descriptor-buffer pool (self._buffer_pool) and Transfer.batch_* fields to SingleDirectionOffloadingHandler. The HPU plugin only overrides __init__ and transfer_async, so the inherited upstream get_finished/shutdown access self._buffer_pool (never initialized in the HPU __init__), raising AttributeError and an EngineDeadError in the CPU offloading test.

Upstream: vllm-project/vllm#42212

Fix: add HPU-specific get_finished/shutdown overrides consistent with the HPU Transfer dataclass and __init__ (no buffer-pool/batch_* usage).

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
JisoLya pushed a commit to JisoLya/vllm that referenced this pull request Jun 5, 2026
…offloading connector (vllm-project#42212)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Signed-off-by: JisoLya <523420504@qq.com>
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
…offloading connector (vllm-project#42212)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
adobrzyn added a commit to vllm-project/vllm-gaudi that referenced this pull request Jun 9, 2026
…de HPU offloading handler get_finished/shutdown (+2 more) (#1525)

This PR is the rolling hourly-CI fix PR. It consolidates 3 fixes against
vllm@`4efd6ffde09477800294a8ed9cc752017812c3b1` per the
single-rolling-PR rule (invariant I9).

## Bug 1: Override HPU offloading handler get_finished/shutdown

- **State machine id**: offloading_handler_buffer_pool_attrerror
- **Commit**: e05b5e9

### Root cause
Upstream vLLM PR #42212 added a pinned descriptor-buffer pool
(`self._buffer_pool`) and `Transfer.batch_*` fields to
`SingleDirectionOffloadingHandler`. The HPU plugin only overrides
`__init__` and `transfer_async`, so the inherited upstream
`get_finished`/`shutdown` reach for `self._buffer_pool`, which the HPU
`__init__` never initializes. This raised an `AttributeError` (surfacing
as `EngineDeadError`) in the CPU offloading test.

### Upstream PR
vllm-project/vllm#42212

### Fix
Add HPU-specific `get_finished` and `shutdown` overrides that are
consistent with the HPU `Transfer` dataclass and `__init__` — recycling
streams/events from the HPU pools and clearing HPU-local state, without
touching the upstream buffer-pool / `batch_*` machinery.

## Bug 2: Fix minimax_m2 import after mamba LINEAR refactor

- **State machine id**: mamba_linear_attn_import_missing
- **Commit**: e0a7774

### Root cause
Upstream vLLM moved `MiniMaxText01RMSNormTP` out of
`vllm.model_executor.layers.mamba.linear_attn` into a dedicated module
`vllm.model_executor.layers.minimax_rms_norm`. The HPU `minimax_m2`
model is eagerly imported by `register_model()`, so the stale import
broke every CI test at import time.

### Upstream PR
vllm-project/vllm#43556

### Fix
Update the import of `MiniMaxText01RMSNormTP` to
`vllm.model_executor.layers.minimax_rms_norm`.

## Bug 3: Fix multi_model_api_server imports after serving-utils
consolidation

- **State machine id**: multi_model_entrypoints_logger_missing
- **Commit**: a9ba162

### Root cause
Upstream vLLM consolidated the online serving utilities, removing
`entrypoints/logger.py`, `entrypoints/openai/server_utils.py` and
`entrypoints/utils.py`. The HPU multi-model API server imported three
symbols from those removed modules, breaking the unit-test import path.

### Upstream PR
vllm-project/vllm#44479

### Fix
Repoint three imports in `multi_model_api_server.py` to the consolidated
locations: `serve.utils.request_logger` (RequestLogger),
`serve.utils.server_utils` (get_uvicorn_log_config) and
`serve.utils.api_utils` (cli_env_setup, process_lora_modules).

## HPU verification
- Pod: Gaudi g3
- Full commit stack (`origin/main..HEAD`) re-verified against
vllm@`4efd6ffde09477800294a8ed9cc752017812c3b1`: import-clean with the
HPU platform plugin active (minimax_m2, multi_model_api_server,
kv_offload cpu_hpu, and register_model all load).

## Related PRs
None

---------

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com>
waqahmed-amd-fi pushed a commit to waqahmed-amd-fi/vllm that referenced this pull request Jun 10, 2026
…offloading connector (vllm-project#42212)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Signed-off-by: Waqar Ahmed <waqar.ahmed@amd.com>
Saddss pushed a commit to Saddss/vllm that referenced this pull request Jun 14, 2026
…offloading connector (vllm-project#42212)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants