Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
import importlib
from typing import TYPE_CHECKING, Callable

# yapf: disable
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase, KVConnectorBaseType)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.logger import init_logger

# yapf: enable

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import KVTransferConfig, VllmConfig

logger = init_logger(__name__)

Expand Down Expand Up @@ -42,17 +46,7 @@ def create_connector(
f"but found {envs.VLLM_USE_V1=}")

kv_transfer_config = config.kv_transfer_config
connector_name = kv_transfer_config.kv_connector
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = kv_transfer_config.kv_connector_module_path
if connector_module_path is None:
raise ValueError(
f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
assert issubclass(connector_cls, KVConnectorBase)
connector_cls = cls.get_connector_class(kv_transfer_config)
logger.info("Creating v1 connector with name: %s and engine_id: %s",
connector_cls.__name__, kv_transfer_config.engine_id)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
Expand All @@ -65,6 +59,23 @@ def create_connector(
# We build separately to enforce strict separation
return connector_cls(config, role)

@classmethod
def get_connector_class(
cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
connector_name = kv_transfer_config.kv_connector
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = kv_transfer_config.kv_connector_module_path
if connector_module_path is None:
raise ValueError(
f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
return connector_cls


# Register various connectors here.
# The registration should not be done in each individual file, as we want to
Expand Down
11 changes: 7 additions & 4 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1)
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.logger import init_logger
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput

Expand Down Expand Up @@ -106,8 +106,9 @@ def get_kv_connector_cache_layout():
vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config
if kv_config is not None:
required_kvcache_layout = (
KVConnectorBase_V1.get_required_kvcache_layout(vllm_config))
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
vllm_config)
if required_kvcache_layout is not None:
return required_kvcache_layout
logger.info_once("Connectors do not specify a " \
Expand Down Expand Up @@ -143,6 +144,8 @@ def update_finished_set(req_ids: Optional[set[str]],
finished_recving = set[str]()
for output in outputs:
output = output.kv_connector_output
if not output:
continue
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,
Expand Down