Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3469,6 +3469,19 @@ class KVTransferConfig:
kv_port: int = 14579
"""The KV connector port, used to build distributed connection."""

kv_connector_external_registration_args: Optional[dict[str, Any]] = None
"""Extra args for external kv connector registration.
Example Usages:
kv_transfer_config=KVTransferConfig(
kv_connector="ExternalConnector",
kv_connector_external_registration_args={
"name": "ExternalConnector",
"module_path": "external_lib.path.external_kv_connector",
"class_name": "ExternalConnector",
},
)
"""

kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
"""any extra config that the connector may need."""

Expand Down
4 changes: 4 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def create_connector_v1(
f"but found {envs.VLLM_USE_V1=}")

connector_name = config.kv_transfer_config.kv_connector
if (config.kv_transfer_config.kv_connector_external_registration_args
is not None and connector_name not in cls._registry):
cls.register_connector(**config.kv_transfer_config.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be called multiple times

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, based on current v1 design each process will only call it once

Copy link
Copy Markdown
Contributor Author

@KingsleyZhang123 KingsleyZhang123 May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double checked, when running world_size = 1, UniProcExecutor will be used, scheduler & worker will run on the same process and this will be called twice.

vllm/vllm/config.py

Lines 1828 to 1829 in d67085c

if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"

kv_connector_external_registration_args)
connector_cls = cls._registry[connector_name]()
assert issubclass(connector_cls, KVConnectorBase_V1)
logger.info("Creating v1 connector with name: %s", connector_name)
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def load_model(self) -> None:
context = nullcontext()
with context:
self.model_runner.load_model()
ensure_kv_transfer_initialized(self.vllm_config)

@torch.inference_mode()
def determine_available_memory(self) -> int:
Expand Down Expand Up @@ -343,8 +344,6 @@ def init_worker_distributed_environment(
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)

ensure_kv_transfer_initialized(vllm_config)


def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
Expand Down