Skip to content

[Core] Add register_model() to KVConnectorBase_V1 for CacheBlend#37339

Open
zbennett10 wants to merge 9 commits intovllm-project:mainfrom
WorldFlowAI:semblend/register-model-hook
Open

[Core] Add register_model() to KVConnectorBase_V1 for CacheBlend#37339
zbennett10 wants to merge 9 commits intovllm-project:mainfrom
WorldFlowAI:semblend/register-model-hook

Conversation

@zbennett10
Copy link
Copy Markdown

Purpose

LMCache's CacheBlend feature requires access to the loaded model weights for selective layer recomputation via VLLMModelTracker.register_model(). However, vLLM never passes the model reference to KV connectors — register_kv_caches() only receives kv_caches: dict[str, torch.Tensor]. This means CacheBlend's LMCBlenderBuilder.get_or_create() fails at runtime when it calls VLLMModelTracker.get_model() because the model was never registered.

This PR adds a new register_model(model) method to KVConnectorBase_V1 (no-op by default). This avoids changing the register_kv_caches signature, which would require updating 16+ connector implementations.

Changes:

  • KVConnectorBase_V1.register_model() — no-op default
  • ActiveKVConnector.__init__() — calls connector.register_model(model) after register_kv_caches()
  • get_kv_connector() — accepts optional model param, passes to ActiveKVConnector
  • GPUModelRunner.init_kv_caches() — passes self.model to get_kv_connector()
  • LMCacheConnectorV1.register_model() — calls VLLMModelTracker.register_model(ENGINE_NAME, model)
  • LMCacheConnectorV1Impl.register_model() — same (bundled adapter)
  • MultiConnector.register_model() — delegates to sub-connectors
  • LMCacheMPConnector.register_model() — delegates to worker adapter

All other connectors (NIXL, Mooncake, FlexKV, offloading, etc.) inherit the no-op default — zero behavior change.

Test Plan

pytest tests/v1/kv_connector/test_register_model.py -v

Tests cover:

  • Base class has register_model and it's a no-op
  • ActiveKVConnector calls register_model when model is provided
  • ActiveKVConnector skips register_model when model is None
  • get_kv_connector passes model through to ActiveKVConnector
  • get_kv_connector returns NoOp when no transfer group
  • LMCacheConnectorV1.register_model calls VLLMModelTracker
  • LMCacheConnectorV1.register_model handles ImportError gracefully
  • MultiConnector.register_model delegates to all sub-connectors

Test Result

All 9 tests pass. Ruff lint passes on all 8 changed files.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively adds a register_model() method to the KVConnector interface, enabling features like LMCache's CacheBlend to access model weights. The changes are well-implemented, maintaining backward compatibility by providing a no-op default in the base class and correctly propagating the model reference through the call stack. The inclusion of comprehensive tests is commendable. I have one suggestion to improve consistency and reduce code duplication in the LMCacheConnectorV1 implementation.

Comment on lines +136 to +159
def register_model(self, model: "torch.nn.Module") -> None:
"""Register model with LMCache's VLLMModelTracker for CacheBlend.

CacheBlend's blender needs access to model weights for selective
layer recomputation. This method is called automatically by vLLM
after model loading.
"""
try:
from lmcache.v1.compute.models.utils import VLLMModelTracker

from vllm.distributed.kv_transfer.kv_connector.v1.\
lmcache_integration.utils import ENGINE_NAME
VLLMModelTracker.register_model(ENGINE_NAME, model)
logger.info("Registered model with LMCache VLLMModelTracker")
except ImportError:
logger.debug(
"LMCache CacheBlend model registration not available"
)
except Exception:
logger.warning(
"Failed to register model with VLLMModelTracker",
exc_info=True,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The register_model method is implemented directly in LMCacheConnectorV1, but other methods in this class (e.g., register_kv_caches) delegate to self._lmcache_engine. This is inconsistent with the class's design and leads to code duplication, as the same register_model logic is also added to LMCacheConnectorV1Impl in vllm_v1_adapter.py.

To improve consistency and maintainability, register_model should delegate the call to self._lmcache_engine. This also makes the implementation in LMCacheConnectorV1Impl effective and avoids having the logic in two places.

    def register_model(self, model: "torch.nn.Module") -> None:
        """Register model with LMCache's VLLMModelTracker for CacheBlend.

        Delegates to the underlying LMCache engine implementation.
        """
        if hasattr(self._lmcache_engine, "register_model"):
            self._lmcache_engine.register_model(model)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Updated in a26cd78: LMCacheConnectorV1.register_model now delegates to self._lmcache_engine.register_model() with a hasattr guard (same pattern as register_kv_caches). The VLLMModelTracker logic lives solely in LMCacheConnectorV1Impl. Tests updated accordingly.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 17, 2026

Hi @zbennett10, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

vLLM never passes the model reference to KV connectors, so LMCache's
CacheBlend blender cannot access model weights for selective layer
recomputation. VLLMModelTracker.register_model() is never called.

Add register_model(model) to KVConnectorBase_V1 as a separate method
(no-op by default). ActiveKVConnector calls it after register_kv_caches
when model is available. LMCache connectors override to call
VLLMModelTracker.register_model(ENGINE_NAME, model).

All other connectors inherit the no-op default — zero behavior change.

Signed-off-by: Zach Bennett <zach@worldflow.ai>
Signed-off-by: Zachary Bennett <bennett.zachary@outlook.com>
…torV1

- Remove duplicated VLLMModelTracker logic from LMCacheConnectorV1.register_model
- Delegate to self._lmcache_engine.register_model() using hasattr guard (consistent with register_kv_caches pattern)
- The implementation lives in LMCacheConnectorV1Impl (vllm_v1_adapter.py)
- Update tests: LMCacheConnectorRegisterModel tests now verify delegation; add TestLMCacheConnectorV1ImplRegisterModel for VLLMModelTracker logic

Signed-off-by: Zachary Bennett <bennett.zachary@outlook.com>
Signed-off-by: Zachary Bennett <bennett.zachary@outlook.com>
@zbennett10 zbennett10 force-pushed the semblend/register-model-hook branch from e1d79f1 to 86d9f44 Compare March 18, 2026 02:46
@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Mar 18, 2026

@zbennett10 Can you somehow get the weights from the vllm config?
i.e. something like:

attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
    for name, param in attn_module.named_parameters():
        ...

@zbennett10
Copy link
Copy Markdown
Author

@zbennett10 Can you somehow get the weights from the vllm config? i.e. something like:

attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
    for name, param in attn_module.named_parameters():
        ...

@orozery Thank you for the quick feedback. With the amount of PRs going into this repo.. I don't know how you all keep up :)

get_layers_from_vllm_config is exactly the right pattern for most KV
connector use cases. We investigated this and ran into one limitation for CacheBlend specifically.

The Attention objects stored in static_forward_context (registered in
attention/attention.py) are the attention backends — they contain the attention
implementation, KV cache metadata, and quantization config, but not qkv_proj or
o_proj. LMCache's CacheBlend needs access to those projection weights for selective
layer recomputation.

The believe the access pattern in LMCache is:

 # In LMCBaseModel.__init__
 for i in range(self.num_layers):
     layer = vllm_model.model.layers[i]
     qkv, _ = layer.self_attn.qkv_proj(hidden_states)  # needs qkv_proj
     rotary = layer.self_attn.rotary_emb               # needs rotary_emb

Since qkv_proj lives in the parent self_attn block (not in the Attention backend that
gets registered in static_forward_context), we can't satisfy this via vllm_config alone.

Two options if you'd prefer to avoid threading model through:

  1. Keep register_model(model) as-is — it's called once at init, no overhead at inference time
  2. Store vllm_config in LMCache's connector and call get_layers_from_vllm_config lazily
    on first CacheBlend use, then walk up from the Attention object to its parent
    (though parent_module traversal is not straightforward in PyTorch.. )

Happy to go with whichever approach you prefer. If there's a way to reach the parent
self_attn blocks via vllm_config that I'm not seeing, we can rework this... just let me know. Thanks :)

@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Mar 18, 2026

I'm actually in favor of passing information of the front door, instead of having connectors hack their way to get the information they need :)
The things is we keep adding new functions very specific functions to the connector API.
Ideally, we would add the model to register_kv_caches and rename it appropriately.
As you said, this will break all existing connectors which is undesirable.
At a certain point we will redesign a clean connector API that will consolidate all existing APIs more cleanly, and will be able to capture future APIs neatly without breaking changes.
I think that the short-term solution is that we add new APIs via a generic as possible abstraction.
For example, instead of introducing register_model, we can do introduce something like

@dataclass
class WorkerConnectorInitializationData:
    model: "torch.nn.Module | None" = None

@dataclass
class WorkerConnectorInitlizationResponse:
   pass

def initialize_worker_connector(
   self, initialization_data: WorkerConnectorInitializationData
) -> WorkerConnectorInitlizationResponse:
   pass

This API can be later extended both in information flowing from the connector and to the connector, without breaking existing connectors.

…ta pattern

Replace the specific `register_model(model: nn.Module)` method on
KVConnectorBase_V1 with a generic dataclass-based initialization API:

  @DataClass
  class WorkerConnectorInitializationData:
      model: torch.nn.Module | None = None

  @DataClass
  class WorkerConnectorInitializationResponse:
      pass

  def initialize_worker_connector(
      self,
      initialization_data: WorkerConnectorInitializationData,
  ) -> WorkerConnectorInitializationResponse

This addresses reviewer feedback (orozery) on PR vllm-project#37339. The dataclass
pattern is extensible without breaking existing connectors: new optional
fields can be added to WorkerConnectorInitializationData in the future
(e.g. vllm_config, attn_backend) without changing any connector signatures.

Changes:
- base.py: add dataclasses, replace register_model no-op with
  initialize_worker_connector returning WorkerConnectorInitializationResponse
- kv_connector.py: always call initialize_worker_connector(data), removing
  the conditional model-is-not-None guard (each connector decides)
- lmcache_connector.py: override, extract data.model, guard before
  forwarding to _lmcache_engine.register_model
- lmcache_mp_connector.py: delegate to worker_adapter
- multi_connector.py: fan out to all sub-connectors with same data object
- vllm_v1_adapter.py: override, extract data.model, call VLLMModelTracker
- test_register_model.py: fully updated for new API (16 tests)

Signed-off-by: Zach Bennett <zach@worldflowai.com>

Signed-off-by: Zachary Bennett <bennett.zachary@outlook.com>
@zbennett10
Copy link
Copy Markdown
Author

I'm actually in favor of passing information of the front door, instead of having connectors hack their way to get the information they need :) The things is we keep adding new functions very specific functions to the connector API. Ideally, we would add the model to register_kv_caches and rename it appropriately. As you said, this will break all existing connectors which is undesirable. At a certain point we will redesign a clean connector API that will consolidate all existing APIs more cleanly, and will be able to capture future APIs neatly without breaking changes. I think that the short-term solution is that we add new APIs via a generic as possible abstraction. For example, instead of introducing register_model, we can do introduce something like

@dataclass
class WorkerConnectorInitializationData:
    model: "torch.nn.Module | None" = None

@dataclass
class WorkerConnectorInitlizationResponse:
   pass

def initialize_worker_connector(
   self, initialization_data: WorkerConnectorInitializationData
) -> WorkerConnectorInitlizationResponse:
   pass

This API can be later extended both in information flowing from the connector and to the connector, without breaking existing connectors.

Great feedback.. check out this new interface. I think it maps much more cleanly - let me know what you think.

Comment on lines +801 to +830
def initialize_worker_connector(
self,
initialization_data: WorkerConnectorInitializationData,
) -> WorkerConnectorInitializationResponse:
"""Register model with LMCache's VLLMModelTracker for CacheBlend.

CacheBlend's blender needs access to model weights for selective
layer recomputation. Called automatically by vLLM after model
loading.
"""
model = initialization_data.model
if model is not None:
try:
from lmcache.v1.compute.models.utils import VLLMModelTracker

from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import ( # noqa: E501
ENGINE_NAME,
)

VLLMModelTracker.register_model(ENGINE_NAME, model)
logger.info("Registered model with VLLMModelTracker")
except ImportError:
logger.debug("LMCache CacheBlend model registration not available")
except Exception:
logger.warning(
"Failed to register model with VLLMModelTracker",
exc_info=True,
)
return WorkerConnectorInitializationResponse()

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orozery Example usage

@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Mar 19, 2026

Thanks!
I take it back a bit: we can skip WorkerConnectorInitlizationResponse for now as it is empty.

@NickLucche @njhill your thoughts about this API?
I like it. This will be useful also for passing in the new cross-layers KV caches (for hybrid models).

@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Mar 19, 2026

@NickLucche @njhill your thoughts about this API? I like it. This will be useful also for passing in the new cross-layers KV caches (for hybrid models).

Taking that even further, we can add kv_caches (from register_kv_caches) to this API as well, so that connectors can get rid of register_kv_caches.
When we see we have enough adoption to the new API we can consider removing register_kv_caches, as well as register_cross_layers_kv_caches.

@zbennett10 BTW your current implementation only applies this new API to v2 model runner. I think we want also to add it to v1 (gpu_model_runner.py).

@zbennett10
Copy link
Copy Markdown
Author

@NickLucche @njhill your thoughts about this API? I like it. This will be useful also for passing in the new cross-layers KV caches (for hybrid models).

Taking that even further, we can add kv_caches (from register_kv_caches) to this API as well, so that connectors can get rid of register_kv_caches. When we see we have enough adoption to the new API we can consider removing register_kv_caches, as well as register_cross_layers_kv_caches.

@zbennett10 BTW your current implementation only applies this new API to v2 model runner. I think we want also to add it to v1 (gpu_model_runner.py).

Updated - thanks

@zbennett10 zbennett10 force-pushed the semblend/register-model-hook branch from ecf9a82 to be4243c Compare March 19, 2026 13:59
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 19, 2026

Hi @zbennett10, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@zbennett10 zbennett10 force-pushed the semblend/register-model-hook branch from be4243c to b638f98 Compare March 19, 2026 15:12
@zbennett10
Copy link
Copy Markdown
Author

zbennett10 commented Mar 19, 2026

@orozery
Quick backstory on why this matters: We're building https://github.com/WorldFlowAI/semblend — an open-source semantic KV cache reuse provider for LLM inference. This repo is meant to be a reference implementation of a semantic caching provider that works with multiple inference engines and also houses the official implementation based on our in-progress "SemBlend" research paper. When two prompts share the same document content but have different instructions (cross-instruction variants), exact-prefix caching misses completely. SemBlend finds these semantic matches and reuses the donor's KV tensors, achieving 2-12x TTFT speedup on long contexts.

We're working across the inference ecosystem to define a clean interface for this:

This vLLM PR is a foundational piece — CacheBlend's selective layer recomputation (which SemBlend uses to maintain output quality after KV injection) needs model weights. I think that the initialize_worker_connector pattern you suggested is exactly right for this and will generalize well as connectors need more initialization context.

We're also collaborating with the NVIDIA cuVS team on GPU-accelerated vector search for semantic matching at fleet scale - they have agreed to help us co-author the "SemBlend" paper we are working on. Happy to share more details if helpful! This is just the beginning of semantic kv-cache reuse... :)

@zbennett10
Copy link
Copy Markdown
Author

@orozery Don't mean to pester you - need anything else from me with this one? Thanks :D

@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Mar 23, 2026

@orozery Don't mean to pester you - need anything else from me with this one? Thanks :D

This LGTM (the connector API change).
I have not looked into the lmcache-side implementation, probably @ApostaC @KuntaiDu should have a look.
Still looking to see what others think about the API change.

@zbennett10
Copy link
Copy Markdown
Author

@orozery Don't mean to pester you - need anything else from me with this one? Thanks :D

This LGTM (the connector API change). I have not looked into the lmcache-side implementation, probably @ApostaC @KuntaiDu should have a look. Still looking to see what others think about the API change.

I think the LMCache PRs still need a lot of review/design. Which I expect :)

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work @zbennett10 .
I think this new API can be arranged.
However I am not super comfortable adding changes to the ModelRunnerv2 side as I am not 100% sure the direction we want to go there is to copy paste all changes on both sides.

Would it be an issue to scope this PR for v1 only?

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/base.py Outdated
Comment thread vllm/v1/worker/gpu_model_runner.py Outdated
Comment on lines +6547 to +6554
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
WorkerConnectorInitializationData,
)

kv_transfer_group.initialize_worker_connector(
WorkerConnectorInitializationData(model=self.model)
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you move this into the KVConnector mixin?
Ideally I'd like to have everything that is behind a has_kv_transfer_group() check in there..but we can keep the scope narrow for this PR

torch is imported at runtime, not under TYPE_CHECKING,
so the string annotation is unnecessary.

Signed-off-by: Zach Bennett <zach@worldflowai.com>

Signed-off-by: Zachary Bennett <bennett.zachary@outlook.com>
@zbennett10 zbennett10 force-pushed the semblend/register-model-hook branch from f624633 to 531cfa2 Compare March 25, 2026 18:44
@zbennett10
Copy link
Copy Markdown
Author

Thanks for your work @zbennett10 . I think this new API can be arranged. However I am not super comfortable adding changes to the ModelRunnerv2 side as I am not 100% sure the direction we want to go there is to copy paste all changes on both sides.

Would it be an issue to scope this PR for v1 only?

That's fair. Handled! @NickLucche

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 30, 2026
Copy link
Copy Markdown
Collaborator

@orozery orozery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zbennett10 Let's revert the changes to model runner v2 (v1/worker/gpu) and restore the changes to v1 (gpu_model_runner.py).
The import for WorkerConnectorInitializationData should be at the top together with the other imports.

Also, we got no response from the LMCache maintainers on this.
I don't want us to make any changes to their connector without their approval.
So either reach out to them, or let's move the LMCache changes to a follow-up PR.

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/base.py Outdated
Comment thread tests/v1/kv_connector/test_register_model.py Outdated
@zbennett10
Copy link
Copy Markdown
Author

@ApostaC @chunxiaozheng @ziruiliu — This PR adds an initialize_worker_connector() hook to KVConnectorBase_V1 that passes the loaded model to connectors after model loading. This enables CacheBlend's selective layer recomputation via VLLMModelTracker.

The LMCache connector changes are:

  • LMCacheConnectorV1: delegates model to engine's register_model() via hasattr guard (consistent with register_kv_caches pattern)
  • LMCacheConnectorV1Impl: registers model with VLLMModelTracker when CacheBlend is available, with graceful fallback
  • LMCacheMPConnector: delegates to worker adapter

Could you review the LMCache-specific changes? The approach uses the same guard patterns already established in the codebase.

- Revert v2 model runner changes (v1/worker/gpu/), restore hook to v1
  gpu_model_runner.py as requested by reviewer
- Drop WorkerConnectorInitializationResponse; return None instead
- Move tests from standalone test_register_model.py into
  test_multi_connector.py
- Ensure all imports are at top level

Signed-off-by: Zach Bennett <zach@worldflowai.com>

Signed-off-by: Zachary Bennett <bennett.zachary@outlook.com>
Add LMCache-specific tests for initialize_worker_connector to
test_multi_connector.py: delegation to engine, guard clauses,
VLLMModelTracker integration, and graceful import error handling.

Signed-off-by: Zach Bennett <zach@worldflowai.com>

Signed-off-by: Zachary Bennett <bennett.zachary@outlook.com>
Comment thread tests/v1/kv_connector/unit/test_multi_connector.py Outdated
…nnector

LMCache initialize_worker_connector tests belong alongside other
LMCacheConnectorV1 tests, not in the MultiConnector test file.

Signed-off-by: Zach Bennett <zach@worldflowai.com>

Signed-off-by: Zachary Bennett <bennett.zachary@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants