Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/v1/kv_connector/unit/test_lmcache_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from vllm.distributed.kv_events import BlockStored
from vllm.distributed.kv_transfer.kv_connector.v1 import SupportsHMA
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
LMCacheConnectorV1,
LMCacheKVEvents,
Expand Down Expand Up @@ -784,3 +785,35 @@ def test_lmcache_kv_events_aggregation(self):
assert aggregated_events[0].block_hashes == ["hash_common"]
assert aggregated_events[0].parent_block_hash == "parent_common"
assert aggregated_events[0].token_ids == [1, 2, 3]


def test_lmcache_connector_supports_hma() -> None:
assert issubclass(LMCacheConnectorV1, SupportsHMA)


def test_request_finished_all_groups_uses_adapter_selected_group(
mock_connector,
) -> None:
request = MagicMock()
block_ids = ([1], [2], [3])
mock_connector._lmcache_engine.get_lmcache_kv_cache_group_id.return_value = 2
mock_connector.request_finished.return_value = (False, {"ok": True})

result = LMCacheConnectorV1.request_finished_all_groups(
mock_connector, request, block_ids
)

assert result == (False, {"ok": True})
mock_connector.request_finished.assert_called_once_with(request, [3])


def test_request_finished_all_groups_rejects_missing_selected_group(
mock_connector,
) -> None:
request = MagicMock()
mock_connector._lmcache_engine.get_lmcache_kv_cache_group_id.return_value = 2

with pytest.raises(ValueError, match="selected KV cache group 2"):
LMCacheConnectorV1.request_finished_all_groups(
mock_connector, request, ([1], [2])
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
SupportsHMA,
)
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionMetadata
Expand Down Expand Up @@ -69,7 +70,7 @@ def __repr__(self) -> str:
return f"<LMCacheKVEvents events={self.get_all_events()}>"


class LMCacheConnectorV1(KVConnectorBase_V1):
class LMCacheConnectorV1(KVConnectorBase_V1, SupportsHMA):
@classmethod
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
"""
Expand Down Expand Up @@ -114,6 +115,12 @@ def __init__(

self._kv_cache_events: LMCacheKVEvents | None = None

def get_lmcache_kv_cache_config(self) -> "KVCacheConfig | None":
"""
Return the vLLM KV cache config for LMCache's integration adapter.
"""
return self._kv_cache_config

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method might get picked up as dead code.
We should either comment that it's being used externally or use the property kv_cache_config directly.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 8b58c3ed5d: I removed the LMCache-specific getter and exposed the config through the connector attribute instead.

vLLM side:

  • LMCacheConnectorV1.__init__ now sets self.kv_cache_config = kv_cache_config
  • get_lmcache_kv_cache_config() was removed

Companion LMCache side is updated in LMCache/LMCache#3284 commit 6ae3e7f7 to read parent.kv_cache_config directly.

Fresh validation:

  • timeout 60 uvx ruff check vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py tests/v1/kv_connector/unit/test_lmcache_connector.py
  • timeout 60 uvx ruff format --check vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py tests/v1/kv_connector/unit/test_lmcache_connector.py
  • .venv/bin/python -m py_compile vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py tests/v1/kv_connector/unit/test_lmcache_connector.py
  • .venv/bin/python -m pytest -q tests/v1/kv_connector/unit/test_lmcache_connector.py -> 26 passed, 16 warnings

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# ==============================
# Worker-side methods
# ==============================
Expand Down Expand Up @@ -339,6 +346,29 @@ def request_finished(
"""
return self._lmcache_engine.request_finished(request, block_ids)

def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called when a request has finished with hybrid KV cache groups enabled.

LMCache currently stores and loads one paged attention KV cache group.
Hybrid models can put that group after state-cache groups, so use the
group selected by the LMCache adapter instead of assuming group 0.
"""
get_group_id = getattr(
self._lmcache_engine, "get_lmcache_kv_cache_group_id", None
)
kv_cache_group_id = get_group_id() if callable(get_group_id) else 0
if kv_cache_group_id >= len(block_ids):
raise ValueError(
f"LMCache selected KV cache group {kv_cache_group_id}, "
f"but vLLM provided {len(block_ids)} block-id group(s)"
)
return self.request_finished(request, block_ids[kv_cache_group_id])

def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Expand Down
Loading