Skip to content

Hybrid KV offload: planner, MultiConnector, and mamba alignment for hybrid models#38261

Open
malaiwah wants to merge 35 commits into
vllm-project:mainfrom
malaiwah:codex/hybrid-kv-offload
Open

Hybrid KV offload: planner, MultiConnector, and mamba alignment for hybrid models#38261
malaiwah wants to merge 35 commits into
vllm-project:mainfrom
malaiwah:codex/hybrid-kv-offload

Conversation

@malaiwah

Copy link
Copy Markdown

Summary

Enables external KV cache offloading for hybrid models (mamba + attention) like Qwen3.5. The stock offload path requires LCM of all group block sizes, which is impractical when mamba groups have very different sizes from attention groups.

Core changes

HybridOffloadPlanner (v1/kv_offload/planner.py):

  • Configurable hybrid_chunk_size splits groups where gpu_block_size % chunk_size == 0
  • Per-group coverage tracking, binary search for chunk counting
  • Handles groups that can't be split (mamba with block_size=max_model_len in non-align mode)

MultiConnector (multi_connector.py):

  • Wraps multiple KV connectors (e.g., LMCache CPU + llm-d disk) via MultiConnector
  • Weighted load selection: matched_tokens × load_weight scoring
  • HMA support (SupportsHMA mixin) for hybrid memory allocator compatibility
  • Preemption handling compatible with stock vLLM's set[str] signature
  • Per-connector Prometheus metrics

Scheduler (scheduler.py):

  • Skip mamba block alignment during async KV load
  • Handle disagreeing KV group prefix lengths gracefully (warn + use minimum)

Metrics (loggers.py):

  • Clamp negative prompt_tokens_by_source values that crash Prometheus counters under concurrent external cache hits

Test plan

Validated on Qwen3.5-4B-FP8 (24 mamba + 8 attention layers, RTX 4080 Super):

Closes #38230

AI-assisted: developed with Claude. All changes reviewed and tested by a human.

🤖 Generated with Claude Code

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces hybrid KV cache offloading (HMA) support, enabling more granular management of KV cache groups like Mamba and Attention. It implements a HybridOffloadPlanner to handle fixed-size offload units, a HybridChunkBlockHashList for multi-group hashing, and updates the MultiConnector with weighted selection logic and per-connector metrics. The offloading scheduler and worker are enhanced to support partial block transfers, backpressure for concurrent I/O, and improved error handling for stale cache files. Additionally, the PR includes extensive unit tests for the new hashing and planning logic, and ensures robustness in metrics reporting by clamping negative token counts. I have no feedback to provide.

Comment on lines +33 to +34
if any(block_size <= 0 for block_size in self.gpu_block_sizes):
raise ValueError("gpu_block_sizes must be positive")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Raising a ValueError for non-positive gpu_block_sizes is critical. All GPU block sizes must be positive for correct memory allocation and transfer logic, and a non-positive value would lead to logical errors or crashes.

Comment on lines +131 to +137
assert all(
offloaded_block_size_int % gpu_block_size == 0
for gpu_block_size in self.gpu_block_size
), (
"If 'block_size' is specified in kv_connector_extra_config, "
"there must be at least one KV cache group, "
"and all groups must have the same block size."
"it must be divisible by every KV cache group block size."
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This assertion is critical for correctness. If block_size is specified in kv_connector_extra_config, it must be divisible by every KV cache group block size. Failure to meet this condition would lead to incorrect block calculations and memory management, potentially causing data corruption or crashes.

Comment on lines +27 to +30
raise ValueError(
"fixed_chunk_size must be greater than or equal to "
"hash_block_size"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Raising a ValueError when fixed_chunk_size is smaller than hash_block_size is critical. This condition would lead to incorrect hashing and chunking logic, potentially causing data corruption or cache misses. Enforcing this constraint ensures the integrity of the offloading mechanism.

Comment on lines +93 to +105
logger.error(
"Hybrid offloading is effectively disabled: "
"first_hashable_chunk_idx=%d requires %d tokens "
"but max_model_len=%d. No chunks can ever be "
"stored. Set max_model_len to a multiple of "
"hybrid_chunk_size=%d (e.g. %d).",
self.hybrid_planner.first_hashable_chunk_idx,
self.hybrid_planner.first_hashable_chunk_idx
* chunk_size_int,
max_model_len,
chunk_size_int,
(max_model_len // chunk_size_int) * chunk_size_int,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Logging an error when hybrid offloading is effectively disabled due to first_hashable_chunk_idx being too large relative to max_model_len is critical. This indicates a severe misconfiguration where no chunks can ever be stored, rendering the feature useless. The error message provides clear guidance for resolution.

Comment on lines +61 to +64
raise TypeError(
"OffloadingConnector requires metadata with reqs_to_load, "
"reqs_to_store, and reqs_to_flush fields."
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Raising a TypeError when OffloadingConnector receives metadata without the expected fields (reqs_to_load, reqs_to_store, reqs_to_flush) is critical. This ensures that the connector operates with valid metadata, preventing potential runtime errors or incorrect behavior due to malformed input.

Comment on lines +335 to +340
raise TypeError(
f"MultiConnector has HMA enabled but these child "
f"connectors do not support it: {non_hma}. Either "
f"use --disable-hybrid-kv-cache-manager or replace "
f"the non-HMA connectors."
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This TypeError is critical as it prevents the MultiConnector from operating in an invalid state where HMA is enabled but not all child connectors support it. This ensures the system's integrity and prevents potential runtime failures.

Comment on lines +73 to +82
logger.warning(
"KV group %d has gpu_block_size=%d which is not "
"divisible by hybrid_chunk_size=%d. This group "
"cannot be split into chunks and will require "
"%d tokens before any offloading can occur. "
"Consider setting max_model_len to a multiple "
"of hybrid_chunk_size.",
i, gbs, chunk_size_int,
gbs,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Logging a warning when gpu_block_size is not divisible by hybrid_chunk_size is important. This configuration can lead to inefficient offloading or unexpected behavior, as the group cannot be split into chunks as intended. The warning helps users understand the implications and adjust their max_model_len accordingly.

Comment on lines +120 to +123
logger.warning(
"offloading worker load submission failed for "
"req_id=%s job_id=%s (stale cache files?), "
"falling back to recompute",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Logging a warning when an offloading worker load submission fails is important. This indicates that a cache file might be stale or corrupted, and the system is falling back to recompute. This warning provides crucial information for debugging cache-related issues and understanding performance implications.

Comment on lines +354 to +359
logger.warning(
"KV groups disagree on computed prefix length: %s. "
"Using minimum (%d tokens) to avoid partial loads.",
computed_tokens_per_group,
num_computed_tokens,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Logging a warning when KV groups disagree on computed prefix length is important. This scenario can indicate issues with cache consistency or partial loads in hybrid models, which could lead to incorrect model behavior. The warning helps in debugging and identifying such discrepancies.

Comment thread vllm/v1/metrics/loggers.py Outdated
Comment on lines +1123 to +1126
logger.warning(
"Negative prompt_tokens_by_source[%s]=%d "
"(external KV transfer accounting skew), clamping to 0",
source, value,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Logging a warning for negative prompt_tokens_by_source values and clamping them to 0 is a crucial robustness improvement. Negative values can cause Prometheus counters to crash, and this fix prevents such failures, ensuring the stability of metrics reporting.

@mergify

mergify Bot commented Mar 26, 2026

Copy link
Copy Markdown
Contributor

Hi @malaiwah, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@malaiwah

Copy link
Copy Markdown
Author

Thanks for the pre-commit note. Will fix the formatting in the next push — this machine doesn't have a full dev environment (bare Ubuntu with GPU), so running pre-commit locally requires some setup. Working on it.

Re: Gemini's review — all the flagged items are intentional validations and safety checks. No changes needed there.

@malaiwah

Copy link
Copy Markdown
Author

Pre-commit failures addressed in the latest push:

  • typos: storeable_prefix_tokensstorable_prefix_tokens (planner.py and test)
  • ruff E501: wrapped long log format strings
  • mypy type errors:
    • scheduler.py: get_timed_out_loads guarded with hasattr check + type: ignore[attr-defined] (it's an extension method not on the base interface)
    • multi_connector.py: annotated reconstructed_data: dict[str, KVConnectorStats]; type: ignore[arg-type] for set[str] forwarded to handle_preemptions
    • offloading/scheduler.py: # type: ignore[no-redef] on variables re-declared after early return; renamed offloaded_hashes to simple_hashes in non-hybrid path; renamed group_sizes to src_group_sizes in else branch
    • cpu.py: block_size_factorblock_size_factors[0] (attribute was renamed to plural tuple in OffloadingSpec)
    • spec.py: type: ignore[arg-type] on int(hybrid_chunk_size) from dict.get()
    • test_offloading_connector.py: annotated extra_config: dict[str, Any]
  • ruff format: reformatted remaining long lines throughout

All 22 pre-commit hooks pass locally.

@mergify

mergify Bot commented Mar 27, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @malaiwah.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 27, 2026
@malaiwah malaiwah force-pushed the codex/hybrid-kv-offload branch from 0a24daa to 9bddaef Compare March 27, 2026 01:51
@malaiwah

Copy link
Copy Markdown
Author

Rebased onto latest main — one conflict in cpu/spec.py (the CPU offloading refactor moved cpu.pycpu/spec.py; resolved by applying the same block_size_factorblock_size_factors[0] fix to the new path). All other commits applied cleanly.

M. Belleau and others added 17 commits March 28, 2026 17:45
Instead of "first hit wins", MultiConnector now scores each
connector's hit as tokens * load_weight and picks the highest
score. This lets a fast CPU cache (high weight) beat a slow disk
cache unless the disk hit is substantially larger.

Configured via "load_weight" in each connector's
kv_connector_extra_config (default 1.0).

Also adds runtime HMA validation: if HMA is enabled, all child
connectors must support it or a clear TypeError is raised at init.

Co-authored-by: Claude
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Stock vLLM 0.18.0 calls handle_preemptions(preempted_req_ids: set[str])
but MultiConnector asserted MultiKVConnectorMetadata, crashing under
memory pressure when preemption is triggered.

Accept both: forward raw set[str] to all children, or unwrap
MultiKVConnectorMetadata per-child as before.

Same pattern as the OffloadingConnector preemption fix (8ca977d67).

Co-authored-by: Claude
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Expose five new metrics on /metrics so operators can see which child
connector is serving cache hits, how many tokens are being loaded, and
how many requests are waiting — broken down by connector name:

  vllm_kv_connector_queries_total{connector}
    Total lookup queries issued to each child connector.

  vllm_kv_connector_hits_total{connector}
    Requests where this connector won the weighted selection.

  vllm_kv_connector_hit_tokens_total{connector}
    Total matched tokens served by the winning connector.

  vllm_kv_connector_misses_total{connector}
    Requests where this connector had no cache hit.

  vllm_kv_connector_pending_loads{connector}  (gauge)
    Requests currently in-flight for an external load from this connector.

Previously all external hits were aggregated under
vllm:external_prefix_cache_hits_total with no connector breakdown,
making it impossible to distinguish LMCache (CPU) vs llm-d (disk)
contributions.

Co-authored-by: Claude
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
…line

Module-level prometheus_client Counter/Gauge objects registered in the
EngineCore subprocess do not appear in the APIServer's /metrics endpoint.
Replace them with a _SelectionStats dataclass that accumulates per-connector
queries/hits/hit_tokens/misses in the scheduler and flows the data through
vLLM's existing KVConnectorStats pipeline:

  EngineCore: MultiConnector.get_kv_connector_stats()
    → MultiKVConnectorStats["__selection__"] = _SelectionStats
  cross-process (msgspec pickle): stats.data sent in SchedulerStats
  APIServer: MultiKVConnectorPromMetrics.observe()
    → _observe_selection() → Prometheus Counters registered in APIServer

New metrics (all labelled with model_name, engine, connector):
  vllm:kv_connector_mc_queries_total
  vllm:kv_connector_mc_hits_total
  vllm:kv_connector_mc_hit_tokens_total
  vllm:kv_connector_mc_misses_total

Co-authored-by: Claude
Signed-off-by: mbelleau
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
offload_unit_sizes, first_hashable_chunk_idx, and group_hash_factors were
recomputed on every property access.  offload_unit_sizes in particular is
called inside the chunk_count_for_tokens binary search loop (via
group_covered_tokens_for_chunk_count), which runs per-request per-step.

Pre-compute all three in __post_init__ using object.__setattr__ (the
standard pattern for frozen dataclasses with derived cached values).
Properties now return the cached tuple directly — O(1) instead of O(groups).

Co-authored-by: Claude
Signed-off-by: mbelleau
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
_get_block_hashes returns a lazy iterator (islice over
HybridChunkBlockHashList).  The previous code called it twice with the
same start/end indices: once for prepare_load (which fully consumed the
iterator) and a second time to update _reqs_being_loaded.

Materialise to a list on the first call so both consumers share the
same object.  This removes one full HybridChunkBlockHashList
construction and one islice traversal per load-scheduled request.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
…duler

Previously, every call to _get_block_hashes created a new
HybridChunkBlockHashList with fresh RequestBlockHashList instances.
Those per-group lists lazily cache computed hashes, but the cache was
discarded at the end of each call, forcing a full recomputation from
token index 0 on every subsequent call — including the 2–3 calls per
request per scheduler step in get_num_new_matched_tokens.

For a 23k-token prompt with attention block_size=1056, each fresh
RequestBlockHashList had to compute up to ~22 group-level hashes per
call.  With 2 groups and 3 calls per step, that is ~130 hash_function
invocations that can be avoided after the first step.

Fix: store one HybridChunkBlockHashList per active request in
_hybrid_hash_lists.  The same instance is returned on subsequent calls,
so RequestBlockHashList's internal cache is reused.  The instance is
seeded lazily on the first _get_block_hashes call and cleaned up in
request_finished.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
_get_value_at previously recomputed the combined chunk hash (a
hash_function call over a tuple of per-group BlockHashes) on every
invocation, even for indices that had been computed before.

Add _chunk_hashes, a lazily-grown list mirroring RequestBlockHashList's
_hashes cache.  Sequential accesses (the common case via islice) are
served from the list after the first visit.  Out-of-order accesses skip
caching to keep the list dense.

Combined with the scheduler-level _hybrid_hash_lists cache (which
preserves the HybridChunkBlockHashList instance across scheduler steps),
this eliminates all redundant hash_function calls for previously-seen
chunk indices — both within a step and across steps.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
The assert inside _build_gpu_transfer_spec_from_chunk_range checked
that gpu_block_size % unit_size == 0 on every call, for every group.
Both values are constants derived from the spec at construction time, so
the check only needs to run once.

Move it to __init__ where it fails early with a clear message at startup
rather than silently passing (when asserts are optimised away with -O)
or checking repeatedly on the hot path.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Verify that _chunk_hashes grows as indices are accessed and that
repeated accesses return the cached value without growing the cache
further.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Stock vLLM 0.18.0 guards _mamba_block_aligned_split with:
  assert num_external_computed_tokens == 0, "External KV connector is not verified yet"

This blocks any external KV cache connector (including our hybrid
offloading path) from working with mamba/hybrid models.  The function
already correctly adds external tokens to num_computed_tokens (line 298-
301) and the alignment logic is agnostic to the source of those tokens —
it only cares about the total count to decide block-aligned splitting.

External tokens are loaded into GPU KV cache before the forward pass;
from the model's perspective they are indistinguishable from locally
computed tokens.  The mamba state for external chunks is already
populated at block boundaries by the offloading connector's store/load
path.

Validated: Qwen3.5-397B-A17B-NVFP4 on 4×GPU TP=4, the crash occurred
after ~23 minutes of successful serving when a request finally triggered
a cache hit path through the mamba block alignment codepath.

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
…h logs to DEBUG

- Remove _ensure_transfer_supported() stub (body was just `return`) and its
  two call sites in update_state_after_alloc and _get_reqs_to_store.
- Downgrade per-transfer-job logger.info calls in offloading/worker.py to
  DEBUG; these fire on every store/load submission and completion and are
  too noisy at INFO level under real inference load.

Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
…nment

- MultiConnector: don't block all connectors when one returns None
  (backpressured). Let resolved connectors answer immediately; only
  defer if ALL connectors are unresolved.
- OffloadingConnectorWorker: graceful fallback when load submission
  fails (stale cache files) instead of assert. Reports failed loads
  as finished so the scheduler falls back to recompute.
- Scheduler: skip mamba block-aligned split during async KV load
  (num_new_tokens is intentionally 0, not a real prefill).

Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
When some groups fail file size validation on load (e.g., attention
group kernel block size changed between restarts), the scheduler
now takes the minimum computed tokens across groups and warns
instead of asserting.  The request falls back to recomputing the
unloaded portion.

Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
When KV groups disagree on computed prefix length (e.g., attention
group loaded from disk but mamba groups didn't), the fallback must
set both num_computed_tokens AND num_external_tokens to 0.
Previously only num_computed_tokens was zeroed, leaving
num_external_tokens non-zero which failed the
chunk_prefix_tokens round-trip assertion.

Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
When a load from external storage (NFS, shared disk) takes longer
than load_timeout_seconds (default 30s), cancel the wait and fall
back to recompute.  Prevents requests from stalling indefinitely
on slow or hung storage.

The timeout is checked each scheduler step in
_try_promote_blocked_waiting_request.  Timed-out loads are marked
as failed, their block hashes released, and the request re-enters
the WAITING queue for recompute.

Configurable via kv_connector_extra_config:
  "load_timeout_seconds": 30.0

Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
- typos: rename storeable_prefix_tokens → storable_prefix_tokens
  in planner.py and test_planner.py
- ruff E501: break long log format strings in offloading/worker.py
  and offloading/scheduler.py
- mypy [no-redef]: annotate variables redefined after early return
  in _build_gpu_transfer_spec_from_chunk_range
- mypy [attr-defined]: add type: ignore for hasattr-guarded accesses
  and rename offloaded_hashes variable in non-hybrid path to avoid
  type mismatch with HybridChunkBlockHashList | None
- mypy [assignment]: annotate reconstructed_data as dict[str, KVConnectorStats]
  in MultiKVConnectorStats.from_dict_data; annotate extra_config as
  dict[str, Any] in test helper
- mypy [arg-type]: guard get_timed_out_loads with hasattr check in
  scheduler; type: ignore on set[str] forwarded to handle_preemptions
- cpu.py: use block_size_factors[0] instead of removed block_size_factor
- spec.py: add type: ignore[arg-type] on int(hybrid_chunk_size) where
  extra_config.get() can return Any | None
- ruff format: reformat long lines throughout

Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
@malaiwah

Copy link
Copy Markdown
Author

Rebased onto latest main (148 upstream commits since last rebase). All conflicts resolved:

  • tests/v1/kv_connector/unit/test_offloading_connector.py → deleted (upstreamed into offloading_connector/ subdir)
  • vllm/v1/kv_offload/spec.py / cpu_gpu.py → import merge (both sides kept)

Cross-node + cross-restart validation — 10/10 PASS (Creativity RTX 4080 Super SM 8.9 ↔ AIBoss RTX 5090 SM 12.0, lovedheart/Qwen3.5-4B-FP8):

Phase Description Result
A Creativity cold start: 9325-token prompt, NFS miss, correct output
B Creativity APC warm: 8448 cached tokens reused, same answer
C NFS namespace: Creativity wrote 0 new sm_120 files
D AIBoss cold: sm_120 ≠ sm_89, NFS miss, no garbage output
E AIBoss APC warm: not poisoned by garbled 1st response
F Namespace isolation: sm_89 and sm_120 coexist under same group
G Creativity restart: 7392 cached tokens loaded from sm_89 NFS
H AIBoss restart: 7392 cached tokens loaded from sm_120 NFS
I LMCache hybrid detection fires in both Worker and EngineCore
J Concurrent cross-host requests: 97% word overlap, no interference

Both containers now run localhost/vllm-trifecta-cu130:multiarch (same image, no bind mounts). Container image includes this PR's vllm overlay + llm-d #467 + LMCache #2879.

…solution

After rebasing onto upstream/main, the automated conflict resolver merged both
sides of import conflicts, leaving:

- vllm/v1/kv_offload/spec.py: unused AttentionBackend import
- vllm/v1/kv_offload/worker/cpu_gpu.py: duplicate BlockIDsLoadStoreSpec import
  and unused AttentionBackend import, causing I001/F401/F811
- vllm/v1/kv_offload/cpu/spec.py: get_handlers() passed undefined attn_backends
  and gpu_block_size to CpuGpuOffloadingHandlers constructor (F821); also used
  removed self.block_size_factor (singular) instead of self.block_size_factors[0]

All ruff checks now pass on all changed files.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
@malaiwah

Copy link
Copy Markdown
Author

Pushed a follow-up commit (77160a8) fixing three ruff errors introduced during the rebase merge resolution:

  • vllm/v1/kv_offload/spec.py: removed unused AttentionBackend import (F401)
  • vllm/v1/kv_offload/worker/cpu_gpu.py: deduplicated BlockIDsLoadStoreSpec import and removed unused AttentionBackend (F811, F401, I001)
  • vllm/v1/kv_offload/cpu/spec.py: removed undefined attn_backends/gpu_block_size kwargs from CpuGpuOffloadingHandlers() constructor call (F821); used self.block_size_factors[0] (plural tuple) consistently

All pre-commit hooks pass locally: ruff check, ruff format, typos, mypy, and repo-specific checks.


Re the pre-run-check gate: the failure message is "PR must have the ready label or the author must have at least 4 merged PRs (found 0)". Could a maintainer please add the ready label so the full CI suite can run?

The Intel XPU failures (xpu-example-test, xpu-v1-test) appear to be pre-existing — they don't touch any Intel-specific code paths and the same failures appear on other recent external-contributor PRs. Happy to be corrected if they're related to our changes.

The merge-both conflict resolver during rebase onto upstream/main
incorrectly appended old test_offloading_connector.py content to
the upstream's offloading_connector/utils.py, producing syntax errors
from line 567 onward. Restore to the upstream version.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
@malaiwah

Copy link
Copy Markdown
Author

Latest push fixes a mangled tests/v1/kv_connector/unit/offloading_connector/utils.py that resulted from the rebase conflict resolver incorrectly appending old test content to the upstream file. All pre-commit hooks now pass locally (ruff, typos, mypy, SPDX headers, forbidden imports — full run).

CI gate: the pre-run-check requires either the ready label or ≥ 4 merged PRs (we have 0). Could a maintainer add ready so the full CI suite can run? Tagging @njhill @russellb @mgoin — any of you able to add it?

Intel XPU failures (xpu-example-test, xpu-v1-test): these appear to be pre-existing flakiness unrelated to our changes (pure Python KV connector code, no XPU-specific paths touched). Happy to investigate further if a maintainer thinks otherwise.

@mergify

mergify Bot commented Mar 31, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @malaiwah.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify

mergify Bot commented May 7, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @malaiwah.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@waynehacking8

Copy link
Copy Markdown
Contributor

This PR is the fix a real user is currently blocked on — see LMCache/LMCache#3655 (and #45268 / #3655's thread): Qwen3.6-class hybrid + --kv-offloading-backend native hits the assert num_external_computed_tokens == 0 in _mamba_block_aligned_split, which this PR removes. The earlier #42554 scheduler hunk doesn't cover the fp8 hybrid case, whereas this PR's description notes validation on Qwen3.5-4B-FP8 — so it looks like the right fix for the fp8 hybrids too.

Two notes:

  • The branch is currently CONFLICTING against main and its base predates the newer qwen3_5_moe (Qwen3.6) architecture, so it can't be exercised against current hybrid models without a rebase. Any plan to rebase onto main?
  • Once it's rebased, I'm happy to validate it on Blackwell (RTX PRO 6000 / SM120) with a current fp8 hybrid + --kv-offloading-backend native — the description's validation was on an RTX 4080 (SM89), so a Blackwell data point would complement it. Just ping me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models intel-gpu Related to Intel GPU kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) needs-rebase new-model Requests to new models nvidia performance Performance-related issues qwen Related to Qwen models rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1

Projects

Status: Todo
Status: No status
Status: No status
Status: No status
Status: To Triage

Development

Successfully merging this pull request may close these issues.

Hybrid KV offload: MultiConnector + planner for mamba+attention models

3 participants