diff --git a/tests/v1/kv_connector/unit/offloading_connector/test_connector.py b/tests/v1/kv_connector/unit/offloading_connector/test_connector.py new file mode 100644 index 000000000000..93a197d2414e --- /dev/null +++ b/tests/v1/kv_connector/unit/offloading_connector/test_connector.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import MagicMock + +from vllm.distributed.kv_transfer.kv_connector.v1.base import SupportsHMA +from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import ( + OffloadingConnector, +) + + +def test_offloading_connector_supports_hma() -> None: + assert issubclass(OffloadingConnector, SupportsHMA) + + +def test_request_finished_all_groups_delegates_to_scheduler() -> None: + connector = object.__new__(OffloadingConnector) + connector.connector_scheduler = MagicMock() + request = MagicMock() + block_ids = ([1, 2], [3, 4]) + + connector.connector_scheduler.request_finished.return_value = (False, None) + + assert connector.request_finished_all_groups(request, block_ids) == (False, None) + connector.connector_scheduler.request_finished.assert_called_once_with( + request, block_ids + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py index f437782070df..e28729901cc7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py @@ -647,7 +647,7 @@ def update_connector_output(self, connector_output: KVConnectorOutput): def request_finished( self, request: Request, - block_ids: list[int], + block_ids: list[int] | tuple[list[int], ...], ) -> tuple[bool, dict[str, Any] | None]: """ Called when a request has finished, before its blocks are freed. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 05b835572c9f..e0e83fd992c0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -10,6 +10,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( KVConnectorBase_V1, KVConnectorRole, + SupportsHMA, ) from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( @@ -42,7 +43,7 @@ from vllm.v1.request import Request -class OffloadingConnector(KVConnectorBase_V1): +class OffloadingConnector(KVConnectorBase_V1, SupportsHMA): @property def prefer_cross_layer_blocks(self) -> bool: return True @@ -151,6 +152,14 @@ def request_finished( assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + def take_events(self) -> Iterable[KVCacheEvent]: assert self.connector_scheduler is not None return self.connector_scheduler.take_events()