diff --git a/vllm/config.py b/vllm/config.py index c6b97bbdcd66..9ea3d61766df 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3469,6 +3469,19 @@ class KVTransferConfig: kv_port: int = 14579 """The KV connector port, used to build distributed connection.""" + kv_connector_external_registration_args: Optional[dict[str, Any]] = None + """Extra args for external kv connector registration. + Example Usages: + kv_transfer_config=KVTransferConfig( + kv_connector="ExternalConnector", + kv_connector_external_registration_args={ + "name": "ExternalConnector", + "module_path": "external_lib.path.external_kv_connector", + "class_name": "ExternalConnector", + }, + ) + """ + kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) """any extra config that the connector may need.""" diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 54cb1871db3c..7d2628f759e6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -59,6 +59,10 @@ def create_connector_v1( f"but found {envs.VLLM_USE_V1=}") connector_name = config.kv_transfer_config.kv_connector + if (config.kv_transfer_config.kv_connector_external_registration_args + is not None and connector_name not in cls._registry): + cls.register_connector(**config.kv_transfer_config. + kv_connector_external_registration_args) connector_cls = cls._registry[connector_name]() assert issubclass(connector_cls, KVConnectorBase_V1) logger.info("Creating v1 connector with name: %s", connector_name) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 5352b1c5a37c..91576ec0697a 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -161,6 +161,7 @@ def load_model(self) -> None: context = nullcontext() with context: self.model_runner.load_model() + ensure_kv_transfer_initialized(self.vllm_config) @torch.inference_mode() def determine_available_memory(self) -> int: @@ -343,8 +344,6 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - ensure_kv_transfer_initialized(vllm_config) - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype.